Skip to content
This repository has been archived by the owner. It is now read-only.

Commit

Permalink
Changed estimation of log2fc in BINDetect from density to GMM; includ…
Browse files Browse the repository at this point in the history
…ed better finding of best number of components; bug fixes
  • Loading branch information
msbentsen committed Jan 9, 2019
1 parent 80bf5e4 commit 0efd686
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 2,618 deletions.
2,510 changes: 0 additions & 2,510 deletions test_data/genes_chr4.gtf.sorted

This file was deleted.

171 changes: 120 additions & 51 deletions tobias/footprinting/BINDetect.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@

import sklearn
from sklearn import mixture
from sklearn.neighbors import KernelDensity
from sklearn.model_selection import GridSearchCV

import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
from matplotlib.ticker import NullFormatter
import pandas as pd
import scipy
from scipy.optimize import curve_fit
from sklearn import preprocessing

#Bio-specific packages
import pyBigWig
Expand All @@ -49,6 +52,7 @@
from scipy.optimize import OptimizeWarning
warnings.simplefilter("ignore", OptimizeWarning)


#--------------------------------------------------------------------------------------------------------------#

def add_bindetect_arguments(parser):
Expand All @@ -67,9 +71,9 @@ def add_bindetect_arguments(parser):

required = parser.add_argument_group('Required arguments')
required.add_argument('--signals', metavar="<bigwig>", help="Signal per condition (.bigwig format)", nargs="*")
required.add_argument('--peaks', metavar="<bed>", help="Peaks.bed containing open chromatin regions across all conditions")
required.add_argument('--motifs', metavar="<motifs>", help="Motifs in pfm/jaspar format")
required.add_argument('--genome', metavar="<fasta>", help="Genome .fasta file")
required.add_argument('--peaks', metavar="<bed>", help="Peaks.bed containing open chromatin regions")

optargs = parser.add_argument_group('Optional arguments')
optargs.add_argument('--cond_names', metavar="<name>", nargs="*", help="Names of conditions fitting to --signals (default: prefix of --signals)")
Expand Down Expand Up @@ -97,9 +101,7 @@ def find_nearest_idx(array, value):

#----------------------------------------------------------------------------------------------------------------#
def run_bindetect(args):

#import matplotlib.pyplot as plt
#from matplotlib.backends.backend_pdf import PdfPages
""" Main function to run bindetect algorithm with input files and parameters given in args """

begin_time = datetime.now()

Expand All @@ -108,6 +110,7 @@ def run_bindetect(args):
args.cond_names = [os.path.basename(os.path.splitext(bw)[0]) for bw in args.signals] if args.cond_names is None else args.cond_names
args.outdir = os.path.abspath(args.outdir)

#Set output files
states = ["bound", "unbound"]
outfiles = [os.path.abspath(os.path.join(args.outdir, "*", "beds", "*_{0}_{1}.bed".format(condition, state))) for (condition, state) in itertools.product(args.cond_names, states)]
outfiles.append(os.path.abspath(os.path.join(args.outdir, "*", "beds", "*_all.bed")))
Expand All @@ -121,9 +124,10 @@ def run_bindetect(args):
outfiles.append(os.path.abspath(os.path.join(args.outdir, "bindetect_results.xlsx")))
outfiles.append(os.path.abspath(os.path.join(args.outdir, "bindetect_figures.pdf")))

#----------------------------------------------------------------------------------------------------#
#------------------------------------------- Setup logger -------------------------------------------#
#----------------------------------------------------------------------------------------------------#

#-------------------------------------------------------------------------------------------------------------#
#------------------------------------------------- Setup logger ----------------------------------------------#
#-------------------------------------------------------------------------------------------------------------#

logger = create_logger(args.verbosity, args.log)

Expand All @@ -140,8 +144,8 @@ def run_bindetect(args):
logger.comment("\n")

# Setup pool
worker_cores = max(1, args.cores - 1) #max(1, int(args.cores * 0.9))
writer_cores = 1 #int(args.cores * 0.1))
worker_cores = max(1, args.cores - 1) #max(1, int(args.cores * 0.9))
writer_cores = 1 #int(args.cores * 0.1))
logger.debug("Worker cores: {0}".format(worker_cores))
logger.debug("Writer cores: {0}".format(writer_cores))

Expand All @@ -150,6 +154,7 @@ def run_bindetect(args):

args.no_overwrite = False #leftover from earlier; remove


#-------------------------------------------------------------------------------------------------------------#
#-------------------------- Pre-processing data: Reading motifs, sequences, peaks ----------------------------#
#-------------------------------------------------------------------------------------------------------------#
Expand Down Expand Up @@ -191,6 +196,7 @@ def run_bindetect(args):


################# Peaks / GC in peaks ################

#Read peak and peak_header
peaks = RegionList().from_bed(args.peaks)
logger.info("- Found {0} regions in input peaks".format(len(peaks)))
Expand Down Expand Up @@ -385,9 +391,9 @@ def run_bindetect(args):
writer_pool.join()


