Commit 5e5fd9c2 authored by mparker2's avatar mparker2
Browse files

delete old files

parent a7d85cb0
Loading
Loading
Loading
Loading

d3pendr/invs.py

deleted100755 → 0
+0 −56
Original line number Diff line number Diff line
def relative_tpe(gene_start, gene_end, aln_start, aln_end, strand, read_end):
    # make tpe position relative to gene start rather than 
    # genomic coordinates
    if read_end == '3':
        if strand == '+':
            return aln_end - gene_start
        else:
            return gene_end - aln_start
    elif read_end == '5':
        if strand == '+':
            return aln_start - gene_start
        else:
            return gene_end - aln_end


def intersect(inv_a, inv_b):
    a_start, a_end = inv_a
    b_start, b_end = inv_b
    if a_end < b_start or a_start > b_end:
        return 0
    else:
        s = max(a_start, b_start)
        e = min(a_end, b_end)
        return e - s


def intersect_spliced_invs(invs_a, invs_b):
    score = 0
    invs_a = iter(invs_a)
    invs_b = iter(invs_b)
    a_start, a_end = next(invs_a)
    b_start, b_end = next(invs_b)
    while True:
        if a_end < b_start:
            try:
                a_start, a_end = next(invs_a)
            except StopIteration:
                break
        elif a_start > b_end:
            try:
                b_start, b_end = next(invs_b)
            except StopIteration:
                break
        else:
            score += intersect([a_start, a_end], [b_start, b_end])
            if a_end > b_end:
                try:
                    b_start, b_end = next(invs_b)
                except StopIteration:
                    break
            else:
                try:
                    a_start, a_end = next(invs_a)
                except StopIteration:
                    break
    return score
 No newline at end of file

d3pendr/io.py

deleted100755 → 0
+0 −323
Original line number Diff line number Diff line
import re
import numpy as np
import pysam
import pyBigWig as pybw


def bam_cigar_to_invs(aln):
    invs = []
    start = aln.reference_start
    end = aln.reference_end
    strand = '-' if aln.is_reverse else '+'
    left = start
    right = left
    has_ins = False
    for op, ln in aln.cigar:
        if op in (1, 4, 5):
            # does not consume reference
            continue
        elif op in (0, 2, 7, 8):
            # consume reference but do not add to invs yet
            right += ln
        elif op == 3:
            invs.append([left, right])
            left = right + ln
            right = left
    if right > left:
        invs.append([left, right])
    assert invs[0][0] == start
    assert invs[-1][1] == end
    return start, end, strand, np.array(invs)


def pair_filter(filt_type):
    if filt_type == 'both':
        def _pair_filter(aln):
            return True
    elif filt_type == '1':
        def _pair_filter(aln):
            return aln.is_read1
    elif filt_type == '2':
        def _pair_filter(aln):
            return aln.is_read2
    return _pair_filter


def bam_query_iterator(bam, *args, **kwargs):
    strand = kwargs.pop('strand', None)
    pairs = kwargs.pop('pairs', 'both')
    pair_filt = pair_filter(pairs)
    if strand is None or strand == '.':
        for aln in bam.fetch(*args, **kwargs):
            if pair_filt(aln):
                yield bam_cigar_to_invs(aln)
    elif strand in '+-':
        is_reverse = strand == '-'
        for aln in bam.fetch(*args, **kwargs):
            if is_reverse == aln.is_reverse and pair_filt(aln):
                yield bam_cigar_to_invs(aln)
    else:
        raise ValueError('strand is not one of +-.')


class MultiParser(object):

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.close()


class MultiBamParser(MultiParser):

    def __init__(self, bam_fns):
        self.handles = {
            bam_fn: pysam.AlignmentFile(bam_fn) for bam_fn in bam_fns
        }
        self.closed = False

    def fetch(self, *args, **kwargs):
        queries = [
            bam_query_iterator(bam, *args, **kwargs)
            for bam in self.handles.values()
        ]
        return queries

    def close(self):
        for bam in self.handles.values():
            bam.close()


