Commit 4d57c0fc authored by mparker2's avatar mparker2
Browse files

change annotation file to gtf

parent 53357041
Loading
Loading
Loading
Loading
+53 −37
Original line number Diff line number Diff line
import re
import numpy as np
import pysam
import pyBigWig as pybw
@@ -154,46 +155,61 @@ def bam_or_bw(fn):
        raise ValueError('files must be bam, sam, or bigwig format')


def parse_inv(record, use_5utr):
    chrom = record[0]
    start = int(record[1])
    end = int(record[2])
    gene_id = record[3]
    strand = record[5]
    if not use_5utr:
        cds_start = int(record[6])
        cds_end = int(record[7])
        if cds_start != cds_end:
            # not a protein coding gene
            if strand == '+':
                start = cds_start
            else:
                end = cds_end
    return chrom, start, end, strand, gene_id


def parse_bed_record(record, extend_gene_five_prime=0, use_5utr=True, extend_gene_three_prime=0):
    chrom, start, end, strand, gene_id = parse_inv(record, use_5utr)
    if extend_gene_five_prime:
        if strand == '+':
            start = max(0, start - extend_gene_five_prime)
def gtf_iterator(gtf_fn, extend_gene_five_prime, ignore_5utr, extend_gene_three_prime):
    gtf_records = {}
    with open(gtf_fn) as gtf:
        for i, record in enumerate(gtf):
            record = record.split('\t')
            feat_type = record[2]
            if feat_type == 'CDS' or feat_type == 'exon':
                try:
                    gene_id = re.search('gene_id "(.+?)";', record[8]).group(1)
                except AttributeError:
                    raise ValueError(f'Could not parse gene_id from GTF line {i}')
                if gene_id not in gtf_records:
                    gtf_records[gene_id] = {
                        'chrom': record[0],
                        'strand': record[6]
                    }
                start = int(record[3]) - 1
                end = int(record[4])
                if feat_type not in gtf_records[gene_id]:
                    gtf_records[gene_id][feat_type] = (start, end)
                else:
            end += extend_gene_five_prime
    if extend_gene_three_prime:
        if strand == '+':
            end += extend_gene_three_prime
                    curr_range = gtf_records[gene_id][feat_type]
                    new_range = (
                        min(curr_range[0], start),
                        max(curr_range[1], end),
                    )
                    gtf_records[gene_id][feat_type] = new_range

    # once whole file is parsed yield the intervals
    for gene_id, gene_info in gtf_records.items():
        chrom = gene_info['chrom']
        strand = gene_info['strand']
        exon_start, exon_end = gene_info['exon']
        try:
            cds_start, cds_end = gene_info['CDS']
        except KeyError:
            # non-coding RNA
            cds_start, cds_end = gene_info['exon']

        # remove region corresponding to 5'UTR if necessary
        if ignore_5utr:
            gene_start = cds_start if strand == '+' else exon_start
            gene_end = exon_end if strand == '+' else cds_end
        else:
            start = max(0, start - extend_gene_three_prime)
    return chrom, start, end, gene_id, strand
            gene_start = exon_start
            gene_end = exon_end

        # add extensions to 3' and 5' ends
        start_ext, end_ext = extend_gene_five_prime, extend_gene_three_prime
        if strand == '-':
            start_ext, end_ext = end_ext, start_ext
        gene_start = max(0, gene_start - start_ext)
        gene_end = gene_end + end_ext

def bed12_iterator(bed_fn, extend_gene_five_prime, ignore_5utr, extend_gene_three_prime):
    with open(bed_fn) as bed:
        for record in bed:
            record = record.split()
            yield parse_bed_record(
                record, extend_gene_five_prime, ignore_5utr, extend_gene_three_prime
            )
        yield chrom, gene_start, gene_end, gene_id, strand


