Commit a7d85cb0 authored by mparker2's avatar mparker2
Browse files

refactor

parent 648cb985
Loading
Loading
Loading
Loading

d3pendr/bam.py

0 → 100755
+121 −0
Original line number Diff line number Diff line
import numpy as np

def bam_cigar_to_invs(aln):
    invs = []
    start = aln.reference_start
    end = aln.reference_end
    strand = '-' if aln.is_reverse else '+'
    left = start
    right = left
    has_ins = False
    for op, ln in aln.cigar:
        if op in (1, 4, 5):
            # does not consume reference
            continue
        elif op in (0, 2, 7, 8):
            # consume reference but do not add to invs yet
            right += ln
        elif op == 3:
            invs.append([left, right])
            left = right + ln
            right = left
    if right > left:
        invs.append([left, right])
    assert invs[0][0] == start
    assert invs[-1][1] == end
    return start, end, strand, np.array(invs)


def pair_filter(filt_type):
    if filt_type == 'both' or filt_type == 'single':
        def _pair_filter(aln):
            return True
    elif filt_type == '1':
        def _pair_filter(aln):
            return aln.is_read1
    elif filt_type == '2':
        def _pair_filter(aln):
            return aln.is_read2
    return _pair_filter


def strand_filter(filt_type, pair_type):
    if pair_type != "both":
        # if we are just using either read1 or 2,
        # assume that "same" or "opposite" refer
        # to the orientation of that read to the
        # strand
        if filt_type == 'same':
            def _strand_filter(aln, strand):
                is_reverse = strand == '-'
                return aln.is_reverse == is_reverse
        if filt_type == 'opposite':
            def _strand_filter(aln, strand):
                is_reverse = strand == '-'
                return aln.is_reverse != is_reverse
    else:
        # in paired ended data, assume that reads1&2 are
        # on the opposite strand from each other and
        # "same" or "opposite" is referring to read1
        # i.e. in paired mode, "same" returns read1 from the same
        # strand and read2 from the opposite strand
        if filt_type == 'same':
            def _strand_filter(aln, strand):
                is_reverse = strand == '-'
                strand_same = aln.is_reverse == is_reverse
                return aln.is_read1 == strand_same
        if filt_type == 'opposite':
            def _strand_filter(aln, strand):
                is_reverse = strand == '-'
                strand_same = aln.is_reverse == is_reverse
                return aln.is_read1 != strand_same
    return _strand_filter



def bam_query_iterator(bam, *args, **kwargs):
    strand = kwargs.pop('strand', None) # + or -
    orient = kwargs.pop('read_strand', 'same') # same or opposite
    pairs = kwargs.pop('pairs', 'single') # both, single, "1" or "2"
    strand_filt = strand_filter(orient, pairs)
    pair_filt = pair_filter(pairs)
    if strand is None or strand == '.':
        for aln in bam.fetch(*args, **kwargs):
            if pair_filt(aln):
                yield bam_cigar_to_invs(aln)
    elif strand in '+-':
        for aln in bam.fetch(*args, **kwargs):
            if strand_filt(aln, strand) and pair_filt(aln):
                yield bam_cigar_to_invs(aln)
    else:
        raise ValueError('strand is not one of +-.')


def add_read_to_cov(cov, aln, offset):
    *_, invs = bam_cigar_to_invs(aln)
    for s, e in invs:
        s = max(s - offset, 0)
        e = e - offset
        cov[s: e] += 1
    return cov


def per_base_bam_coverage(bam, chrom, start, end, **kwargs):
    strand = kwargs.pop('strand', None)
    orient = kwargs.pop('read_strand', 'same')
    pairs = kwargs.pop('pairs', 'both')
    strand_filt = strand_filter(orient, pairs)
    pair_filt = pair_filter(pairs)
    cov = np.zeros(end - start, dtype='int64')
    if strand is None or strand == '.':
        for aln in bam.fetch(chrom, start, end, **kwargs):
            if pair_filt(aln):
                cov = add_read_to_cov(cov, aln, start)
    elif strand in '+-':
        is_reverse = strand == '-'
        for aln in bam.fetch(chrom, start, end, **kwargs):
            if strand_filt(aln, strand) and pair_filt(aln):
                cov = add_read_to_cov(cov, aln, start)
    else:
        raise ValueError('strand is not one of +-.')
    return cov
 No newline at end of file