#--------------------------------------------------------------------------------------#
#---------------- Process information on background scores and overlaps ---------------#
#--------------------------------------------------------------------------------------#
#-------------------------------------------------------------------------------------------------------------#
#---------------------------- Process information on background scores and overlaps --------------------------#
#-------------------------------------------------------------------------------------------------------------#

logger.info("Merging results from subsets")
background = {}
Expand All @@ -401,12 +407,13 @@ def run_bindetect(args):
background["signal"][bigwig] = np.array(background["signal"][bigwig])


## Estimate score distribution to define bound/unbound threshold per condition
###### Estimate score distribution to define bound/unbound threshold per condition ######
logger.comment("")
logger.info("Estimating score distributions per condition")
args.thresholds = {}
pseudos = []
figures = [] #save figures before saving to file to unify x-ranges
accessible = {}
for bigwig in args.cond_names:

#Prepare scores (remove 0's etc.)
Expand All @@ -417,8 +424,17 @@ def run_bindetect(args):
x_max = np.percentile(bg_values, [99])

#Fit mixture of normals
gmm = sklearn.mixture.GaussianMixture(n_components=2)
gmm.fit(np.log(bg_values).reshape(-1, 1))
lowest_bic = np.inf
for n_components in range(1,3): #1/2 components
gmm = sklearn.mixture.GaussianMixture(n_components=n_components, random_state=1)
gmm.fit(np.log(bg_values).reshape(-1, 1))

bic = gmm.bic(np.log(bg_values).reshape(-1,1))
logger.debug("n_compontents: {0} | bic: {1}".format(n_components, bic))
if bic < lowest_bic:
lowest_bic = bic
best_gmm = gmm
gmm = best_gmm

#Extract most-right gaussian
means = gmm.means_.flatten()
Expand All @@ -427,24 +443,27 @@ def run_bindetect(args):
log_params = scipy.stats.lognorm.fit(bg_values[bg_values < x_max], f0=sds[chosen_i], fscale=np.exp(means[chosen_i]))

#Plot mixture
#plt.hist(np.log(bg_values), bins='auto', density=True)
#xlim = plt.xlim()
#x = np.linspace(xlim[0], xlim[1], 1000)
#for i in range(2):
# pdf = scipy.stats.norm.pdf(x, means[i], sds[i])
# plt.plot(x, pdf)
if args.debug:
plt.hist(np.log(bg_values), bins='auto', density=True)
xlim = plt.xlim()
x = np.linspace(xlim[0], xlim[1], 1000)
for i in range(2):
pdf = scipy.stats.norm.pdf(x, means[i], sds[i])
plt.plot(x, pdf)

#logprob = gmm.score_samples(x.reshape(-1, 1))
#df = np.exp(logprob)
#plt.plot(x, df)
#plt.show()
logprob = gmm.score_samples(x.reshape(-1, 1))
df = np.exp(logprob)
plt.plot(x, df)
plt.show()

#Estimate threshold and pseudocount
threshold = round(scipy.stats.lognorm.ppf(1-args.bound_threshold, *log_params), 5)
args.thresholds[bigwig] = threshold
logger.info("- Threshold for condition {0} estimated at: {1}".format(bigwig, threshold))

pseudo = round(scipy.stats.lognorm.ppf(0.2, *log_params), 5)
#Mode of distribution
mode = scipy.optimize.fmin(lambda x: -scipy.stats.lognorm.pdf(x, *log_params), 0, disp=False)[0]
pseudo = mode / 2.0 #pseudo is half the mode
pseudos.append(pseudo)

#Plot fit
Expand All @@ -453,7 +472,7 @@ def run_bindetect(args):

xvals = np.linspace(0, x_max, 1000)
probas = scipy.stats.lognorm.pdf(xvals, *log_params)
ax.plot(xvals, probas, label="Log-normal fit")
ax.plot(xvals, probas, label="Log-normal fit", color="orange")

ax.axvline(threshold, color="black", label="Bound/unbound threshold")
ymax = plt.ylim()[1]
Expand All @@ -478,15 +497,18 @@ def run_bindetect(args):
#Estimate pseudocount
if args.pseudo == None:
args.pseudo = np.mean(pseudo)
logger.info("Pseudocunt estimated to: {0}".format(round(args.pseudo, 5)))
logger.info("Pseudocount estimated at: {0}".format(round(args.pseudo, 5)))

####### Foldchanges between conditions ########


############ Foldchanges between conditions ################
logger.comment("")
log2fc_params = {}
if len(args.signals) > 1:
logger.info("Calculating log2 fold changes between conditions")

for (bigwig1, bigwig2) in comparisons: #cond1, cond2
logger.debug("- {0} / {1}".format(bigwig1, bigwig2))
logger.info("- {0} / {1}".format(bigwig1, bigwig2))

scores1 = np.copy(background["signal"][bigwig1])
scores2 = np.copy(background["signal"][bigwig2])
Expand All @@ -497,26 +519,74 @@ def run_bindetect(args):
scores1, scores2, gcs = scores1[included], scores2[included], gcs[included]

