Skip to content
Permalink
master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
# -*- coding: utf-8 -*-
import argparse
from sklearn.externals import joblib
from src.model.nn_model import Model
from src.batcher import Batcher
from src.hook import acc_hook, save_predictions, evaluate_perclass
import warnings
import os
import tensorflow as tf
from tensorflow import keras
parser = argparse.ArgumentParser()
parser.add_argument("dataset",help="dataset to train model",choices=["figer","gillick","other"])
parser.add_argument("encoder",help="context encoder to use in model",choices=["averaging","lstm","attentive"])
parser.add_argument('--feature', dest='feature', action='store_true')
parser.add_argument('--no-feature', dest='feature', action='store_false')
parser.set_defaults(feature=False)
parser.add_argument('--hier', dest='hier', action='store_true')
parser.add_argument('--no-hier', dest='hier', action='store_false')
parser.set_defaults(hier=False)
args = parser.parse_args()
print ("Creating the model")
model = Model(type=args.dataset,encoder=args.encoder,hier=args.hier,feature=args.feature)
# model = Model("figer", "averaging", True, True)
print ("Loading the dictionaries")
# src_data = "/home/cxchu/Downloads/NFGEC-master/data/"
universe = "got"
src_data = "/var/tmp/wikia/entity-typing/deep-learning/attentionNER/" + universe +"/data"
# d = "Wiki" if args.dataset == "figer" else "OntoNotes"
d = ""
dicts = joblib.load(src_data+d+"/dicts_"+args.dataset+".pkl")
print ("Loading the datasets")
train_dataset = joblib.load(src_data+d+"/train_"+args.dataset+".pkl")
dev_dataset = joblib.load(src_data+d+"/dev_"+args.dataset+".pkl")
test_dataset = joblib.load(src_data+d+"/test_"+args.dataset+".pkl")
print
print ("train_size:", train_dataset["data"].shape[0])
print ("dev_size: ", dev_dataset["data"].shape[0])
print ("test_size: ", test_dataset["data"].shape[0])
print ("Creating batchers")
# batch_size : 1000, context_length : 10
train_batcher = Batcher(train_dataset["storage"],train_dataset["data"],1000,10,dicts["id2vec"])
dev_batcher = Batcher(dev_dataset["storage"],dev_dataset["data"],dev_dataset["data"].shape[0],10,dicts["id2vec"])
test_batcher = Batcher(test_dataset["storage"],test_dataset["data"],test_dataset["data"].shape[0],10,dicts["id2vec"])
# print(dev_dataset["data"].shape[0])
step_par_epoch = 2000 if args.dataset == "figer" else 150
# step_par_epoch = 2000 if args.dataset == "figer" else 1000
print ("start trainning")
for epoch in range(5):
train_batcher.shuffle()
print ("epoch",epoch)
for i in range(step_par_epoch):
# with warnings.catch_warnings():
# warnings.simplefilter("ignore", category=RuntimeWarning)
# try:
context_data, mention_representation_data, target_data, feature_data = train_batcher.next()
# print("check")
model.train(context_data, mention_representation_data, target_data, feature_data)
# print("train")
# except Exception:
# continue
print ("------dev--------")
# with warnings.catch_warnings():
# warnings.simplefilter("ignore", category=RuntimeWarning)
context_data, mention_representation_data, target_data, feature_data = dev_batcher.next()
scores = model.predict(context_data, mention_representation_data,feature_data)
acc_hook(scores, target_data)
print ("Training completed. Below are the final test scores: ")
print ("-----test--------")
context_data, mention_representation_data, target_data, feature_data = test_batcher.next()
scores = model.predict(context_data, mention_representation_data, feature_data)
acc_hook(scores, target_data)
fname = args.dataset + "_" + args.encoder + "_" + str(args.feature) + "_" + str(args.hier) + ".txt"
save_predictions(scores, target_data, dicts["id2label"],fname)
'''
print('Saving model and test to load the model..............')
save_dir = '/var/tmp/wikia/entity-typing/deep-learning/general-types/' + 'general-model-onto'
model_name = 'model'
model.save(save_dir, model_name)
'''
print('Saving model and test to load the model..............')
save_dir = '/var/tmp/wikia/entity-typing/deep-learning/attentionNER/' + universe
model_name = 'model'
model.save(save_dir, model_name)
'''
#evaluation per class
precision, recall, f1 = evaluate_perclass(scores, target_data, dicts["id2label"])
# for key, value in precision.items():
# print(key + ":\t" + str(value) + "\t" + str(recall[key]) + "\t" + str(f1[key]))
#writing type to file
fout = open("/var/tmp/wikia/entity-typing/deep-learning/" + universe +"/coarse-classess", "w")
for key, value in sorted(f1.items(), key=lambda kv: kv[1], reverse=True):
if value >= 0.80:
print(key + ":\t" + str(precision[key]) + "\t" + str(recall[key]) + "\t" + str(value))
fout.write(key + "\n")
fout.close()
'''
print ("Cheers!")
'''
print("Testing on different universe---------------------------------")
u = "got";
print("Testing on: " + u)
dicts = joblib.load("/var/tmp/wikia/entity-typing/" + u +"/data/dicts_"+ args.dataset + ".pkl")
test_dataset = joblib.load("/var/tmp/wikia/entity-typing/" + u +"/data/test_"+ args.dataset + ".pkl")
print ("test_size: ", test_dataset["data"].shape[0])
test_batcher = Batcher(test_dataset["storage"],test_dataset["data"],test_dataset["data"].shape[0],10,dicts["id2vec"])
context_data, mention_representation_data, target_data, feature_data = test_batcher.next()
scores = model.top_class_predict(context_data, mention_representation_data, feature_data)
acc_hook(scores, target_data)
fname = args.dataset + "_" + args.encoder + "_" + str(args.feature) + "_" + str(args.hier) + "_train_lotr_test_got.txt"
save_predictions(scores, target_data, dicts["id2label"],fname)
'''