d3pendr/d3pendr.py

0 → 100755
+147 −0
Original line number Diff line number Diff line
import numpy as np
import pandas as pd
from statsmodels.stats.multitest import multipletests
from joblib import Parallel, delayed

from .multibam import MultiBamParser, MultiBigWigParser, bam_or_bw
from .gtf import gtf_iterator, chunk_gtf_records
from .tpe import get_tpe_distribution, get_apa_tpes
from .stats import d3pendr_stats


# columns in the output bed format
RESULTS_COLUMNS = [
    'chrom', 'start', 'end', 'gene_id', 'score', 'strand',
    'wass_dist', 'wass_dir', 'wass_pval', 'wass_fdr',
    'nreads_cntrl', 'nreads_treat'
]

TPE_APA_RESULTS_COLUMNS = [
    'chrom', 'start', 'end', 'gene_id', 'strand',
    'nreads_cntrl', 'nreads_treat',
    'frac_cntrl', 'frac_treat', 'relative_change'
]
        

def _process_gtf_records(gtf_records, treat_bam_fns, cntrl_bam_fns,
                         read_strand, read_end, paired_end_read,
                         min_read_overlap, min_reads,
                         bootstraps, threshold,
                         use_gamma_model, test_homogeneity,
                         is_bam, find_apa_tpe_sites, tpe_sigma,
                         tpe_min_reads, tpe_min_rel_change):
    results = []
    tpe_apa_res = []

    parser = MultiBamParser if is_bam else MultiBigWigParser

    with parser(cntrl_bam_fns) as cntrl_bam, parser(treat_bam_fns) as treat_bam:

        for chrom, start, end, gene_id, strand in gtf_records:
            cntrl_distribs, nreads_cntrl = get_tpe_distribution(
                cntrl_bam, chrom, start, end, strand,
                min_read_overlap, read_strand, read_end,
                paired_end_read, is_bam
            )
            treat_distribs, nreads_treat = get_tpe_distribution(
                treat_bam, chrom, start, end, strand,
                min_read_overlap, read_strand, read_end,
                paired_end_read, is_bam
            )

            if (nreads_cntrl >= min_reads).all() and (nreads_treat >= min_reads).all():
                wass_dist, wass_dir, wass_pval = d3pendr_stats(
                    cntrl_distribs, treat_distribs,
                    bootstraps=bootstraps, threshold=threshold,
                    use_gamma_model=use_gamma_model,
                    test_homogeneity=test_homogeneity,
                )
                results.append([
                    chrom, start, end, gene_id, round(wass_dist), strand,
                    wass_dist, wass_dir, wass_pval, 1, # placeholder for wasserstein test fdr
                    nreads_cntrl, nreads_treat
                ])

                if find_apa_tpe_sites and wass_pval <= threshold:
                    tpes = get_apa_tpes(
                        cntrl_distribs, sum(nreads_cntrl),
                        treat_distribs, sum(nreads_treat),
                        tpe_sigma, tpe_min_reads, tpe_min_rel_change
                    )
                    for t in tpes:
                        tpe_start, tpe_end, *res = t
                        # revert tpe coords to absolute
                        if strand == '+':
                            tpe_start = tpe_start + start
                            tpe_end = tpe_end + start
                        elif strand == '-':
                            tpe_start, tpe_end = tpe_end, tpe_start
                            tpe_start = end - tpe_start
                            tpe_end = end - tpe_end

                        tpe_apa_res.append([
                            chrom, tpe_start, tpe_end, gene_id, strand, *res
                        ])

    results = pd.DataFrame(results, columns=RESULTS_COLUMNS)

    if find_apa_tpe_sites:
        tpe_apa_res = pd.DataFrame(tpe_apa_res, columns=TPE_APA_RESULTS_COLUMNS)
    else:
        tpe_apa_res = None
    return results, tpe_apa_res