def write_output_bed(output_bed_fn, results):
+3 −3
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ from .ref_guided import ref_guided_diff_tpe
@click.option('-c', '--control-fns', required=True, multiple=True)
@click.option('-o', '--output-prefix', required=True)
@click.option('--write-apa-sites/--no-apa-sites', required=False, default=True)
@click.option('-a', '--annotation-bed12', required=True)
@click.option('-a', '--annotation-gtf-fn', required=True)
@click.option('--read-strand', type=click.Choice(['same', 'opposite', 'unstranded']), default='same')
@click.option('--read-end', type=click.Choice(['3', '5']), default='3')
@click.option('--paired-end-read', type=click.Choice(['1', '2', 'both', 'single']), default='single')
@@ -26,7 +26,7 @@ from .ref_guided import ref_guided_diff_tpe
@click.option('-p', '--processes', default=4)
def d3pendr(treatment_fns, control_fns,
            output_prefix, write_apa_sites,
            annotation_bed12,
            annotation_gtf_fn,
            read_strand, read_end, paired_end_read,
            min_read_overlap, min_reads_per_rep,
            extend_gene_five_prime, use_5utr,
@@ -50,7 +50,7 @@ def d3pendr(treatment_fns, control_fns,
        paired_end_read = 'both'

    results, apa_sites = ref_guided_diff_tpe(
        annotation_bed12,
        annotation_gtf_fn,
        treatment_fns, control_fns,
        read_strand, read_end, paired_end_read,
        min_read_overlap, min_reads_per_rep,
+15 −17
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ from joblib import Parallel, delayed

from .io import (
    MultiBamParser, MultiBigWigParser, bam_or_bw,
    bed12_iterator, write_output_bed
    gtf_iterator, write_output_bed
)
from .invs import relative_tpe, intersect_spliced_invs
from .stats import tpe_stats
@@ -157,9 +157,7 @@ def get_apa_tpes(cntrl_distrib, nreads_cntrl,
                       cntrl_frac, treat_frac, relative_change)
        



def process_bed_records(bed_records, treat_bam_fns, cntrl_bam_fns,
def process_gtf_records(gtf_records, treat_bam_fns, cntrl_bam_fns,
                        read_strand, read_end, paired_end_read,
                        min_read_overlap, min_reads,
                        bootstraps, threshold, is_bam,
@@ -172,7 +170,7 @@ def process_bed_records(bed_records, treat_bam_fns, cntrl_bam_fns,

    with parser(cntrl_bam_fns) as cntrl_bam, parser(treat_bam_fns) as treat_bam:

        for chrom, start, end, gene_id, strand in bed_records:
        for chrom, start, end, gene_id, strand in gtf_records:
            cntrl_distribs, nreads_cntrl = get_tpe_distribution(
                cntrl_bam, chrom, start, end, strand,
                min_read_overlap, read_strand, read_end,
@@ -226,20 +224,20 @@ def process_bed_records(bed_records, treat_bam_fns, cntrl_bam_fns,
    return results, tpe_apa_res


def chunk_bed_records(bed_records, processes):
    # read the whole bed file
    bed_records = list(bed_records)
    nrecords = len(bed_records)
def chunk_gtf_records(gtf_records, processes):
    # read the whole gtf file
    gtf_records = list(gtf_records)
    nrecords = len(gtf_records)
    n, r = divmod(nrecords, processes)
    split_points = ([0] + r * [n + 1] + (processes - r) * [n])
    split_points = np.cumsum(split_points)
    for i in range(processes):
        start = split_points[i]
        end = split_points[i + 1]
        yield bed_records[start: end]
        yield gtf_records[start: end]


def ref_guided_diff_tpe(bed_fn, treat_bam_fns, cntrl_bam_fns,
def ref_guided_diff_tpe(gtf_fn, treat_bam_fns, cntrl_bam_fns,
                        read_strand, read_end, paired_end_read,
                        min_read_overlap, min_reads_per_cond,
                        extend_gene_five_prime, use_5utr,
@@ -260,18 +258,18 @@ def ref_guided_diff_tpe(bed_fn, treat_bam_fns, cntrl_bam_fns,
        find_tpe_sites, tpe_sigma,
        tpe_min_reads, tpe_min_rel_change,
    )
    bed_it = bed12_iterator(
        bed_fn, extend_gene_five_prime, use_5utr, extend_gene_three_prime
    gtf_it = gtf_iterator(
        gtf_fn, extend_gene_five_prime, use_5utr, extend_gene_three_prime
    )
    if processes == 1:
        # run on main process
        results, tpa_apa_results = process_bed_records(
            bed_it, *args
        results, tpa_apa_results = process_gtf_records(
            gtf_it, *args
        )
    else:
        results = Parallel(n_jobs=processes)(
            delayed(process_bed_records)(bed_chunk, *args)
            for bed_chunk in chunk_bed_records(bed_it, processes)
            delayed(process_gtf_records)(gtf_chunk, *args)
            for gtf_chunk in chunk_gtf_records(gtf_it, processes)
        )
        results, tpe_apa_results = zip(*results)
        results = pd.concat(results)