Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
BlueWhale/bluewhale.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
105 lines (92 sloc)
4.98 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import sys | |
from dnn.dhs_sum_model import DhsSumModel as DHSSUM | |
from dnn.dhs_sum_model_within import DhsSumModelWithin as DHSSUMWITHIN | |
from dnn.dhs_sum_model_openonlyprofile import DhsSumModel as DHSSUMOOP | |
from dnn.dna_convnet_flank import DnaNeuralNet as DNA | |
from dnn.dna_convnet_flank_within import DnaNeuralNet as DNAWITHIN | |
from dnn.full_model_loop import FullModel as FULL | |
from dnn.full_model_loop_tfrna import FullModel as FULLRNA | |
from dnn.full_model_finetune import FullModelFinetune as FULLFINE | |
from dnn.full_model_within import FullModelWithin as FULLWITHIN | |
from dnn.full_model_within_finetune import FullModelWithinFinetune as FULLWITHINFINE | |
from dnn.full_model_aggregate import FullModel as FULLAGG | |
def main(name,dataset,mode,batchid, threshold, reuse, num_epochs=500): | |
assert mode in ["training","prediction",'evaluate'], "mode must be 'training', 'evaluate' or 'prediction'" | |
assert dataset in ["train","test","ladder"], "dataset must be 'train', 'test' or 'ladder'" | |
if mode=="evaluate": | |
model=globals()[name](batchid,threshold,dataset,"training") | |
model.loadParams() | |
model.finalEvaluation() | |
return | |
model=globals()[name](batchid,threshold,dataset,mode) | |
if mode=="training" and dataset != "train": | |
raise ValueError("mode=training can only performed on the training dataset") | |
if mode=="training": | |
if reuse: | |
print("reuse parameters") | |
model.loadParams() | |
try: | |
model.train(num_epochs) | |
except Exception, e: | |
print("error occurred:"+str(e)) | |
print("error occurred:"+str(sys.exc_info()[0])) | |
finally: | |
print("do final evaluation") | |
model.finalEvaluation() | |
model.saveParams() | |
else: | |
model.loadParams() | |
model.predict() | |
# evaluate the model | |
if __name__ == '__main__': | |
if ('--help' in sys.argv) or ('-h' in sys.argv) or (len(sys.argv)<=1): | |
print("Trains various statistical models on DREAM dataset.") | |
print("Usage: %s modelname dataset mode [epochs] [reuse]" % sys.argv[0]) | |
print("") | |
print("modelnames:") | |
print("\tDHSSUM:\tCell-type independent DhsModel that uses Dnase-fold enrichment across cell-types profile and 9 consecutive 200 bp bins") | |
print("\tDHSSUMWITHN:\tWithin-cell-type DhsModel that uses Dnase-fold enrichment across cell-types profile and 9 consecutive 200 bp bins") | |
print("\tDHSSUMOOP:\tCell-type specific DhsModel that uses Dnase-fold enrichment across 9 consecutive 200 bp bins") | |
print("\tDNA:\tCell-type independent DNA model that extracts features from the DNA sequence using a convolutional neural network") | |
print("\tDNAWITHIN:\tWithin cell-type DNA model that extracts features from the DNA sequence using a convolutional neural network") | |
print("\tFULL:\tCell-type specific Full model that uses the pretrained FULLAGG and DHSSUMOOP model and trains a cell-type specific model on top.") | |
print("\tFULLAGG:\tCell-type independent Full model that uses the pretrained DNA and DHSSUM model and trains another cell-type independent model on top | |
that combines those pieces of information.") | |
print("\tFULLRNA:\tCell-type specific Full model that uses the pretrained FULLAGG and DHSSUMOOP model and trains a cell-type specific model on top. This | |
model also leverages RNA-seq profiles of the TFs as a source of cell-type specificity.") | |
print("\tFULLFINE:\tFinetuning of FULLRNA") | |
print("\tFULLWITHIN:\tCell-type specific Full model that uses the pretrained DNAWITIN and DHSSUMWITHIN model and trains a cell-type specific model on top.") | |
print("\tFULLWITHINFINE:\tFinetuning of FULLFINE.") | |
print("") | |
print("dataset:") | |
print("\ttrain:\tTo use the training dataset") | |
print("\tladder:\tTo use the leaderboard dataset") | |
print("\ttest:\tTo use the final-test dataset") | |
print("") | |
print("mode:") | |
print("\ttraining:\tTrains the model parameters") | |
print("\tprediction:\tPerforms predictions using already trained parameters") | |
print("\tevaluate:\tEvaluates the model on the validation data.") | |
print("") | |
print("epochs: number of training epochs to perform (default: 500)") | |
print("") | |
print("reuse: if reuse is >0, the previously saved parameters are loaded and training is continued without random initial weights.") | |
print("") | |
print("Example usage:") | |
print("") | |
print("bluewhale.py DHSSUM train training 10") | |
else: | |
kwargs = {} | |
kwargs['threshold'] ='conservative' | |
kwargs['name'] = sys.argv[1] | |
kwargs['dataset'] = sys.argv[2] | |
kwargs['mode'] = sys.argv[3] | |
kwargs['batchid'] =None | |
if len(sys.argv) > 4: | |
kwargs['num_epochs'] = int(sys.argv[4]) | |
if len(sys.argv) > 5 and sys.argv[5]>0: | |
kwargs['reuse'] = True | |
else: | |
kwargs['reuse'] = False | |
main(**kwargs) |