def run_d3pendr(gtf_fn, treat_bam_fns, cntrl_bam_fns,
                read_strand, read_end, paired_end_read,
                min_read_overlap, min_reads_per_cond,
                extend_gene_five_prime, use_5utr,
                extend_gene_three_prime,
                by_locus,
                max_terminal_intron_size,
                min_terminal_exon_size,
                bootstraps, threshold,
                use_gamma_model, test_homogeneity,
                find_tpe_sites, tpe_sigma,
                tpe_min_reads, tpe_min_rel_change,
                processes):

    filetype = bam_or_bw(*treat_bam_fns, *cntrl_bam_fns)

    args = (
        treat_bam_fns, cntrl_bam_fns,
        read_strand, read_end, paired_end_read,
        min_read_overlap, min_reads_per_cond,
        bootstraps, threshold,
        use_gamma_model, test_homogeneity,
        filetype, find_tpe_sites, tpe_sigma,
        tpe_min_reads, tpe_min_rel_change,
    )
    gtf_it = gtf_iterator(
        gtf_fn, extend_gene_five_prime, use_5utr, extend_gene_three_prime,
        by_locus, max_terminal_intron_size, min_terminal_exon_size,
    )
    if processes == 1:
        # run on main process
        results, tpa_apa_results = _process_gtf_records(
            gtf_it, *args
        )
    else:
        results = Parallel(n_jobs=processes)(
            delayed(_process_gtf_records)(gtf_chunk, *args)
            for gtf_chunk in chunk_gtf_records(gtf_it, processes)
        )
        results, tpe_apa_results = zip(*results)
        results = pd.concat(results)
        if find_tpe_sites:
            tpe_apa_results = pd.concat(tpe_apa_results)
        else:
            tpe_apa_results = None

    _, results['wass_fdr'], *_ = multipletests(results.wass_pval, method='fdr_bh')

    if find_tpe_sites:
        tpe_apa_results = tpe_apa_results[
            tpe_apa_results.gene_id.isin(results.query(f'wass_fdr <= {threshold}').gene_id)
        ]
    return results, tpe_apa_results
 No newline at end of file

d3pendr/gtf.py

0 → 100755
+151 −0
Original line number Diff line number Diff line
import re
import numpy as np


def chunk_gtf_records(gtf_records, processes):
    # read the whole gtf file
    gtf_records = list(gtf_records)
    nrecords = len(gtf_records)
    n, r = divmod(nrecords, processes)
    split_points = ([0] + r * [n + 1] + (processes - r) * [n])
    split_points = np.cumsum(split_points)
    for i in range(processes):
        start = split_points[i]
        end = split_points[i + 1]
        yield gtf_records[start: end]


def flatten_intervals(invs):
    flattened = []
    all_invs = iter(np.sort(invs, axis=0))
    inv_start, inv_end = next(all_invs)
    for start, end in all_invs:
        if start <= inv_end:
            inv_end = max(inv_end, end)
        else:
            flattened.append([inv_start, inv_end])
            inv_start, inv_end = start, end
    if not flattened or flattened[-1] != [inv_start, inv_end]:
        flattened.append([inv_start, inv_end])
    return np.array(flattened)


def filter_terminal_exons(invs, max_intron_size, min_exon_size):
    if len(invs) == 1:
        return invs
    else:
        l_ex = invs[0, 1] - invs[0, 0]
        l_in = invs[1, 0] - invs[0, 1]
        if (l_ex < min_exon_size) or (l_in >= max_intron_size):
            invs = invs[1:]
            if len(invs) == 1:
                return invs
        else:
            r_ex = invs[-1, 1] - invs[-1, 0]
            r_in = invs[-1, 0] - invs[-2, 1]
            if (r_ex < min_exon_size) or (r_in >= max_intron_size):
                invs = invs[:-1]
    return invs


def get_record_range(invs,
                     max_terminal_intron_size=None,
                     min_terminal_exon_size=None,
                     filter_=True):
    invs = flatten_intervals(invs)
    if filter_:
        invs = filter_terminal_exons(
            invs,
            max_terminal_intron_size,
            min_terminal_exon_size,
        )
    return invs[0, 0], invs[-1, 1]


def get_gtf_attribute(gtf_record, attribute):
    try:
        attr = re.search(f'{attribute} "(.+?)";', gtf_record[8]).group(1)
    except AttributeError:
        raise ValueError(
            f'Could not parse attribute {attribute} '
            f'from GTF with feature type {record[2]}'
        )
    return attr