class MultiBigWigParser(MultiParser):

    def __init__(self, bw_fns):
        # attempt to infer if stranded, each bw_fn should be comma separated list
        bw_fns = [tuple(fn.split(',')) for fn in bw_fns]
        if all([len(fn) == 2 for fn in bw_fns]):
            stranded = True
        elif all([len(fn) == 1 for fn in bw_fns]):
            stranded = False
        else:
            raise ValueError('Please provide either single bw files or comma separated pos,neg bw files')
        if stranded:
            self.handles = {
                bw_fn: (pybw.open(bw_fn[0]), pybw.open(bw_fn[1])) for bw_fn in bw_fns
            }
        else:
            self.handles = {
                bw_fn: (pybw.open(bw_fn[0]),) for bw_fn in bw_fns
            }
        self.closed = False
        self.stranded = stranded

    def fetch(self, chrom, start, end, strand=None):
        if strand is not None and not self.stranded:
            raise ValueError('cannot specify strand on unstranded bigwigs')
        if self.stranded:
            if strand == '+':
                queries = [
                    pos_bw.values(chrom, start, end, numpy=True)
                    for pos_bw, neg_bw in self.handles.values()
                ]
            elif strand == '-':
                queries = [
                    neg_bw.values(chrom, start, end, numpy=True)
                    for pos_bw, neg_bw in self.handles.values()
                ]
            elif strand is None or strand == '.':
                queries = [
                    np.nansum([
                        pos_bw.values(chrom, start, end, numpy=True),
                        neg_bw.values(chrom, start, end, numpy=True)
                    ], axis=0)
                    for pos_bw, neg_bw in self.handles.values()
                ]
        else:
            queries = [
                bw[0].values(chrom, start, end, numpy=True)
                for bw in self.handles
            ]
        return queries

    def close(self):
        for bws in self.handles.values():
            for bw in bws:
                bw.close()


def bam_or_bw(fn):
    if fn.endswith('.bam') or fn.endswith('.sam'):
        return True
    elif fn.endswith('.bw') or fn.lower().endswith('.bigwig'):
        return False
    else:
        raise ValueError('files must be bam, sam, or bigwig format')


def flatten_intervals(invs):
    flattened = []
    all_invs = iter(np.sort(invs, axis=0))
    inv_start, inv_end = next(all_invs)
    for start, end in all_invs:
        if start <= inv_end:
            inv_end = max(inv_end, end)
        else:
            flattened.append([inv_start, inv_end])
            inv_start, inv_end = start, end
    if not flattened or flattened[-1] != [inv_start, inv_end]:
        flattened.append([inv_start, inv_end])
    return np.array(flattened)


def filter_terminal_exons(invs, max_intron_size, min_exon_size):
    if len(invs) == 1:
        return invs
    else:
        l_ex = invs[0, 1] - invs[0, 0]
        l_in = invs[1, 0] - invs[0, 1]
        if (l_ex < min_exon_size) or (l_in >= max_intron_size):
            invs = invs[1:]
            if len(invs) == 1:
                return invs
        else:
            r_ex = invs[-1, 1] - invs[-1, 0]
            r_in = invs[-1, 0] - invs[-2, 1]
            if (r_ex < min_exon_size) or (r_in >= max_intron_size):
                invs = invs[:-1]
    return invs


def get_record_range(invs,
                     max_terminal_intron_size=None,
                     min_terminal_exon_size=None,
                     filter_=True):
    invs = flatten_intervals(invs)
    if filter_:
        invs = filter_terminal_exons(
            invs,
            max_terminal_intron_size,
            min_terminal_exon_size,
        )
    return invs[0, 0], invs[-1, 1]


def get_gtf_attribute(gtf_record, attribute):
    try:
        attr = re.search(f'{attribute} "(.+?)";', gtf_record[8]).group(1)
    except AttributeError:
        raise ValueError(
            f'Could not parse attribute {attribute} '
            f'from GTF with feature type {record[2]}'
        )
    return attr


