Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
CoNekT/planet/models/coexpression_clusters.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
253 lines (196 sloc)
10.4 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flask import url_for | |
from planet import db | |
from planet.models.expression_networks import ExpressionNetwork | |
from planet.models.gene_families import GeneFamily | |
from planet.models.relationships import sequence_coexpression_cluster, SequenceGOAssociation, ClusterGOEnrichment | |
from planet.models.relationships import CoexpressionClusterSimilarity, SequenceCoexpressionClusterAssociation, SequenceFamilyAssociation | |
from utils.enrichment import hypergeo_sf, fdr_correction | |
from utils.benchmark import benchmark | |
from utils.jaccard import jaccard | |
from sqlalchemy import join | |
from sqlalchemy.orm import joinedload, load_only | |
import json | |
from math import log2 | |
class CoexpressionClusteringMethod(db.Model): | |
__tablename__ = 'coexpression_clustering_methods' | |
id = db.Column(db.Integer, primary_key=True) | |
network_method_id = db.Column(db.Integer, db.ForeignKey('expression_network_methods.id'), index=True) | |
method = db.Column(db.Text) | |
cluster_count = db.Column(db.Integer) | |
clusters = db.relationship('CoexpressionCluster', backref=db.backref('method', lazy='joined'), lazy='dynamic') | |
def calculate_enrichment(self): | |
gene_count = self.network_method.species.sequence_count | |
species_id = self.network_method.species_id | |
clusters = self.clusters.all() | |
for cluster in clusters: | |
print(gene_count, species_id, cluster.id) | |
pass | |
@staticmethod | |
def update_counts(): | |
""" | |
To avoid long counts the number of clusters per method can be precalculated and stored in the database | |
using this function | |
""" | |
methods = CoexpressionClusteringMethod.query.all() | |
for m in methods: | |
m.cluster_count = m.clusters.count() | |
try: | |
db.session.commit() | |
except Exception as e: | |
db.session.rollback() | |
print(e) | |
class CoexpressionCluster(db.Model): | |
__tablename__ = 'coexpression_clusters' | |
id = db.Column(db.Integer, primary_key=True) | |
method_id = db.Column(db.Integer, db.ForeignKey('coexpression_clustering_methods.id')) | |
name = db.Column(db.String(50), index=True) | |
sequences = db.relationship('Sequence', secondary=sequence_coexpression_cluster, lazy='dynamic') | |
sequence_associations = db.relationship('SequenceCoexpressionClusterAssociation', | |
backref=db.backref('coexpression_cluster', lazy='joined'), | |
lazy='dynamic') | |
go_enrichment = db.relationship('ClusterGOEnrichment', | |
backref=db.backref('cluster', lazy='joined'), | |
lazy='dynamic') | |
@staticmethod | |
def get_cluster(cluster_id): | |
""" | |
Returns the network for a whole cluster (reporting edges only between members of the cluster !) | |
:param cluster_id: internal ID of the cluster | |
""" | |
cluster = CoexpressionCluster.query.get(cluster_id) | |
probes = [member.probe for member in cluster.sequence_associations.all()] | |
network = cluster.method.network_method.probes.\ | |
options(joinedload('sequence').load_only('name')).\ | |
filter(ExpressionNetwork.probe.in_(probes)).all() | |
nodes = [] | |
edges = [] | |
existing_edges = [] | |
for node in network: | |
nodes.append({"id": node.probe, | |
"name": node.probe, | |
"gene_id": int(node.sequence_id) if node.sequence_id is not None else None, | |
"gene_name": node.sequence.name if node.sequence_id is not None else node.probe, | |
"depth": 0}) | |
links = json.loads(node.network) | |
for link in links: | |
# only add links that are in the cluster ! | |
if link["probe_name"] in probes and [node.probe, link["probe_name"]] not in existing_edges: | |
edges.append({"source": node.probe, | |
"target": link["probe_name"], | |
"profile_comparison": | |
url_for('expression_profile.expression_profile_compare_probes', | |
probe_a=node.probe, | |
probe_b=link["probe_name"], | |
species_id=node.method.species.id), | |
"depth": 0, | |
"link_score": link["link_score"], | |
"edge_type": cluster.method.network_method.edge_type}) | |
existing_edges.append([node.probe, link["probe_name"]]) | |
existing_edges.append([link["probe_name"], node.probe]) | |
return {"nodes": nodes, "edges": edges} | |
def __calculate_enrichment(self): | |
""" | |
Initial implementation to calculate GO enrichment for a single cluster | |
""" | |
gene_count = self.method.network_method.species.sequence_count | |
species_id = self.method.network_method.species_id | |
sequences = self.sequences.options(load_only("id")).all() | |
associations = SequenceGOAssociation.query\ | |
.filter(SequenceGOAssociation.sequence_id.in_([s.id for s in sequences]))\ | |
.options(load_only("sequence_id", "go_id"))\ | |
.group_by(SequenceGOAssociation.sequence_id, SequenceGOAssociation.go_id) | |
go_data = {} | |
for a in associations: | |
if a.go_id not in go_data.keys(): | |
go_data[a.go_id] = {} | |
go_data[a.go_id]["total_count"] = json.loads(a.go.species_counts)[str(species_id)] | |
go_data[a.go_id]["cluster_count"] = 1 | |
else: | |
go_data[a.go_id]["cluster_count"] += 1 | |
p_values = [] | |
for go_id in go_data: | |
p_values.append(hypergeo_sf(go_data[go_id]['cluster_count'], | |
len(sequences), | |
go_data[go_id]['total_count'], | |
gene_count)) | |
corrected_p_values = fdr_correction(p_values) | |
for i, go_id in enumerate(go_data): | |
enrichment = ClusterGOEnrichment() | |
enrichment.cluster_id = self.id | |
enrichment.go_id = go_id | |
enrichment.cluster_count = go_data[go_id]['cluster_count'] | |
enrichment.cluster_size = len(sequences) | |
enrichment.go_count = go_data[go_id]['total_count'] | |
enrichment.go_size = gene_count | |
enrichment.enrichment = log2((go_data[go_id]['cluster_count']/len(sequences))/(go_data[go_id]['total_count']/gene_count)) | |
enrichment.p_value = p_values[i] | |
enrichment.corrected_p_value = corrected_p_values[i] | |
db.session.add(enrichment) | |
try: | |
db.session.commit() | |
except Exception as e: | |
db.session.rollback() | |
print(e) | |
@staticmethod | |
def calculate_enrichment(empty=True): | |
""" | |
Static method to calculate the enrichment for all cluster in the database | |
:param empty: empty table cluster_go_enrichment first | |
:return: | |
""" | |
# If required empty the table first | |
if empty: | |
try: | |
db.session.query(ClusterGOEnrichment).delete() | |
db.session.commit() | |
except Exception as e: | |
db.session.rollback() | |
print(e) | |
else: | |
clusters = CoexpressionCluster.query.all() | |
for i, cluster in enumerate(clusters): | |
print(i, "\t cluster: ", cluster.method_id, cluster.name) | |
cluster.__calculate_enrichment() | |
@staticmethod | |
@benchmark | |
def calculate_similarities(gene_family_method_id=1, percentile_pass=0.95): | |
# sqlalchemy to fetch cluster associations | |
fields = [SequenceCoexpressionClusterAssociation.__table__.c.sequence_id, | |
SequenceCoexpressionClusterAssociation.__table__.c.coexpression_cluster_id] | |
condition = SequenceCoexpressionClusterAssociation.__table__.c.sequence_id is not None | |
cluster_associations = db.engine.execute(db.select(fields).where(condition)).fetchall() | |
# sqlalchemy to fetch sequence family associations | |
fields = [SequenceFamilyAssociation.__table__.c.sequence_id, SequenceFamilyAssociation.__table__.c.gene_family_id, GeneFamily.__table__.c.method_id] | |
condition = GeneFamily.__table__.c.method_id == gene_family_method_id | |
table = join(SequenceFamilyAssociation.__table__, GeneFamily.__table__, SequenceFamilyAssociation.__table__.c.gene_family_id == GeneFamily.__table__.c.id) | |
sequence_families = db.engine.execute(db.select(fields).select_from(table).where(condition)).fetchall() | |
# convert sqlachemy results into dictionary | |
sequence_to_family = {seq_id: fam_id for seq_id, fam_id, method_id in sequence_families} | |
cluster_to_sequences = {} | |
cluster_to_families = {} | |
for seq_id, cluster_id in cluster_associations: | |
if cluster_id not in cluster_to_sequences.keys(): | |
cluster_to_sequences[cluster_id] = [] | |
cluster_to_sequences[cluster_id].append(seq_id) | |
for cluster_id, sequences in cluster_to_sequences.items(): | |
families = list(set([sequence_to_family[s] for s in sequences if s in sequence_to_family.keys()])) | |
if len(families) > 0: | |
cluster_to_families[cluster_id] = families | |
keys = list(cluster_to_families.keys()) | |
data = [] | |
for i in range(len(keys) - 1): | |
for j in range(i+1, len(keys)): | |
current_keys = [keys[x] for x in [i, j]] | |
current_families = [cluster_to_families[k] for k in current_keys] | |
if len(current_families[0]) > 4 and len(current_families[1]) > 4: | |
j = jaccard(current_families[0], current_families[1]) | |
data.append([current_keys[0], current_keys[1], j]) | |
ordered_j = sorted([a[2] for a in data]) | |
percentile_cutoff = ordered_j[int(len(ordered_j)*percentile_pass)] | |
database = [{'source_id': d[0], | |
'target_id': d[1], | |
'gene_family_method_id': gene_family_method_id, | |
'jaccard_index': d[2], | |
'p_value': 0, | |
'corrected_p_value': 0} for d in data if d[2] >= percentile_cutoff] | |
db.engine.execute(CoexpressionClusterSimilarity.__table__.insert(), database) |