diff --git a/planet/models/coexpression_clusters.py b/planet/models/coexpression_clusters.py index 5f74aa9..0038737 100644 --- a/planet/models/coexpression_clusters.py +++ b/planet/models/coexpression_clusters.py @@ -8,6 +8,7 @@ from utils.enrichment import hypergeo_sf, fdr_correction from utils.benchmark import benchmark +from utils.jaccard import jaccard from sqlalchemy import join from sqlalchemy.orm import joinedload, load_only @@ -195,7 +196,7 @@ def calculate_enrichment(empty=True): @staticmethod @benchmark - def calculate_similarities(gene_family_method_id=1): + def calculate_similarities(gene_family_method_id=1, percentile_pass=0.95): # sqlalchemy to fetch cluster associations fields = [SequenceCoexpressionClusterAssociation.__table__.c.sequence_id, @@ -225,5 +226,28 @@ def calculate_similarities(gene_family_method_id=1): if len(families) > 0: cluster_to_families[cluster_id] = families - for c, f in cluster_to_families.items(): - print(c,f) \ No newline at end of file + keys = list(cluster_to_families.keys()) + + data = [] + + for i in range(len(keys) - 1): + for j in range(i+1, len(keys)): + current_keys = [keys[x] for x in [i, j]] + current_families = [cluster_to_families[k] for k in current_keys] + + if len(current_families[0]) > 4 and len(current_families[1]) > 4: + j = jaccard(current_families[0], current_families[1]) + data.append([current_keys[0], current_keys[1], j]) + + ordered_j = sorted([a[2] for a in data]) + percentile_cutoff = ordered_j[int(len(ordered_j)*percentile_pass)] + + database = [{'source_id': d[0], + 'target_id': d[1], + 'gene_family_method_id': gene_family_method_id, + 'jaccard_index': d[2], + 'p_value': 0, + 'corrected_p_value': 0} for d in data if d[2] >= percentile_cutoff] + + db.engine.execute(CoexpressionClusterSimilarity.__table__.insert(), database) + diff --git a/utils/jaccard.py b/utils/jaccard.py index ced1a85..d63b046 100644 --- a/utils/jaccard.py +++ b/utils/jaccard.py @@ -11,3 +11,4 @@ def jaccard(list_a, list_b): intersection_count = len(set(list_a).intersection(set(list_b))) return intersection_count/union_count +