def gtf_iterator(gtf_fn,
                 extend_gene_five_prime=0,
                 use_5utr=False,
                 extend_gene_three_prime=0,
                 by_locus=True,
                 max_terminal_intron_size=100_000,
                 min_terminal_exon_size=20):
    gtf_records = {}
    if by_locus:
        gene_to_locus_mapping = {}
    with open(gtf_fn) as gtf:
        for i, record in enumerate(gtf):
            record = record.split('\t')
            chrom, _, feat_type, start, end, _, strand = record[:7]
            start = int(start) - 1
            end = int(end)
            if feat_type == 'transcript' and by_locus:
                locus_id = get_gtf_attribute(record, 'locus')
                gene_id = get_gtf_attribute(record, 'gene_id')
                gene_to_locus_mapping[gene_id] = locus_id
            elif feat_type == 'CDS' or feat_type == 'exon':
                gene_id = get_gtf_attribute(record, 'gene_id')
                idx = (chrom, gene_id, strand)
                if idx not in gtf_records:
                    gtf_records[idx] = {}
                if feat_type not in gtf_records[idx]:
                    gtf_records[idx][feat_type] = []
                gtf_records[idx][feat_type].append((start, end))

    if by_locus:
        # regroup gene invs by locus id:
        gtf_records_by_locus = {}
        for (chrom, gene_id, strand), feat_invs in gtf_records.items():
            locus_id = gene_to_locus_mapping[gene_id]
            new_idx = (chrom, locus_id, strand)
            if new_idx not in gtf_records_by_locus:
                gtf_records_by_locus[new_idx] = {}
            for feat_type, invs in feat_invs.items():
                if feat_type not in gtf_records_by_locus[new_idx]:
                    gtf_records_by_locus[new_idx][feat_type] = []
                gtf_records_by_locus[new_idx][feat_type] += invs
        gtf_records = gtf_records_by_locus

    # once whole file is parsed yield the intervals
    for (chrom, gene_id, strand), feat_invs in gtf_records.items():
        exon_start, exon_end = get_record_range(
            feat_invs['exon'],
            max_terminal_intron_size,
            min_terminal_exon_size,
        )
        try:
            cds_start, cds_end = get_record_range(
                feat_invs['CDS'],
                filter_=False,
            )
        except KeyError:
            # non-coding RNA
            cds_start, cds_end = exon_start, exon_end

        # remove region corresponding to 5'UTR if necessary
        if use_5utr:
            gene_start = exon_start
            gene_end = exon_end
        else:
            gene_start = cds_start if strand == '+' else exon_start
            gene_end = exon_end if strand == '+' else cds_end
        

        # add extensions to 3' and 5' ends
        start_ext, end_ext = extend_gene_five_prime, extend_gene_three_prime
        if strand == '-':
            start_ext, end_ext = end_ext, start_ext
        gene_start = max(0, gene_start - start_ext)
        gene_end = gene_end + end_ext

        yield chrom, gene_start, gene_end, gene_id, strand


def write_output_bed(output_bed_fn, results):
    with open(output_bed_fn, 'w') as bed:
        for (
            chrom, start, end, gene_id, score, strand,
            wass_dist, wass_dir, wass_pval, wass_fdr,
            nreads_cntrl, nreads_treat
        ) in results.itertuples(index=False):
            record = (f'{chrom}\t{int(start):d}\t{int(end):d}\t'
                      f'{gene_id}\t{int(score):d}\t{strand}\t'
                      f'{wass_dist:.1f}\t{wass_dir:.1f}\t'
                      f'{wass_pval:.3g}\t{wass_fdr:.3g}\t'
                      f'{sum(nreads_cntrl):d}\t{sum(nreads_treat):d}\n')
            bed.write(record)