def gtf_iterator(gtf_fn,
                 extend_gene_five_prime=0,
                 use_5utr=False,
                 extend_gene_three_prime=0,
                 by_locus=True,
                 max_terminal_intron_size=100_000,
                 min_terminal_exon_size=20):
    gtf_records = {}
    if by_locus:
        gene_to_locus_mapping = {}
    with open(gtf_fn) as gtf:
        for i, record in enumerate(gtf):
            record = record.split('\t')
            chrom, _, feat_type, start, end, _, strand = record[:7]
            start = int(start) - 1
            end = int(end)
            if feat_type == 'transcript' and by_locus:
                locus_id = get_gtf_attribute(record, 'locus')
                gene_id = get_gtf_attribute(record, 'gene_id')
                gene_to_locus_mapping[gene_id] = locus_id
            elif feat_type == 'CDS' or feat_type == 'exon':
                gene_id = get_gtf_attribute(record, 'gene_id')
                idx = (chrom, gene_id, strand)
                if idx not in gtf_records:
                    gtf_records[idx] = {}
                if feat_type not in gtf_records[idx]:
                    gtf_records[idx][feat_type] = []
                gtf_records[idx][feat_type].append((start, end))

    if by_locus:
        # regroup gene invs by locus id:
        gtf_records_by_locus = {}
        for (chrom, gene_id, strand), feat_invs in gtf_records.items():
            locus_id = gene_to_locus_mapping[gene_id]
            new_idx = (chrom, locus_id, strand)
            if new_idx not in gtf_records_by_locus:
                gtf_records_by_locus[new_idx] = {}
            for feat_type, invs in feat_invs.items():
                if feat_type not in gtf_records_by_locus[new_idx]:
                    gtf_records_by_locus[new_idx][feat_type] = []
                gtf_records_by_locus[new_idx][feat_type] += invs
        gtf_records = gtf_records_by_locus

    # once whole file is parsed yield the intervals
    for (chrom, gene_id, strand), feat_invs in gtf_records.items():
        exon_start, exon_end = get_record_range(
            feat_invs['exon'],
            max_terminal_intron_size,
            min_terminal_exon_size,
        )
        try:
            cds_start, cds_end = get_record_range(
                feat_invs['CDS'],
                filter_=False,
            )
        except KeyError:
            # non-coding RNA
            cds_start, cds_end = exon_start, exon_end

        # remove region corresponding to 5'UTR if necessary
        if use_5utr:
            gene_start = exon_start
            gene_end = exon_end
        else:
            gene_start = cds_start if strand == '+' else exon_start
            gene_end = exon_end if strand == '+' else cds_end
        

        # add extensions to 3' and 5' ends
        start_ext, end_ext = extend_gene_five_prime, extend_gene_three_prime
        if strand == '-':
            start_ext, end_ext = end_ext, start_ext
        gene_start = max(0, gene_start - start_ext)
        gene_end = gene_end + end_ext

        yield chrom, gene_start, gene_end, gene_id, strand
 No newline at end of file
+7 −8
Original line number Diff line number Diff line
import click

from .io import write_output_bed, write_apa_site_bed
from .ref_guided import ref_guided_diff_tpe
from .output import (
    write_wass_test_output_bed,
    write_apa_site_bed,
)
from .d3pendr import run_d3pendr


