Commit 648cb985 authored by mparker2's avatar mparker2
Browse files

add option to cluster by locus id

parent 08de6715
Loading
Loading
Loading
Loading
+115 −36
Original line number Diff line number Diff line
@@ -47,7 +47,7 @@ def bam_query_iterator(bam, *args, **kwargs):
    strand = kwargs.pop('strand', None)
    pairs = kwargs.pop('pairs', 'both')
    pair_filt = pair_filter(pairs)
    if strand is None:
    if strand is None or strand == '.':
        for aln in bam.fetch(*args, **kwargs):
            if pair_filt(aln):
                yield bam_cigar_to_invs(aln)
@@ -57,7 +57,7 @@ def bam_query_iterator(bam, *args, **kwargs):
            if is_reverse == aln.is_reverse and pair_filt(aln):
                yield bam_cigar_to_invs(aln)
    else:
        raise ValueError('strand is not one of +-')
        raise ValueError('strand is not one of +-.')


class MultiParser(object):
@@ -125,7 +125,7 @@ class MultiBigWigParser(MultiParser):
                    neg_bw.values(chrom, start, end, numpy=True)
                    for pos_bw, neg_bw in self.handles.values()
                ]
            elif strand is None:
            elif strand is None or strand == '.':
                queries = [
                    np.nansum([
                        pos_bw.values(chrom, start, end, numpy=True),
@@ -155,52 +155,131 @@ def bam_or_bw(fn):
        raise ValueError('files must be bam, sam, or bigwig format')


def gtf_iterator(gtf_fn, extend_gene_five_prime, ignore_5utr, extend_gene_three_prime):
def flatten_intervals(invs):
    flattened = []
    all_invs = iter(np.sort(invs, axis=0))
    inv_start, inv_end = next(all_invs)
    for start, end in all_invs:
        if start <= inv_end:
            inv_end = max(inv_end, end)
        else:
            flattened.append([inv_start, inv_end])
            inv_start, inv_end = start, end
    if not flattened or flattened[-1] != [inv_start, inv_end]:
        flattened.append([inv_start, inv_end])
    return np.array(flattened)


def filter_terminal_exons(invs, max_intron_size, min_exon_size):
    if len(invs) == 1:
        return invs
    else:
        l_ex = invs[0, 1] - invs[0, 0]
        l_in = invs[1, 0] - invs[0, 1]
        if (l_ex < min_exon_size) or (l_in >= max_intron_size):
            invs = invs[1:]
            if len(invs) == 1:
                return invs
        else:
            r_ex = invs[-1, 1] - invs[-1, 0]
            r_in = invs[-1, 0] - invs[-2, 1]
            if (r_ex < min_exon_size) or (r_in >= max_intron_size):
                invs = invs[:-1]
    return invs


def get_record_range(invs,
                     max_terminal_intron_size=None,
                     min_terminal_exon_size=None,
                     filter_=True):
    invs = flatten_intervals(invs)
    if filter_:
        invs = filter_terminal_exons(
            invs,
            max_terminal_intron_size,
            min_terminal_exon_size,
        )
    return invs[0, 0], invs[-1, 1]


def get_gtf_attribute(gtf_record, attribute):
    try:
        attr = re.search(f'{attribute} "(.+?)";', gtf_record[8]).group(1)
    except AttributeError:
        raise ValueError(
            f'Could not parse attribute {attribute} '
            f'from GTF with feature type {record[2]}'
        )
    return attr


def gtf_iterator(gtf_fn,
                 extend_gene_five_prime=0,
                 use_5utr=False,
                 extend_gene_three_prime=0,
                 by_locus=True,
                 max_terminal_intron_size=100_000,
                 min_terminal_exon_size=20):
    gtf_records = {}
    if by_locus:
        gene_to_locus_mapping = {}
    with open(gtf_fn) as gtf:
        for i, record in enumerate(gtf):
            record = record.split('\t')
            feat_type = record[2]
            if feat_type == 'CDS' or feat_type == 'exon':
                try:
                    gene_id = re.search('gene_id "(.+?)";', record[8]).group(1)
                except AttributeError:
                    raise ValueError(f'Could not parse gene_id from GTF line {i}')
                if gene_id not in gtf_records:
                    gtf_records[gene_id] = {
                        'chrom': record[0],
                        'strand': record[6]
                    }
                start = int(record[3]) - 1
                end = int(record[4])
                if feat_type not in gtf_records[gene_id]:
                    gtf_records[gene_id][feat_type] = (start, end)
                else:
                    curr_range = gtf_records[gene_id][feat_type]
                    new_range = (
                        min(curr_range[0], start),
                        max(curr_range[1], end),
                    )
                    gtf_records[gene_id][feat_type] = new_range
            chrom, _, feat_type, start, end, _, strand = record[:7]
            start = int(start) - 1
            end = int(end)
            if feat_type == 'transcript' and by_locus:
                locus_id = get_gtf_attribute(record, 'locus')
                gene_id = get_gtf_attribute(record, 'gene_id')
                gene_to_locus_mapping[gene_id] = locus_id
            elif feat_type == 'CDS' or feat_type == 'exon':
                gene_id = get_gtf_attribute(record, 'gene_id')
                idx = (chrom, gene_id, strand)
                if idx not in gtf_records:
                    gtf_records[idx] = {}
                if feat_type not in gtf_records[idx]:
                    gtf_records[idx][feat_type] = []
                gtf_records[idx][feat_type].append((start, end))

    if by_locus:
        # regroup gene invs by locus id:
        gtf_records_by_locus = {}
        for (chrom, gene_id, strand), feat_invs in gtf_records.items():
            locus_id = gene_to_locus_mapping[gene_id]
            new_idx = (chrom, locus_id, strand)
            if new_idx not in gtf_records_by_locus:
                gtf_records_by_locus[new_idx] = {}
            for feat_type, invs in feat_invs.items():
                if feat_type not in gtf_records_by_locus[new_idx]:
                    gtf_records_by_locus[new_idx][feat_type] = []
                gtf_records_by_locus[new_idx][feat_type] += invs
        gtf_records = gtf_records_by_locus

    # once whole file is parsed yield the intervals
    for gene_id, gene_info in gtf_records.items():
        chrom = gene_info['chrom']
        strand = gene_info['strand']
        exon_start, exon_end = gene_info['exon']
    for (chrom, gene_id, strand), feat_invs in gtf_records.items():
        exon_start, exon_end = get_record_range(
            feat_invs['exon'],
            max_terminal_intron_size,
            min_terminal_exon_size,
        )
        try:
            cds_start, cds_end = gene_info['CDS']
            cds_start, cds_end = get_record_range(
                feat_invs['CDS'],
                filter_=False,
            )
        except KeyError:
            # non-coding RNA
            cds_start, cds_end = gene_info['exon']
            cds_start, cds_end = exon_start, exon_end

        # remove region corresponding to 5'UTR if necessary
        if ignore_5utr:
            gene_start = cds_start if strand == '+' else exon_start
            gene_end = exon_end if strand == '+' else cds_end
        else:
        if use_5utr:
            gene_start = exon_start
            gene_end = exon_end
        else:
            gene_start = cds_start if strand == '+' else exon_start
            gene_end = exon_end if strand == '+' else cds_end
        

        # add extensions to 3' and 5' ends
        start_ext, end_ext = extend_gene_five_prime, extend_gene_three_prime
+7 −0
Original line number Diff line number Diff line
@@ -18,6 +18,9 @@ from .ref_guided import ref_guided_diff_tpe
@click.option('--extend-gene-five-prime', default=0)
@click.option('--use-5utr/--ignore-5utr', default=True)
@click.option('--extend-gene-three-prime', default=0)
@click.option('--use-locus-tag/--use-gene-id-tag', default=False)
@click.option('--max-terminal-intron-size', default=100_000)
@click.option('--min-terminal-exon-size', default=30)
@click.option('--bootstraps', default=999)
@click.option('--threshold', default=0.05)
@click.option('--use-gamma-model/--no-model', default=True)
@@ -33,6 +36,8 @@ def d3pendr(treatment_fns, control_fns,
            min_read_overlap, min_reads_per_rep,
            extend_gene_five_prime, use_5utr,
            extend_gene_three_prime,
            use_locus_tag,
            max_terminal_intron_size, min_terminal_exon_size,
            bootstraps, threshold, use_gamma_model, test_homogeneity,
            tpe_cluster_sigma, min_tpe_reads, min_tpe_fractional_change,
            processes):
@@ -58,6 +63,8 @@ def d3pendr(treatment_fns, control_fns,
        min_read_overlap, min_reads_per_rep,
        extend_gene_five_prime, use_5utr,
        extend_gene_three_prime,
        use_locus_tag,
        max_terminal_intron_size, min_terminal_exon_size,
        bootstraps, threshold,
        use_gamma_model, test_homogeneity,
        write_apa_sites,
+5 −1
Original line number Diff line number Diff line
@@ -244,6 +244,9 @@ def ref_guided_diff_tpe(gtf_fn, treat_bam_fns, cntrl_bam_fns,
                        min_read_overlap, min_reads_per_cond,
                        extend_gene_five_prime, use_5utr,
                        extend_gene_three_prime,
                        by_locus,
                        max_terminal_intron_size,
                        min_terminal_exon_size,
                        bootstraps, threshold,
                        use_gamma_model, test_homogeneity,
                        find_tpe_sites, tpe_sigma,
@@ -263,7 +266,8 @@ def ref_guided_diff_tpe(gtf_fn, treat_bam_fns, cntrl_bam_fns,
        tpe_min_reads, tpe_min_rel_change,
    )
    gtf_it = gtf_iterator(
        gtf_fn, extend_gene_five_prime, use_5utr, extend_gene_three_prime
        gtf_fn, extend_gene_five_prime, use_5utr, extend_gene_three_prime,
        by_locus, max_terminal_intron_size, min_terminal_exon_size,
    )
    if processes == 1:
        # run on main process