Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
ssHMM/bin/batch_seqstructhmm
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
executable file
151 lines (126 sloc)
8.32 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import datetime | |
import argparse | |
import multiprocessing | |
import itertools | |
import os | |
import random | |
import sys | |
from sshmm.sequence_container import readSequencesAndShapes | |
from sshmm.seqstructhmm import SeqStructHMM | |
from sshmm import seq_hmm | |
from sshmm.log import prepareLogger | |
def parseArguments(args): | |
"""Sets up the command-line parser and calls it on the command-line arguments to the program. | |
arguments: | |
args -- command-line arguments to the program""" | |
parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter, | |
description='Trains multiple Hidden Markov Models for the sequence-structure binding ' | |
'preferences of a given set of RNA-binding proteins. The models are trained on ' | |
'sequences and structures in FASTA format located in a given data directory.\n' | |
'During the training process, statistics about the models are printed ' | |
'on stdout. In every iteration, the current model and a visualization ' | |
'of the model are stored in the batch directory.\n' | |
'The training processes terminate when no significant progress has been ' | |
'made for three iterations.') | |
parser.add_argument('data_directory', type=str, | |
help='data directory; must contain the sequence files under fasta/<protein>/positive.fasta and structure files under <structure_type>/<protein>/positive.txt') | |
parser.add_argument('proteins', type=str, | |
help='list of RNA-binding proteins to analyze (surrounded by quotation marks, separated by whitespace)') | |
parser.add_argument('batch_directory', type=str, | |
help='directory for batch output') | |
parser.add_argument('--cores', '-c', type=int, help='number of cores to use (if not given, all cores are used)') | |
parser.add_argument('--structure_type', '-s', type=str, default='shapes', | |
help='structure type to use; must match location of structure files (see data_directory argument above) (default: shapes)') | |
parser.add_argument('--motif_length', '-n', type=int, default=6, | |
help='length of the motifs that shall be found (default: 6)') | |
parser.add_argument('--baum_welch', '-b', action='store_true', default=True, | |
help='should the models be initialized with a Baum-Welch optimized sequence motif (default: yes)') | |
parser.add_argument('--flexibility', '-f', type=int, default=10, | |
help='greedyness of Gibbs sampler: model parameters are sampled from among the top f configurations (default: f=10), set f to 0 in order to include all possible configurations') | |
parser.add_argument('--block_size', type=int, default=1, | |
help='number of sequences to be held-out in each iteration (default: 1)') | |
parser.add_argument('--threshold', '-t', type=float, default=10.0, | |
help='the iterative algorithm is terminated if this reduction in sequence structure ' | |
'loglikelihood is not reached for any of the 3 last measurements (default: 10)') | |
parser.add_argument('--termination_interval', '-i', type=int, default=100, | |
help='produce output every <i> iterations (default: i=100)') | |
return parser.parse_args() | |
def do_training_for_configuration(data_directory, batch_job_directory, protein, structure_type, motif_length, baum_welch, flexibility, block_size, termination_interval, threshold): | |
# Create output directory for job | |
job_directory = '{0}/{1}_{2}_ml{3}_fl{4}_bs{5}_ti{6}_tt{7}/'.format(batch_job_directory, protein, structure_type, motif_length, flexibility, block_size, termination_interval, threshold) | |
os.mkdir(job_directory) | |
random.seed() #reinitialize the random number generator because else random numbers would repeat | |
# Initialize logs | |
main_logger = prepareLogger('main_logger', job_directory + '/verbose.log', verbose=True) | |
numbers_logger = prepareLogger('numbers_logger', job_directory + '/numbers.log') | |
# Read in sequences and shapes | |
sequence_path = '{0}/fasta/{1}/positive.fasta'.format(data_directory, protein) | |
structure_path = '{0}/{1}/{2}/positive.txt'.format(data_directory, structure_type, protein) | |
training_sequence_container = readSequencesAndShapes(sequence_path, structure_path, motif_length, main_logger) | |
main_logger.info('Read %s training sequences', training_sequence_container.get_length()) | |
# Initialize model | |
model = SeqStructHMM(job_directory, main_logger, numbers_logger, training_sequence_container, motif_length, | |
flexibility, block_size) | |
if baum_welch: | |
best_baumwelch_sequence_model = seq_hmm.find_best_baumwelch_sequence_models(motif_length, | |
training_sequence_container, | |
main_logger) | |
best_viterbi_paths = best_baumwelch_sequence_model[1] | |
model.prepare_model_with_viterbi(best_viterbi_paths) | |
else: | |
model.prepare_model_randomly() | |
# Train model | |
key_statistics = model.do_training(termination_interval, threshold, write_model_state=True)[0] | |
# Write results | |
pwm_global = model.get_pwm_global() | |
pwm_global.write_to_file(job_directory + 'logo_global.txt') | |
pwm_global.write_weblogo(job_directory + 'logo_global.png') | |
pwm_best_sequences = model.get_pwm_best_sequences() | |
pwm_best_sequences.write_to_file(job_directory + 'logo_best_sequences.txt') | |
pwm_best_sequences.write_weblogo(job_directory + 'logo_best_sequences.png') | |
pwm_hairpin = model.get_pwm_hairpin() | |
pwm_hairpin.write_to_file(job_directory + 'logo_hairpin.txt') | |
pwm_hairpin.write_weblogo(job_directory + 'logo_hairpin.png') | |
model.model.printAsGraph(job_directory + 'final_graph.png', model.sequence_container.get_length()) | |
model.model.write(job_directory + 'final_model.xml') | |
return [protein] + key_statistics | |
def do_training_for_configuration_star(argsAsList): | |
return do_training_for_configuration(*argsAsList) | |
def start_process(batchLogger): | |
batchLogger.info('Starting ' + multiprocessing.current_process().name) | |
def main(args): | |
options = parseArguments(args) | |
#Create output directory for job | |
batch_job_directory = "{0}/batch_{1}".format(options.batch_directory, datetime.datetime.now().strftime('%y%m%d_%H%M%S')) | |
os.mkdir(batch_job_directory) | |
#Initialize log | |
batch_logger = prepareLogger('batch_logger', batch_job_directory + '/batch.log', verbose=True, stdout=True) | |
#Generate jobs | |
proteins = options.proteins.split() | |
jobs = itertools.product([options.data_directory], [batch_job_directory], proteins, [options.structure_type], | |
[options.motif_length], [True], [options.flexibility], | |
[options.block_size], [options.termination_interval], [options.threshold]) | |
batch_logger.info('Started batch run on {0} proteins.'.format(len(proteins))) | |
#Determine number of cores to use | |
if options.cores: | |
pool_size = options.cores | |
else: | |
pool_size = min(multiprocessing.cpu_count() - 1, len(proteins)) | |
batch_logger.info('Using {0} cores.'.format(pool_size)) | |
pool = multiprocessing.Pool(processes=pool_size, | |
initializer=start_process, | |
initargs=[batch_logger], | |
maxtasksperchild=2, | |
) | |
#Execute jobs | |
results = pool.map_async(do_training_for_configuration_star, jobs).get(9999999) | |
pool.close() # no more tasks | |
pool.join() # wrap up current tasks | |
batch_logger.info('Finished all proteins.') | |
batch_logger.info('Result statistics:') | |
batch_logger.info('\t'.join(['Protein', 'Number of iterations', 'Sequence loglikelihood', 'Sequence-structure loglikelihood'])) | |
for configuration in results: | |
batch_logger.info('\t'.join(map(str, configuration))) | |
if __name__ == "__main__" : | |
sys.exit(main(sys.argv)) |