@click.command()
@@ -52,11 +55,7 @@ def d3pendr(treatment_fns, control_fns,
    Outputs bed6 format with extra columns.
    '''

    if paired_end_read == 'single':
        # for options single or both, all reads are used
        paired_end_read = 'both'

    results, apa_sites = ref_guided_diff_tpe(
    results, apa_sites = run_d3pendr(
        annotation_gtf_fn,
        treatment_fns, control_fns,
        read_strand, read_end, paired_end_read,
@@ -72,7 +71,7 @@ def d3pendr(treatment_fns, control_fns,
        processes
    )
    output_bed = f'{output_prefix}.apa_results.bed'
    write_output_bed(output_bed, results)
    write_wass_test_output_bed(output_bed, results)
    if write_apa_sites:
        apa_site_bed = f'{output_prefix}.apa_sites.bed'
        write_apa_site_bed(apa_site_bed, apa_sites)

d3pendr/multibam.py

0 → 100755
+147 −0
Original line number Diff line number Diff line
import numpy as np

import pysam
import pyBigWig as pybw

from .bam import bam_query_iterator, per_base_bam_coverage

def _bam_or_bw(fn):
    if fn.endswith('.bam') or fn.endswith('.sam'):
        return True
    elif fn.endswith('.bw') or fn.lower().endswith('.bigwig'):
        return False
    else:
        raise ValueError('files must be bam, sam, or bigwig format')


def bam_or_bw(*fns):
    filetypes = [_bam_or_bw(fn) for fn in fns]
    assert all([t == filetypes[0] for t in filetypes])
    filetype = filetypes[0]
    return filetype


class MultiParser(object):

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.close()


class MultiBamParser(MultiParser):

    def __init__(self, bam_fns):
        self.handles = [
            pysam.AlignmentFile(bam_fn) for bam_fn in bam_fns
        ]
        self._calc_norm_factors(self.handles)
        self.closed = False

    def fetch(self, *args, **kwargs):
        queries = [
            bam_query_iterator(bam, *args, **kwargs)
            for bam in self.handles
        ]
        return queries

    def coverage(self, chrom, start, end, **kwargs):
        normalise = kwargs.pop('normalise', False)
        coverage = [
            per_base_bam_coverage(bam, chrom, start, end, **kwargs)
            for bam in self.handles
        ]
        coverage = np.array(coverage)

        if normalise:
            coverage *= self.norm_factors
        
        return coverage

    def _calc_norm_factors(self, bams):
        p = np.reshape([b.mapped for b in bams], newshape=(-1, 1))
        n = p.mean()
        self.norm_factors = n / p

    def close(self):
        for bam in self.handles:
            bam.close()


class MultiBigWigParser(MultiParser):

    def __init__(self, bw_fns):
        # attempt to infer if stranded, each bw_fn should be 
        # comma separated list
        bw_fns = [tuple(fn.split(',')) for fn in bw_fns]
        if all([len(fn) == 2 for fn in bw_fns]):
            stranded = True
        elif all([len(fn) == 1 for fn in bw_fns]):
            stranded = False
        else:
            raise ValueError('Please provide either single bw files '
                             'or comma separated pos,neg bw files')
        if stranded:
            self.handles = [
                (pybw.open(bw_fn[0]), pybw.open(bw_fn[1]))
                for bw_fn in bw_fns
            ]
        else:
            self.handles = [
                (pybw.open(bw_fn[0]),) for bw_fn in bw_fns
            ]
        self._calc_norm_factors(self.handles)
        self.closed = False
        self.stranded = stranded

    def fetch(self, *args, **kwargs):
        raise NotImplemented('fetch not possible for bigwig')

    def coverage(self, chrom, start, end, strand=None, normalise=False, **kwargs):
        if strand is not None and not self.stranded:
            raise ValueError('cannot specify strand on unstranded bigwigs')
        if self.stranded:
            if strand == '+':
                coverage = [
                    pos_bw.values(chrom, start, end, numpy=True)
                    for pos_bw, neg_bw in self.handles
                ]
            elif strand == '-':
                coverage = [
                    neg_bw.values(chrom, start, end, numpy=True)
                    for pos_bw, neg_bw in self.handles
                ]
            elif strand is None or strand == '.':
                coverage = [
                    np.nansum([
                        pos_bw.values(chrom, start, end, numpy=True),
                        neg_bw.values(chrom, start, end, numpy=True)
                    ], axis=0)
                    for pos_bw, neg_bw in self.handles
                ]
        else:
            coverage = [
                bw[0].values(chrom, start, end, numpy=True)
                for bw in self.handles
            ]
        coverage = np.array(coverage)
        coverage[np.isnan(coverage)] = 0

        if normalise:
            coverage *= self.norm_factors
        
        return coverage

    def _calc_norm_factors(self, bws):
        p = np.reshape(
            [sum([s.header()['sumData'] for s in b]) for b in bws],
            newshape=(-1, 1)
        )
        n = p.mean()
        self.norm_factors = n / p

    def close(self):
        for bws in self.handles:
            for bw in bws:
                bw.close()
 No newline at end of file
Loading