def write_apa_site_bed(output_bed_fn, results):
    with open(output_bed_fn, 'w') as bed:
        for (
            chrom, start, end, gene_id, strand,
            cntrl_count, treat_count,
            cntrl_frac, treat_frac,
            relative_change
        ) in results.itertuples(index=False):
            direction = int(relative_change > 0)
            record = (f'{chrom}\t{start:d}\t{end:d}\t'
                      f'{gene_id}\t{direction:d}\t{strand}\t'
                      f'{cntrl_count:d}\t{treat_count:d}\t'
                      f'{cntrl_frac:.2f}\t{treat_frac:.2f}\t'
                      f'{relative_change:.2f}\n')
            bed.write(record)
 No newline at end of file

d3pendr/ref_guided.py

deleted100755 → 0
+0 −295
Original line number Diff line number Diff line
import itertools as it
from bisect import bisect_right
from collections import defaultdict, Counter

import numpy as np
import pandas as pd
from scipy.ndimage import gaussian_filter1d
from statsmodels.stats.multitest import multipletests
from joblib import Parallel, delayed

from .io import (
    MultiBamParser, MultiBigWigParser, bam_or_bw,
    gtf_iterator, write_output_bed
)
from .invs import relative_tpe, intersect_spliced_invs
from .stats import tpe_stats


# columns in the output bed format
RESULTS_COLUMNS = [
    'chrom', 'start', 'end', 'gene_id', 'score', 'strand',
    'wass_dist', 'wass_dir', 'wass_pval', 'wass_fdr',
    'nreads_cntrl', 'nreads_treat'
]

TPE_APA_RESULTS_COLUMNS = [
    'chrom', 'start', 'end', 'gene_id', 'strand',
    'nreads_cntrl', 'nreads_treat',
    'frac_cntrl', 'frac_treat', 'relative_change'
]


def get_tpe_distribution(mbam, chrom, gene_start, gene_end, strand,
                         min_read_overlap, read_strand, read_end,
                         paired_end_read, is_bam):
    tpe_distribs = []
    nreads = []

    if read_strand == 'opposite':
        fetch_strand = '+' if strand == '-' else '-'
    elif read_strand == 'unstranded':
        fetch_strand = None
    else:
        fetch_strand = strand

    # fetch parsed alignments from bam files, filtering by strand
    if is_bam:

        for sample in mbam.fetch(
                chrom, gene_start, gene_end,
                strand=fetch_strand, pairs=paired_end_read):
            # sample is an iterator of parsed bam file alignments
            
            sample_distrib = []
            n = 0
            for aln_start, aln_end, aln_strand, aln_invs in sample:
                # calculate the fraction of the read alignment overlapping the
                # annotated gene
                aln_len = sum([e - s for s, e in aln_invs])
                i = intersect_spliced_invs(aln_invs, [(gene_start, gene_end),])

                # only use the read if it overlaps with the reference annotation
                # by at least min_read_overlap fraction of its aligned length
                if i / aln_len >= min_read_overlap:
                    tpe = relative_tpe(
                        gene_start, gene_end,
                        aln_start, aln_end,
                        strand, read_end
                    )
                    sample_distrib.append(tpe)
                    n += 1
            tpe_distribs.append(np.array(sample_distrib))
            nreads.append(n)

    else:
        for sample in mbam.fetch(
                chrom, gene_start, gene_end, strand=fetch_strand):
            # sample must be count data
            sample[np.isnan(sample)] = 0
            if np.mod(sample, 1).any():
                raise ValueError('Bigwig values are not discrete')
            sample = sample.astype(np.int)
            # sample is a hist of coverage over the gene from the bigwig file
            # it needs inverting for negative strand genes
            if strand == '-':
                sample = sample[::-1]
            sample_distrib = np.repeat(np.arange(gene_end - gene_start), sample)
            tpe_distribs.append(sample_distrib)
            nreads.append(len(sample_distrib))

    nreads = np.array(nreads)
    return tpe_distribs, nreads


