-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
cxchu
committed
Apr 22, 2020
1 parent
213e016
commit c01dc3c
Showing
257 changed files
with
3,751,637 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
eclipse.preferences.version=1 | ||
org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8 | ||
org.eclipse.jdt.core.compiler.compliance=1.8 | ||
org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning | ||
org.eclipse.jdt.core.compiler.source=1.8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
Reference Univs Ranking" | ||
- Keep 1000 top tf-idf tokens | ||
|
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
import sys | ||
|
||
def main(): | ||
word2freq = {} | ||
feature2freq = {} | ||
label2freq = {} | ||
with open(sys.argv[1]) as f: | ||
for line in f: | ||
temp = line.strip().split("\t") | ||
labels, features, words = temp[3],temp[4],temp[2] | ||
for label in labels.split(): | ||
if label not in label2freq: | ||
label2freq[label] = 1 | ||
else: | ||
label2freq[label] += 1 | ||
for word in words.split(): | ||
if word not in word2freq: | ||
word2freq[word] = 1 | ||
else: | ||
word2freq[word] += 1 | ||
for feature in features.split(): | ||
if feature not in feature2freq: | ||
feature2freq[feature] = 1 | ||
else: | ||
feature2freq[feature] += 1 | ||
|
||
def _local(file_path, X2freq, start_idx=0): | ||
with open(file_path,"w") as f: | ||
for i,(X,freq) in enumerate(sorted(X2freq.items(),key = lambda t: -t[1]), start_idx): | ||
f.write(str(i)+"\t"+X+"\t"+str(freq)+"\n") | ||
|
||
_local(sys.argv[2],word2freq) | ||
_local(sys.argv[3],feature2freq, start_idx=1) | ||
_local(sys.argv[4],label2freq) | ||
|
||
if(__name__=='__main__'): | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
''' | ||
Created on Jul 1, 2019 | ||
@author: cxchu | ||
''' | ||
# -*- coding: utf-8 -*- | ||
from sklearn.externals import joblib | ||
import pickle | ||
import numpy as np | ||
import sys | ||
|
||
|
||
def create_dataset(corpus_path,label2id,word2id,feature2id): | ||
num_of_labels = len(label2id.values()) | ||
num_of_samples = sum(1 for line in open(corpus_path)) | ||
storage = [] | ||
#data = np.zeros((num_of_samples,4+70+num_of_labels),"int32") | ||
data = np.zeros((num_of_samples,4+num_of_labels),"int32") | ||
s_start_pointer = 0 | ||
num = 0 | ||
|
||
with open(corpus_path) as f: | ||
for line in f: | ||
if len(line.split("\t")) != 5: | ||
continue | ||
(start,end,words,labels,features) = line.strip().split("\t") | ||
labels, words, features = labels.split(), words.split(), features.split() | ||
length = len(words) | ||
start, end = int(start), int(end) | ||
labels_code = [0 for i in range(num_of_labels)] | ||
for label in labels: | ||
if label in label2id: | ||
labels_code[label2id[label]] = 1 | ||
words_code = [word2id[word] if word in word2id else word2id["unk"] for word in words] | ||
features_code = [feature2id[feature] for feature in features] | ||
storage += words_code | ||
data[num,0] = s_start_pointer # s_start | ||
data[num,1] = s_start_pointer + length # s_end | ||
data[num,2] = s_start_pointer + start # e_start | ||
data[num,3] = s_start_pointer + end # e_end | ||
#data[num,4:4+len(features_code)] = np.array(features_code) | ||
data[num,4:] = labels_code | ||
s_start_pointer += length | ||
num += 1 | ||
if num % 100000 == 0: | ||
print(num) | ||
return np.array(storage,"int32"), data | ||
|
||
def create_raw_dataset(label2id,word2id,feature2id): | ||
num_of_labels = len(label2id.values()) | ||
# num_of_samples = sum(1 for line in open(corpus_path)) | ||
storage = [] | ||
# data = np.zeros((num_of_samples,4+70+num_of_labels),"int32") | ||
# data = np.zeros((num_of_samples,4+num_of_labels),"int32") | ||
s_start_pointer = 0 | ||
num = 0 | ||
sentences = [] | ||
mentions = [] | ||
|
||
lines = [] | ||
print('input') | ||
sys.stdout.flush() | ||
line = sys.stdin.readline() | ||
while line != 'end': | ||
lines.append(line.strip()) | ||
line = sys.stdin.readline() | ||
line = line.strip() | ||
print('get all input') | ||
sys.stdout.flush() | ||
data = np.zeros((len(lines),4+num_of_labels),"int32") | ||
|
||
for line in lines: | ||
if len(line.split("\t")) != 3: | ||
continue | ||
(start,end,words) = line.strip().split("\t") | ||
sentences.append(words) | ||
words = words.split() | ||
length = len(words) | ||
start, end = int(start), int(end) | ||
if start == end: | ||
mention = words[start] | ||
else: | ||
mention = " ".join([words[i+start] for i in range(end-start)]) | ||
mentions.append(mention) | ||
labels_code = [0 for _ in range(num_of_labels)] | ||
|
||
words_code = [word2id[word] if word in word2id else word2id["unknown"] for word in words] | ||
storage += words_code | ||
data[num,0] = s_start_pointer # s_start | ||
data[num,1] = s_start_pointer + length # s_end | ||
data[num,2] = s_start_pointer + start # e_start | ||
data[num,3] = s_start_pointer + end # e_end | ||
# data[num,4:4+len(features_code)] = np.array(features_code) | ||
# data[num,74:] = labels_code | ||
data[num,4:] = labels_code | ||
s_start_pointer += length | ||
num += 1 | ||
if num % 100000 == 0: | ||
print(num) | ||
return np.array(storage,"int32"), data, sentences, mentions | ||
|
||
def main(): | ||
dicts = joblib.load(sys.argv[1]) | ||
label2id = dicts["label2id"] | ||
word2id = dicts["word2id"] | ||
feature2id = dicts["feature2id"] | ||
storage,data = create_dataset(sys.argv[2],label2id,word2id,feature2id) | ||
dataset = {"storage":storage,"data":data} | ||
joblib.dump(dataset,sys.argv[3]) | ||
|
||
|
||
if(__name__=='__main__'): | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from sklearn.externals import joblib | ||
import pickle | ||
import numpy as np | ||
import sys | ||
|
||
def load_X2id(file_path): | ||
X2id = {} | ||
id2X = {} | ||
with open(file_path) as f: | ||
for line in f: | ||
temp = line.strip().split() | ||
id,X = temp[0],temp[1] | ||
X2id[X] = int(id) | ||
id2X[int(id)] = X | ||
return X2id, id2X | ||
|
||
def load_word2vec(file_path): | ||
word2vec = {} | ||
with open(file_path) as lines: | ||
for line in lines: | ||
try: | ||
split = line.split(" ") | ||
word = split[0] | ||
vector_strings = split[1:] | ||
vector = [float(num) for num in vector_strings] | ||
word2vec[word] = np.array(vector) | ||
except Exception: | ||
pass | ||
return word2vec | ||
|
||
def create_id2vec(word2id,word2vec): | ||
unk_vec = word2vec["unk"] | ||
dim_of_vector = len(unk_vec) | ||
num_of_tokens = len(word2id) | ||
id2vec = np.zeros((num_of_tokens,dim_of_vector)) | ||
for word,t_id in word2id.items(): | ||
id2vec[t_id,:] = word2vec[word] if word in word2vec else unk_vec | ||
return id2vec | ||
|
||
|
||
|
||
def main(): | ||
print ("word2id...") | ||
word2id, id2word = load_X2id(sys.argv[1]) | ||
print ("feature2id...") | ||
feature2id, id2feature = load_X2id(sys.argv[2]) | ||
print ("label2id...") | ||
label2id, id2label = load_X2id(sys.argv[3]) | ||
print ("word2vec...") | ||
word2vec = load_word2vec(sys.argv[4]) | ||
print ("id2vec...") | ||
id2vec = create_id2vec(word2id,word2vec) | ||
print ("done!") | ||
dicts = {"id2vec":id2vec,"word2id":word2id,"id2word":id2word,"label2id":label2id,"id2label":id2label,"feature2id":feature2id,"id2feature":id2feature} | ||
print ("dicts save...") | ||
joblib.dump(dicts,sys.argv[5]) | ||
|
||
|
||
if(__name__=='__main__'): | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
model_checkpoint_path: "/local/home/cxchu/workspace/NNForFineGrainedEntityTyping/general-model/model" | ||
all_model_checkpoint_paths: "/local/home/cxchu/workspace/NNForFineGrainedEntityTyping/general-model/model" |
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
''' | ||
Created on Jul 1, 2019 | ||
@author: cxchu | ||
''' | ||
''' | ||
Created on Jan 15, 2019 | ||
@author: cxchu | ||
''' | ||
from sklearn.externals import joblib | ||
import sys, os | ||
|
||
from create_dataset import create_raw_dataset | ||
from src.batcher import Batcher | ||
from src.hook import acc_hook, save_predictions, evaluate_perclass | ||
from src.model.nn_model import Model | ||
import tensorflow as tf | ||
|
||
import optparse | ||
|
||
optparser = optparse.OptionParser() | ||
optparser.add_option( | ||
"-b", "--basedir", default="/var/tmp/wikia/entity-typing/deep-learning/", | ||
help="directory to model of top class prediction" | ||
) | ||
opts = optparser.parse_args()[0] | ||
|
||
basedir = opts.basedir | ||
|
||
dict = basedir + "general-model/dicts_gillick.pkl" | ||
# dict = basedir + "general-types/all/data/dicts_gillick.pkl" | ||
|
||
# universe = "onion" | ||
# raw_data = "/var/tmp/wikia/entity-typing/input-data/" + universe + "/" + universe + "-3-supervised" | ||
# save_data = "/var/tmp/wikia/entity-typing/deep-learning/got/got_test.pkl" | ||
|
||
dicts = joblib.load(dict) | ||
label2id = dicts["label2id"] | ||
id2label = dicts["id2label"] | ||
word2id = dicts["word2id"] | ||
feature2id = dicts["feature2id"] | ||
storage,data,sentences, mentions = create_raw_dataset(label2id,word2id,feature2id) | ||
test_dataset = {"storage":storage,"data":data} | ||
# joblib.dump(dataset,save_data) | ||
|
||
print ("Loading the dataset") | ||
# test_dataset = joblib.load(save_data) | ||
|
||
print ("test_size: ", test_dataset["data"].shape[0]) | ||
|
||
print ("Creating batchers") | ||
# batch_size : 1000, context_length : 10 | ||
test_batcher = Batcher(test_dataset["storage"],test_dataset["data"],test_dataset["data"].shape[0],10,dicts["id2vec"]) | ||
|
||
|
||
print('Loading the model..............') | ||
save_dir = './general-model' | ||
model_name = 'model' | ||
|
||
checkpoint_file = os.path.join(save_dir, model_name) | ||
graph = tf.Graph() | ||
with graph.as_default(): | ||
sess = tf.Session() | ||
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) | ||
saver.restore(sess, checkpoint_file) | ||
|
||
keep_prob = graph.get_operation_by_name("keep_prob").outputs[0] | ||
mention_representation = graph.get_operation_by_name("mention_representation").outputs[0] | ||
|
||
context_length = 8 | ||
context = [graph.get_operation_by_name("context" + str(i)).outputs[0] for i in range(context_length*2+1)] | ||
|
||
distribution = graph.get_operation_by_name("distribution").outputs[0] | ||
|
||
context_data, mention_representation_data, target_data, feature_data = test_batcher.next() | ||
|
||
feed = {mention_representation: mention_representation_data, | ||
keep_prob: [1.0]} | ||
# if self.feature == True and feature_data is not None: | ||
# feed[self.features] = feature_data | ||
for i in range(context_length*2+1): | ||
feed[context[i]] = context_data[:,i,:] | ||
scores = sess.run(distribution,feed_dict=feed) | ||
|
||
#writing to file..... | ||
# fname = "/var/tmp/wikia/entity-typing/input-data/" + universe + "/" + universe + "-3-supervised-general-prediction" | ||
# with open(fname,"w") as f: | ||
print('results') | ||
sys.stdout.flush() | ||
for sent, score in zip(mentions, scores): | ||
res = [] | ||
# print(sent + "===" + str(score)) | ||
for id, s in enumerate(list(score)): | ||
if s >= 0.5: | ||
res.append(id2label[id] + "\t" + str(s)) | ||
if len(res) > 0: | ||
print(sent + "=====[" + ", ".join([t for t in res]) + "]") | ||
sys.stdout.flush() | ||
print('end') | ||
sys.stdout.flush() | ||
# f.close() | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
Oops, something went wrong.