Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
ENTYFI/attentionNER/train.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
131 lines (111 sloc)
5.54 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import argparse | |
from sklearn.externals import joblib | |
from src.model.nn_model import Model | |
from src.batcher import Batcher | |
from src.hook import acc_hook, save_predictions, evaluate_perclass | |
import warnings | |
import os | |
import tensorflow as tf | |
from tensorflow import keras | |
parser = argparse.ArgumentParser() | |
parser.add_argument("dataset",help="dataset to train model",choices=["figer","gillick","other"]) | |
parser.add_argument("encoder",help="context encoder to use in model",choices=["averaging","lstm","attentive"]) | |
parser.add_argument('--feature', dest='feature', action='store_true') | |
parser.add_argument('--no-feature', dest='feature', action='store_false') | |
parser.set_defaults(feature=False) | |
parser.add_argument('--hier', dest='hier', action='store_true') | |
parser.add_argument('--no-hier', dest='hier', action='store_false') | |
parser.set_defaults(hier=False) | |
args = parser.parse_args() | |
print ("Creating the model") | |
model = Model(type=args.dataset,encoder=args.encoder,hier=args.hier,feature=args.feature) | |
# model = Model("figer", "averaging", True, True) | |
print ("Loading the dictionaries") | |
# src_data = "/home/cxchu/Downloads/NFGEC-master/data/" | |
universe = "got" | |
src_data = "/var/tmp/wikia/entity-typing/deep-learning/attentionNER/" + universe +"/data" | |
# d = "Wiki" if args.dataset == "figer" else "OntoNotes" | |
d = "" | |
dicts = joblib.load(src_data+d+"/dicts_"+args.dataset+".pkl") | |
print ("Loading the datasets") | |
train_dataset = joblib.load(src_data+d+"/train_"+args.dataset+".pkl") | |
dev_dataset = joblib.load(src_data+d+"/dev_"+args.dataset+".pkl") | |
test_dataset = joblib.load(src_data+d+"/test_"+args.dataset+".pkl") | |
print ("train_size:", train_dataset["data"].shape[0]) | |
print ("dev_size: ", dev_dataset["data"].shape[0]) | |
print ("test_size: ", test_dataset["data"].shape[0]) | |
print ("Creating batchers") | |
# batch_size : 1000, context_length : 10 | |
train_batcher = Batcher(train_dataset["storage"],train_dataset["data"],1000,10,dicts["id2vec"]) | |
dev_batcher = Batcher(dev_dataset["storage"],dev_dataset["data"],dev_dataset["data"].shape[0],10,dicts["id2vec"]) | |
test_batcher = Batcher(test_dataset["storage"],test_dataset["data"],test_dataset["data"].shape[0],10,dicts["id2vec"]) | |
# print(dev_dataset["data"].shape[0]) | |
step_par_epoch = 2000 if args.dataset == "figer" else 150 | |
# step_par_epoch = 2000 if args.dataset == "figer" else 1000 | |
print ("start trainning") | |
for epoch in range(5): | |
train_batcher.shuffle() | |
print ("epoch",epoch) | |
for i in range(step_par_epoch): | |
# with warnings.catch_warnings(): | |
# warnings.simplefilter("ignore", category=RuntimeWarning) | |
# try: | |
context_data, mention_representation_data, target_data, feature_data = train_batcher.next() | |
# print("check") | |
model.train(context_data, mention_representation_data, target_data, feature_data) | |
# print("train") | |
# except Exception: | |
# continue | |
print ("------dev--------") | |
# with warnings.catch_warnings(): | |
# warnings.simplefilter("ignore", category=RuntimeWarning) | |
context_data, mention_representation_data, target_data, feature_data = dev_batcher.next() | |
scores = model.predict(context_data, mention_representation_data,feature_data) | |
acc_hook(scores, target_data) | |
print ("Training completed. Below are the final test scores: ") | |
print ("-----test--------") | |
context_data, mention_representation_data, target_data, feature_data = test_batcher.next() | |
scores = model.predict(context_data, mention_representation_data, feature_data) | |
acc_hook(scores, target_data) | |
fname = args.dataset + "_" + args.encoder + "_" + str(args.feature) + "_" + str(args.hier) + ".txt" | |
save_predictions(scores, target_data, dicts["id2label"],fname) | |
''' | |
print('Saving model and test to load the model..............') | |
save_dir = '/var/tmp/wikia/entity-typing/deep-learning/general-types/' + 'general-model-onto' | |
model_name = 'model' | |
model.save(save_dir, model_name) | |
''' | |
print('Saving model and test to load the model..............') | |
save_dir = '/var/tmp/wikia/entity-typing/deep-learning/attentionNER/' + universe | |
model_name = 'model' | |
model.save(save_dir, model_name) | |
''' | |
#evaluation per class | |
precision, recall, f1 = evaluate_perclass(scores, target_data, dicts["id2label"]) | |
# for key, value in precision.items(): | |
# print(key + ":\t" + str(value) + "\t" + str(recall[key]) + "\t" + str(f1[key])) | |
#writing type to file | |
fout = open("/var/tmp/wikia/entity-typing/deep-learning/" + universe +"/coarse-classess", "w") | |
for key, value in sorted(f1.items(), key=lambda kv: kv[1], reverse=True): | |
if value >= 0.80: | |
print(key + ":\t" + str(precision[key]) + "\t" + str(recall[key]) + "\t" + str(value)) | |
fout.write(key + "\n") | |
fout.close() | |
''' | |
print ("Cheers!") | |
''' | |
print("Testing on different universe---------------------------------") | |
u = "got"; | |
print("Testing on: " + u) | |
dicts = joblib.load("/var/tmp/wikia/entity-typing/" + u +"/data/dicts_"+ args.dataset + ".pkl") | |
test_dataset = joblib.load("/var/tmp/wikia/entity-typing/" + u +"/data/test_"+ args.dataset + ".pkl") | |
print ("test_size: ", test_dataset["data"].shape[0]) | |
test_batcher = Batcher(test_dataset["storage"],test_dataset["data"],test_dataset["data"].shape[0],10,dicts["id2vec"]) | |
context_data, mention_representation_data, target_data, feature_data = test_batcher.next() | |
scores = model.top_class_predict(context_data, mention_representation_data, feature_data) | |
acc_hook(scores, target_data) | |
fname = args.dataset + "_" + args.encoder + "_" + str(args.feature) + "_" + str(args.hier) + "_train_lotr_test_got.txt" | |
save_predictions(scores, target_data, dicts["id2label"],fname) | |
''' |