def argrelmin_left_on_flat(arr, order):
    idx = []
    for i in range(order, len(arr) - order):
        if np.all(arr[i] < arr[i - order: i]) and np.all(arr[i] <= arr[i + 1: i + 1 + order]):
            idx.append(i)
    return np.array(idx)


def cluster_by_endpoint(endpoints, conds, sigma):
    offset = endpoints.min() - sigma * 3
    endpoints_scaled = endpoints - offset
    endpoints_max = endpoints_scaled.max()
    endpoints_dist = np.bincount(
        endpoints_scaled,
        minlength=endpoints_max + sigma * 3
    ).astype('f')
    endpoints_dist = gaussian_filter1d(
        endpoints_dist, sigma=sigma, mode='constant', cval=0
    )

    # find local minima in three prime positions
    cut_idx = argrelmin_left_on_flat(endpoints_dist, order=sigma)
    cut_idx = cut_idx + offset

    # classify alignments in relation to local minima
    cluster_idx = defaultdict(list)
    cluster_cond_count = defaultdict(list)
    for pos, c in zip(endpoints, conds):
        i = bisect_right(cut_idx, pos)
        cluster_idx[i].append(pos)
        cluster_cond_count[i].append(c)

    # get actual start/end points of cluster and cluster counts for each cond
    clusters = {}
    for i, pos in cluster_idx.items():
        inv = (min(pos), max(pos))
        clusters[inv] = Counter(cluster_cond_count[i])
    return clusters


def get_apa_tpes(cntrl_distrib, nreads_cntrl,
                 treat_distrib, nreads_treat,
                 sigma, min_count, min_rel_change):
    endpoints = np.concatenate([
        *cntrl_distrib, *treat_distrib
    ])
    conds = np.repeat(
        [0, 1],
        [nreads_cntrl, nreads_treat]
    )
    assert len(endpoints) == len(conds)
    tpes = cluster_by_endpoint(endpoints, conds, sigma)

    for (start, end), counts in tpes.items():
        cntrl_count = counts[0]
        treat_count = counts[1]
        if (cntrl_count + treat_count) >= min_count:
            cntrl_frac = cntrl_count / nreads_cntrl
            treat_frac = treat_count / nreads_treat
            relative_change = treat_frac - cntrl_frac
            if np.abs(relative_change) >= min_rel_change:
                yield (start, end, cntrl_count, treat_count,
                       cntrl_frac, treat_frac, relative_change)
        

def process_gtf_records(gtf_records, treat_bam_fns, cntrl_bam_fns,
                        read_strand, read_end, paired_end_read,
                        min_read_overlap, min_reads,
                        bootstraps, threshold,
                        use_gamma_model, test_homogeneity,
                        is_bam, find_apa_tpe_sites, tpe_sigma,
                        tpe_min_reads, tpe_min_rel_change):
    results = []
    tpe_apa_res = []

    parser = MultiBamParser if is_bam else MultiBigWigParser

    with parser(cntrl_bam_fns) as cntrl_bam, parser(treat_bam_fns) as treat_bam:

        for chrom, start, end, gene_id, strand in gtf_records:
            cntrl_distribs, nreads_cntrl = get_tpe_distribution(
                cntrl_bam, chrom, start, end, strand,
                min_read_overlap, read_strand, read_end,
                paired_end_read, is_bam
            )
            treat_distribs, nreads_treat = get_tpe_distribution(
                treat_bam, chrom, start, end, strand,
                min_read_overlap, read_strand, read_end,
                paired_end_read, is_bam
            )

            if (nreads_cntrl >= min_reads).all() and (nreads_treat >= min_reads).all():
                wass_dist, wass_dir, wass_pval = tpe_stats(
                    cntrl_distribs, treat_distribs,
                    bootstraps=bootstraps, threshold=threshold,
                    use_gamma_model=use_gamma_model,
                    test_homogeneity=test_homogeneity,
                )
                results.append([
                    chrom, start, end, gene_id, round(wass_dist), strand,
                    wass_dist, wass_dir, wass_pval, 1, # placeholder for wasserstein test fdr
                    nreads_cntrl, nreads_treat
                ])

                if find_apa_tpe_sites and wass_pval <= threshold:
                    tpes = get_apa_tpes(
                        cntrl_distribs, sum(nreads_cntrl),
                        treat_distribs, sum(nreads_treat),
                        tpe_sigma, tpe_min_reads, tpe_min_rel_change
                    )
                    for t in tpes:
                        tpe_start, tpe_end, *res = t
                        # revert tpe coords to absolute
                        if strand == '+':
                            tpe_start = tpe_start + start
                            tpe_end = tpe_end + start
                        elif strand == '-':
                            tpe_start, tpe_end = tpe_end, tpe_start
                            tpe_start = end - tpe_start
                            tpe_end = end - tpe_end

                        tpe_apa_res.append([
                            chrom, tpe_start, tpe_end, gene_id, strand, *res
                        ])

    results = pd.DataFrame(results, columns=RESULTS_COLUMNS)

    if find_apa_tpe_sites:
        tpe_apa_res = pd.DataFrame(tpe_apa_res, columns=TPE_APA_RESULTS_COLUMNS)
    else:
        tpe_apa_res = None
    return results, tpe_apa_res


