Skip to content
This repository has been archived by the owner. It is now read-only.
Permalink
dc0d9bd781
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
462 lines (343 sloc) 15.4 KB
#!/usr/bin/env python
"""
Classes and functions for performing bias estimation, correction and visualization in ATACorrect
@author: Mette Bentsen
@contact: mette.bentsen (at) mpi-bn.mpg.de
@license: MIT
"""
import os
import sys
import gc
import numpy as np
import multiprocessing as mp
import time
from datetime import datetime
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from scipy.optimize import curve_fit
#Bio-specific packages
import pysam
#Internal functions and classes
from tobias.utils.sequences import *
from tobias.utils.signals import *
from tobias.utils.ngs import *
from tobias.utils.utilities import *
from tobias.utils.logger import *
#Catch warnings from curve_fit
import warnings
from scipy.optimize import OptimizeWarning
warnings.simplefilter("error", OptimizeWarning)
#--------------------------------------------------------------------------------------------------#
class AtacBias:
""" Class for storing information about estimated bias """
def __init__(self, L, stype):
self.stype = stype
self.bias = {"forward": SequenceMatrix.create(L, self.stype),
"reverse": SequenceMatrix.create(L, self.stype),
"both": SequenceMatrix.create(L, self.stype)}
self.no_reads = 0
def join(self, obj):
""" Join counts from AtacBias obj with other AtacBias obj """
self.bias["forward"].add_counts(obj.bias["forward"])
self.bias["reverse"].add_counts(obj.bias["reverse"])
self.bias["both"].add_counts(obj.bias["both"])
self.no_reads += obj.no_reads
#--------------------------------------------------------------------------------------------------#
def count_reads(regions_list, params):
""" Count reads from bam within regions (counts position of cutsite to prevent double-counting) """
bam_f = params.bam
read_shift = params.read_shift
bam_obj = pysam.AlignmentFile(bam_f, "rb")
log_q = params.log_q
logger = TobiasLogger("", params.verbosity, log_q) #sending all logger calls to log_q
#Count per region
read_count = 0
logger.spam("Started counting region_chunk ({0} -> {1})".format("_".join([str(element) for element in regions_list[0]]), "_".join([str(element) for element in regions_list[-1]])))
for region in regions_list:
read_lst = ReadList().from_bam(bam_obj, region)
for read in read_lst:
read.get_cutsite(read_shift)
if read.cutsite > region.start and read.cutsite < region.end: #only reads within borders
read_count += 1
logger.spam("Finished counting region_chunk ({0} -> {1})".format("_".join([str(element) for element in regions_list[0]]), "_".join([str(element) for element in regions_list[-1]])))
bam_obj.close()
return(read_count)
#--------------------------------------------------------------------------------------------------#
def bias_estimation(regions_list, params):
""" Estimates bias of insertions within regions """
#Info on run
bam_f = params.bam
fasta_f = params.genome
k_flank = params.k_flank
bg_shift = params.bg_shift
read_shift = params.read_shift
L = 2 * k_flank + 1
logger = TobiasLogger("", params.verbosity, params.log_q) #sending all logger calls to log_q
#Open objects for reading
bam_obj = pysam.AlignmentFile(bam_f, "rb")
fasta_obj = pysam.FastaFile(fasta_f)
chrom_lengths = dict(zip(bam_obj.references, bam_obj.lengths)) #Chromosome boundaries from bam_obj
bias_obj = AtacBias(L, params.score_mat)
strands = ["forward", "reverse"]
#Estimate bias at each region
for region in regions_list:
read_lst = ReadList().from_bam(bam_obj, region)
for read in read_lst:
read.get_cutsite(read_shift)
## Kmer cutting bias ##
if len(read_lst) > 0:
#Extract sequence
extended_region = region.extend_reg(k_flank + bg_shift) #Extend to allow full kmers
extended_region.check_boundary(chrom_lengths, "cut")
sequence_obj = GenomicSequence(extended_region).from_fasta(fasta_obj)
#Split reads forward/reverse
for_lst, rev_lst = read_lst.split_strands()
read_lst_strand = {"forward": for_lst, "reverse": rev_lst}
for strand in strands:
#Map reads to positions
read_per_pos = {}
for read in read_lst_strand[strand]:
if read.cigartuples is not None:
first_tuple = read.cigartuples[-1] if read.is_reverse else read.cigartuples[0]
if first_tuple[0] == 0 and first_tuple[1] > params.k_flank + max(np.abs(params.read_shift)): #Only include non-clipped reads
read_per_pos[read.cutsite] = read_per_pos.get(read.cutsite, []) + [read]
for cutsite in read_per_pos:
if cutsite > region.start and cutsite < region.end: #only reads within borders
read = read_per_pos[cutsite][0] #use first read in list to establish kmer
no_cut = min(len(read_per_pos[cutsite]), 10) #put cap on number of cuts to limit influence of outliers
read.get_kmer(sequence_obj, k_flank)
bias_obj.bias[strand].add_sequence(read.kmer, no_cut)
read.shift_cutsite(-bg_shift) #upstream of read; ensures that bg is not within fragment
read.get_kmer(sequence_obj, k_flank)
bias_obj.bias[strand].add_background(read.kmer, no_cut)
bias_obj.no_reads += no_cut
bam_obj.close()
fasta_obj.close()
return(bias_obj) #object containing information collected on bias
#--------------------------------------------------------------------------------------------------#
def relu(x, a, b):
""" a and b are components of a linear curve (y=a*x+b) """
y = np.maximum(0.0, a*x + b)
return(y)
#--------------------------------------------------------------------------------------------------#
def bias_correction(regions_list, params, bias_obj):
""" Corrects bias in cutsites (from bamfile) using estimated bias """
logger = TobiasLogger("", params.verbosity, params.log_q)
bam_f = params.bam
fasta_f = params.genome
k_flank = params.k_flank
read_shift = params.read_shift
L = 2 * k_flank + 1
w = params.window
f = int(w/2.0)
qs = params.qs
f_extend = k_flank + f
strands = ["forward", "reverse"]
pre_bias = {strand: SequenceMatrix.create(L, "PWM") for strand in strands}
post_bias = {strand: SequenceMatrix.create(L, "PWM") for strand in strands}
#Open bamfile and fasta
bam_obj = pysam.AlignmentFile(bam_f, "rb")
fasta_obj = pysam.FastaFile(fasta_f)
chrom_lengths = dict(zip(bam_obj.references, bam_obj.lengths)) #Chromosome boundaries from bam_obj
out_signals = {}
#Go through each region
for region_obj in regions_list:
region_obj.extend_reg(f_extend)
region_obj.check_boundary(chrom_lengths, "cut")
reg_len = region_obj.get_length() #length including flanking
reg_key = (region_obj.chrom, region_obj.start+f_extend, region_obj.end-f_extend) #output region
out_signals[reg_key] = {"uncorrected":{}, "bias":{}, "expected":{}, "corrected":{}}
################################
####### Uncorrected reads ######
################################
#Get cutsite positions for each read
read_lst = ReadList().from_bam(bam_obj, region_obj)
for read in read_lst:
read.get_cutsite(read_shift)
#Exclude reads with cutsites outside region
read_lst = ReadList([read for read in read_lst if read.cutsite > region_obj.start and read.cutsite < region_obj.end])
for_lst, rev_lst = read_lst.split_strands()
read_lst_strand = {"forward": for_lst, "reverse": rev_lst}
for strand in strands:
out_signals[reg_key]["uncorrected"][strand] = read_lst_strand[strand].signal(region_obj)
out_signals[reg_key]["uncorrected"][strand] = np.round(out_signals[reg_key]["uncorrected"][strand], 5)
################################
###### Estimation of bias ######
################################
#Get sequence in this region
sequence_obj = GenomicSequence(region_obj).from_fasta(fasta_obj)
#Score sequence using forward/reverse motifs
for strand in strands:
if strand == "forward":
seq = sequence_obj.sequence
bias = bias_obj.bias[strand].score_sequence(seq)
elif strand == "reverse":
seq = sequence_obj.revcomp
bias = bias_obj.bias[strand].score_sequence(seq)[::-1] #3'-5'
out_signals[reg_key]["bias"][strand] = np.nan_to_num(bias) #convert any nans to 0
#################################
###### Correction of reads ######
#################################
reg_end = reg_len - k_flank
step = 10
overlaps = int(params.window / step)
window_starts = list(range(k_flank, reg_end-params.window, step))
window_ends = list(range(k_flank+params.window, reg_end, step))
window_ends[-1] = reg_len
windows = list(zip(window_starts, window_ends))
for strand in strands:
########### Estimate bias threshold ###########
bias_predictions = np.zeros((overlaps,reg_len))
row = 0
for window in windows:
signal_w = out_signals[reg_key]["uncorrected"][strand][window[0]:window[1]]
bias_w = out_signals[reg_key]["bias"][strand][window[0]:window[1]]
signalmax = np.max(signal_w)
biasmin = np.min(bias_w)
biasmax = np.max(bias_w)
if signalmax > 0:
try:
popt, pcov = curve_fit(relu, bias_w, signal_w)
bias_predict = relu(bias_w, *popt)
except (OptimizeWarning, RuntimeError):
cut_positions = np.logical_not(np.isclose(signal_w, 0))
bias_min = np.min(bias_w[cut_positions])
bias_predict = bias_w - bias_min
bias_predict[bias_predict < 0] = 0
if np.max(bias_predict) > 0:
bias_predict = bias_predict / np.max(bias_predict)
else:
bias_predict = np.zeros(window[1]-window[0])
bias_predictions[row, window[0]:window[1]] = bias_predict
row += 1 if row < overlaps - 1 else 0
bias_prediction = np.mean(bias_predictions, axis=0)
bias = bias_prediction
######## Calculate expected signal ######
signal_sum = fast_rolling_math(out_signals[reg_key]["uncorrected"][strand], w, "sum")
signal_sum[np.isnan(signal_sum)] = 0 #f-width ends of region
bias_sum = fast_rolling_math(bias, w, "sum") #ends of arr are nan
nulls = np.logical_or(np.isclose(bias_sum, 0), np.isnan(bias_sum))
bias_sum[nulls] = 1 # N-regions will give stretches of 0-bias
bias_probas = bias / bias_sum
bias_probas[nulls] = 0 #nan to 0
out_signals[reg_key]["expected"][strand] = signal_sum * bias_probas
######## Correct signal ########
out_signals[reg_key]["uncorrected"][strand] *= bias_obj.correction_factor
out_signals[reg_key]["expected"][strand] *= bias_obj.correction_factor
out_signals[reg_key]["corrected"][strand] = out_signals[reg_key]["uncorrected"][strand] - out_signals[reg_key]["expected"][strand]
#######################################
######## Verify correction ########
#######################################
#Verify correction across all reads
for strand in strands:
for idx in range(k_flank,reg_len - k_flank -1):
if idx > k_flank and idx < reg_len-k_flank:
orig = out_signals[reg_key]["uncorrected"][strand][idx]
correct = out_signals[reg_key]["corrected"][strand][idx]
if orig != 0 or correct != 0: #if both are 0, don't add to pre/post bias
if strand == "forward":
kmer = sequence_obj.sequence[idx-k_flank:idx+k_flank+1]
else:
kmer = sequence_obj.revcomp[reg_len-idx-k_flank-1:reg_len-idx+k_flank]
#Save kmer for bias correction verification
pre_bias[strand].add_sequence(kmer, orig)
post_bias[strand].add_sequence(kmer, correct)
#######################################
######## Write to queue #########
#######################################
#Set size back to original
for track in out_signals[reg_key]:
for strand in out_signals[reg_key][track]:
out_signals[reg_key][track][strand] = out_signals[reg_key][track][strand][f_extend:-f_extend]
#Calculate "both" if split_strands == False
if params.split_strands == False:
for track in out_signals[reg_key]:
out_signals[reg_key][track]["both"] = out_signals[reg_key][track]["forward"] + out_signals[reg_key][track]["reverse"]
#Send to queue
strands_to_write = ["forward", "reverse"] if params.split_strands == True else ["both"]
for track in out_signals[reg_key]:
#Send to writer per strand
for strand in strands_to_write:
key = "{0}:{1}".format(track, strand)
logger.spam("Sending {0} signal from region {1} to writer queue".format(key, reg_key))
qs[key].put((key, reg_key, out_signals[reg_key][track][strand]))
#Sent to qs - delete from this process
out_signals[reg_key] = None
bam_obj.close()
fasta_obj.close()
gc.collect()
return([pre_bias, post_bias])
####################################################################################################
######################################## Plot functions ############################################
####################################################################################################
colors = {0:"green", 1:"red", 2:"blue", 3:"darkkhaki"}
names = {0:"A", 1:"T", 2:"C", 3:"G"}
def plot_pssm(matrix, title):
""" Plot pssm in matrix """
#Make figure
fig, ax = plt.subplots()
fig.suptitle(title, fontsize=16, weight="bold")
#Formatting of x axis
length = matrix.shape[1]
flank = int(length/2.0)
xvals = np.arange(length) # each position corresponds to i in mat
#Customize minor tick labels
xtick_pos = xvals[:-1] + 0.5
xtick_labels = list(range(-flank, flank))
ax.xaxis.set_major_locator(ticker.FixedLocator(xvals))
ax.xaxis.set_major_formatter(ticker.FixedFormatter(xtick_labels))
ax.xaxis.set_minor_locator(ticker.FixedLocator(xtick_pos)) #locate minor ticks between major ones (cutsites)
ax.xaxis.set_minor_formatter(ticker.NullFormatter())
#Make background grid on major ticks
plt.grid(color='0.8', which="minor", ls="--", axis="x")
plt.xlim([0, length-1])
plt.xlabel('Position from cutsite')
plt.ylabel('PSSM score')
######## Plot data #######
#Plot PSSM / bias motif
for nuc in range(4):
plt.plot(xvals, matrix[nuc,:], color=colors[nuc], label=names[nuc])
#Cutsite-line
plt.axvline(flank-0.5, linewidth=2, color="black", zorder=100)
#Finish up
plt.legend(loc="lower right")
plt.tight_layout()
fig.subplots_adjust(top=0.88, hspace=0.5)
return(fig)
#----------------------------------------------------------------------------------------------------#
def plot_correction(pre_mat, post_mat, title):
""" Plot comparison of pre-correction and post-correction matrices """
#Make figure
fig, ax = plt.subplots()
fig.suptitle(title, fontsize=16, weight="bold")
L = pre_mat.shape[1]
flank = int(L/2.0)
xvals = np.arange(L) # each position corresponds to i in mat
#Customize minor tick labels
xtick_pos = xvals[:-1] + 0.5
xtick_labels = list(range(-flank, flank)) #-flank - flank without 0
ax.xaxis.set_major_locator(ticker.FixedLocator(xvals))
ax.xaxis.set_major_formatter(ticker.FixedFormatter(xtick_labels))
ax.xaxis.set_minor_locator(ticker.FixedLocator(xtick_pos)) #locate minor ticks between major ones (cutsites)
ax.xaxis.set_minor_formatter(ticker.NullFormatter())
#PWMs for all mats
pre_pwm = pre_mat
post_pwm = post_mat
#Pre correction
for nuc in range(4):
yvals = [pre_pwm[nuc, m] for m in range(L)]
plt.plot(xvals, yvals, linestyle="--", color=colors[nuc], linewidth=1, alpha=0.5)
#Post correction
for nuc in range(4):
yvals = [post_pwm[nuc, m] for m in range(L)]
plt.plot(xvals, yvals, color=colors[nuc], linewidth=2, label=names[nuc])
plt.xlim([0, L-1])
plt.xlabel('Position from cutsite')
plt.ylabel('Nucleotide frequency')
#Set legend
plt.plot([0],[0], linestyle="--", linewidth=1, color="black", label="pre-correction")
plt.plot([0],[0], color="black", label="post-correction")
plt.legend(loc="lower right", prop={'size':6})
plt.tight_layout()
fig.subplots_adjust(top=0.88, hspace=0.4)
return(fig)