log2fcs = np.log2(np.true_divide(scores1 + args.pseudo, scores2 + args.pseudo))
included = np.logical_and(np.logical_not(np.isclose(log2fcs, 0)), np.logical_not(np.isnan((gcs))))

log2fcs, gcs = log2fcs[included], gcs[included]
values = np.vstack([log2fcs, gcs])

if values.shape[1] == 0:
sys.exit("ERROR: Bigwig values of conditions {0} and {1} are equal or contain only zeroes - please check your input data.".format(bigwig1, bigwig2))

kernel = scipy.stats.gaussian_kde(values)
log2fc_params[(bigwig1, bigwig2)] = kernel
#Fit mixture to describe relationship between GC/log2fc
lowest_bic = np.inf
n_components_range = range(1,10)
for n_components in n_components_range:
gmm = sklearn.mixture.GaussianMixture(n_components=n_components, covariance_type="full", random_state=1)
gmm.fit(values.T)
bic = gmm.bic(values.T)

#Create plot
fig, ax = plt.subplots()
g = sns.jointplot(x=log2fcs, y=gcs, kind="kde") #, ax=ax)
g.set_axis_labels(xlabel="Log2 fold change", ylabel="GC content") #yaxis and xaxis label
g.fig.suptitle("Background log2FCs ({0} / {1})".format(bigwig1, bigwig2))
plt.tight_layout()
figure_pdf.savefig(g.fig)
plt.close()
logger.debug("n_compontents: {0} | bic: {1}".format(n_components, bic))
if bic < lowest_bic:
lowest_bic = bic
best_gmm = gmm

log2fc_params[(bigwig1, bigwig2)] = best_gmm

#Create plot
nullfmt = NullFormatter()
xmin, xmax = np.percentile(log2fcs, [1,99])
ymin, ymax = np.percentile(gcs, [1,99])

# definitions for the axes
left, width = 0.1, 0.65
bottom, height = 0.1, 0.65
bottom_h = left_h = left + width + 0.02

rect_scatter = [left, bottom, width, height]
rect_histx = [left, bottom_h, width, 0.2]
rect_histy = [left_h, bottom, 0.2, height]

#Initialize figure
plt.figure(1, figsize=(8, 8))
axScatter = plt.axes(rect_scatter)
axHistx = plt.axes(rect_histx)
axHisty = plt.axes(rect_histy)

#Gaussian mixture
X, Y = np.meshgrid(np.linspace(xmin,xmax,100), np.linspace(ymin,ymax,100))
positions = np.array([X.ravel(), Y.ravel()]).T
Z = np.exp(best_gmm.score_samples(positions))
Z = np.reshape(Z, X.shape)

zmax = np.percentile(Z, [99])
axScatter.contourf(X, Y, Z, 20, cmap=plt.cm.viridis)
#axScatter.scatter(log2fcs, gcs, s=0.5, alpha=0.5)

#Histograms
axHistx.hist(log2fcs[np.logical_and(xmin < log2fcs, log2fcs < xmax)], bins='auto', density=True, color="darkslateblue")
axHisty.hist(gcs[np.logical_and(ymin < gcs, gcs < ymax)], bins='auto', orientation='horizontal', density=True, color="darkslateblue")
axHistx.set_xlim(axScatter.get_xlim())
axHisty.set_ylim(axScatter.get_ylim())
axHisty.yaxis.set_major_formatter(nullfmt)
axHistx.xaxis.set_major_formatter(nullfmt)
axHistx.yaxis.set_major_formatter(nullfmt)
axHistx.xaxis.set_major_formatter(nullfmt)

#Decorate
axScatter.set_xlabel("Log2 fold change")
axScatter.set_ylabel("GC content")
axHistx.set_title("Background log2FCs ({0} / {1})".format(bigwig1, bigwig2))

figure_pdf.savefig(bbox_inches='tight')
plt.close()

background = None #free up space


Expand Down Expand Up @@ -549,7 +619,7 @@ def run_bindetect(args):

logger.info("Concatenating results from subsets")
info_table = pd.concat(results) #pandas tables
index_names = info_table.index
#index_names = info_table.index

pool.terminate()
pool.join()
Expand Down Expand Up @@ -604,12 +674,11 @@ def run_bindetect(args):

#Cluster distance matrix


#Test index names against names
index_names = info_table.index
for name in names:
if name not in index_names:
logger.info("{0} not in index".format(name))
#index_names = info_table.index
#for name in names:
# if name not in index_names:
# logger.info("{0} not in index".format(name))

#Plotting bindetect per comparison
for (cond1, cond2) in comparisons:
Expand All @@ -633,7 +702,7 @@ def run_bindetect(args):

end_time = datetime.now()
logger.comment("")
logger.info("Finished BINDetect run (time elapsed: {0})".format(end_time - begin_time))
logger.info("Finished BINDetect run (time elapsed: {0}). Results are found in: {1}".format(end_time - begin_time, args.outdir))



Expand Down
Loading

0 comments on commit 0efd686

Please sign in to comment.