def chunk_gtf_records(gtf_records, processes):
    # read the whole gtf file
    gtf_records = list(gtf_records)
    nrecords = len(gtf_records)
    n, r = divmod(nrecords, processes)
    split_points = ([0] + r * [n + 1] + (processes - r) * [n])
    split_points = np.cumsum(split_points)
    for i in range(processes):
        start = split_points[i]
        end = split_points[i + 1]
        yield gtf_records[start: end]


def ref_guided_diff_tpe(gtf_fn, treat_bam_fns, cntrl_bam_fns,
                        read_strand, read_end, paired_end_read,
                        min_read_overlap, min_reads_per_cond,
                        extend_gene_five_prime, use_5utr,
                        extend_gene_three_prime,
                        by_locus,
                        max_terminal_intron_size,
                        min_terminal_exon_size,
                        bootstraps, threshold,
                        use_gamma_model, test_homogeneity,
                        find_tpe_sites, tpe_sigma,
                        tpe_min_reads, tpe_min_rel_change,
                        processes):

    filetypes = [bam_or_bw(fn) for fn in it.chain(treat_bam_fns, cntrl_bam_fns)]
    assert all([t == filetypes[0] for t in filetypes])
    filetype = filetypes[0]
    args = (
        treat_bam_fns, cntrl_bam_fns,
        read_strand, read_end, paired_end_read,
        min_read_overlap, min_reads_per_cond,
        bootstraps, threshold,
        use_gamma_model, test_homogeneity,
        filetype, find_tpe_sites, tpe_sigma,
        tpe_min_reads, tpe_min_rel_change,
    )
    gtf_it = gtf_iterator(
        gtf_fn, extend_gene_five_prime, use_5utr, extend_gene_three_prime,
        by_locus, max_terminal_intron_size, min_terminal_exon_size,
    )
    if processes == 1:
        # run on main process
        results, tpa_apa_results = process_gtf_records(
            gtf_it, *args
        )
    else:
        results = Parallel(n_jobs=processes)(
            delayed(process_gtf_records)(gtf_chunk, *args)
            for gtf_chunk in chunk_gtf_records(gtf_it, processes)
        )
        results, tpe_apa_results = zip(*results)
        results = pd.concat(results)
        if find_tpe_sites:
            tpe_apa_results = pd.concat(tpe_apa_results)
        else:
            tpe_apa_results = None

    _, results['wass_fdr'], *_ = multipletests(results.wass_pval, method='fdr_bh')

    if find_tpe_sites:
        tpe_apa_results = tpe_apa_results[
            tpe_apa_results.gene_id.isin(results.query(f'wass_fdr <= {threshold}').gene_id)
        ]
    return results, tpe_apa_results
 No newline at end of file