from operator import itemgetter
from intervaltree import Interval
from collections.abc import Iterable
from scipy.stats import chi2_contingency
from scipy.signal import find_peaks
from Bio.Seq import reverse_complement, translate
from Bio.Data.CodonTable import TranslationError
from pysam import FastaFile
import numpy as np
import copy
import itertools
from .splice_graph import SegmentGraph
from .short_read import Coverage
from ._transcriptome_filter import SPLICE_CATEGORY
from ._utils import pairwise, _filter_event, find_orfs, smooth, get_quantiles, _filter_function, \
pairwise_event_test, prepare_contingency_table
import logging
logger = logging.getLogger('isotools')
[docs]class Gene(Interval):
'This class stores all gene information and transcripts. It is derived from intervaltree.Interval.'
required_infos = ['ID', 'chr', 'strand']
# initialization
def __new__(cls, begin, end, data, transcriptome):
return super().__new__(cls, begin, end, data) # required as Interval (and Gene) is immutable
def __init__(self, begin, end, data, transcriptome):
self._transcriptome = transcriptome
def __str__(self):
return 'Gene {} {}({}), {} reference transcripts, {} expressed transcripts'.format(
self.name, self.region, self.strand, self.n_ref_transcripts, self.n_transcripts)
def __repr__(self):
return object.__repr__(self)
from ._gene_plots import sashimi_plot, gene_track, sashimi_plot_short_reads, sashimi_figure, plot_domains
from .domains import add_interpro_domains
[docs] def short_reads(self, idx):
'''Returns the short read coverage profile for a short read sample.
:param idx: The index of the short read sample. '''
try:
return self.data['short_reads'][idx]
except (KeyError, IndexError):
srdf = self._transcriptome.infos['short_reads'] # raises key_error if no short reads added
self.data.setdefault('short_reads', [])
for i in range(len(self.data['short_reads']), len(srdf)):
self.data['short_reads'].append(Coverage.from_bam(srdf.file[i], self))
return self.data['short_reads'][idx]
[docs] def correct_fuzzy_junctions(self, trid, size, modify=True):
'''Corrects for splicing shifts.
This function looks for "shifted junctions", e.g. same difference compared to reference annotaiton at both donor and acceptor)
presumably caused by ambigous alignments. In these cases the positions are adapted to the reference position (if modify is set).
:param trid: The index of the transcript to be checked.
:param size: The maximum shift to be corrected.
:param modify: If set, the exon positions are corrected according to the reference.'''
exons = trid['exons']
shifts = self.ref_segment_graph.fuzzy_junction(exons, size)
if shifts and modify:
for i, sh in shifts.items():
if exons[i][0] <= exons[i][1] + sh and exons[i + 1][0] + sh <= exons[i + 1][1]:
exons[i][1] += sh
exons[i + 1][0] += sh
trid['exons'] = [e for e in exons if e[0] < e[1]] # remove zero length exons
return shifts
def _to_gtf(self, trids, source='isoseq'):
'''Creates the gtf lines of the gene as strings.'''
donotshow = {'transcripts', 'short_exons', 'segment_graph'}
info = {'gene_id': self.id, 'gene_name': self.name}
lines = [None]
starts = []
ends = []
for i in trids:
tr = self.transcripts[i]
info['transcript_id'] = f'{info["gene_id"]}_{i}'
starts.append(tr['exons'][0][0] + 1)
ends.append(tr['exons'][-1][1])
trinfo = info.copy()
if 'downstream_A_content' in tr:
trinfo['downstream_A_content'] = f'{tr["downstream_A_content"]:0.3f}'
if tr['annotation'][0] == 0: # FSM
refinfo = {}
for refid in tr['annotation'][1]['FSM']:
for k in self.ref_transcripts[refid]:
if k == 'exons':
continue
elif k == 'CDS':
if self.strand == '+':
cds_start, cds_end = self.ref_transcripts[refid]['CDS']
else:
cds_end, cds_start = self.ref_transcripts[refid]['CDS']
refinfo.setdefault('CDS_start', []).append(str(cds_start))
refinfo.setdefault('CDS_end', []).append(str(cds_end))
else:
refinfo.setdefault(k, []).append(str(self.ref_transcripts[refid][k]))
for k, vlist in refinfo.items():
trinfo[f'ref_{k}'] = ','.join(vlist)
else:
trinfo['novelty'] = ','.join(k for k in tr['annotation'][1])
lines.append((self.chrom, source, 'transcript', tr['exons'][0][0] + 1, tr['exons'][-1][1], '.',
self.strand, '.', '; '.join(f'{k} "{v}"' for k, v in trinfo.items())))
noncanonical = tr.get('noncanonical_splicing', [])
for enr, pos in enumerate(tr['exons']):
exon_info = info.copy()
exon_info['exon_id'] = f'{info["gene_id"]}_{i}_{enr}'
if enr in noncanonical:
exon_info['noncanonical_donor'] = noncanonical[enr][:2]
if enr+1 in noncanonical:
exon_info['noncanonical_acceptor'] = noncanonical[enr+1][2:]
lines.append((self.chrom, source, 'exon', pos[0] + 1, pos[1], '.', self.strand, '.', '; '.join(f'{k} "{v}"' for k, v in exon_info.items())))
if len(lines) > 1:
# add gene line
if 'reference' in self.data:
info.update({k: v for k, v in self.data['reference'].items() if k not in donotshow}) # add reference gene specific fields
lines[0] = (self.chrom, source, 'gene', min(starts), max(ends), '.', self.strand, '.', '; '.join(f'{k} "{v}"' for k, v in info.items()))
return lines
return []
[docs] def add_noncanonical_splicing(self, genome_fh):
'''Add information on noncanonical splicing.
For all transcripts of the gene, scan for noncanonical (i.e. not GT-AG) splice sites.
If noncanonical splice sites are present, the corresponding intron index (in genomic orientation) and the sequence
i.e. the dinucleotides of donor and aceptor as XX-YY string are stored in the "noncannoncical_splicing" field of the transcript dicts.
True noncanonical splicing is rare, thus it might indicate technical artifacts (template switching, missalignment, ...)
:param genome_fh: A file handle of the genome fasta file.'''
ss_seq = {}
for tr in self.transcripts:
pos = [(tr['exons'][i][1], tr['exons'][i + 1][0] - 2) for i in range(len(tr['exons']) - 1)]
new_ss_seq = {site: genome_fh.fetch(self.chrom, site, site + 2).upper() for intron in pos for site in intron if site not in ss_seq}
if new_ss_seq:
ss_seq.update(new_ss_seq)
if self.strand == '+':
sj_seq = [ss_seq[d] + ss_seq[a] for d, a in pos]
else:
sj_seq = [reverse_complement(ss_seq[d] + ss_seq[a]) for d, a in pos]
nc = [(i, seq) for i, seq in enumerate(sj_seq) if seq != 'GTAG']
if nc:
tr['noncanonical_splicing'] = nc
[docs] def add_direct_repeat_len(self, genome_fh, delta=15, max_mm=2, wobble=2):
'''Computes direct repeat length.
This function counts the number of consequtive equal bases at donor and acceptor sites of the splice junctions.
This information is stored in the "direct_repeat_len" filed of the transcript dictionaries.
Direct repeats longer than expected by chance indicate template switching.
:param genome_fh: The file handle to the genome fasta.
:param delta: The maximum length of direct repeats that can be found.
:param max_mm: The maximum length of direct repeats that can be found.
:param wobble: The maximum length of direct repeats that can be found.'''
intron_seq = {}
score = {}
for tr in self.transcripts:
for intron in ((tr['exons'][i][1], tr['exons'][i + 1][0]) for i in range(len(tr['exons']) - 1)):
for pos in intron:
try:
intron_seq.setdefault(pos, genome_fh.fetch(self.chrom, pos - delta, pos + delta))
except (ValueError, IndexError): # N padding at start/end of the chromosomes
chr_len = genome_fh.get_reference_length(self.chrom)
seq = genome_fh.fetch(self.chrom, max(0, pos - delta), min(chr_len, pos + delta))
if pos - delta < 0:
seq = ''.join(['N'] * (pos - delta)) + seq
if pos + delta > chr_len:
seq += ''.join(['N'] * (pos + delta - chr_len))
intron_seq.setdefault(pos, seq)
if intron not in score:
score[intron] = repeat_len(intron_seq[intron[0]], intron_seq[intron[1]], wobble=wobble, max_mm=max_mm)
for tr in self.transcripts:
tr['direct_repeat_len'] = [min(score[(e1[1], e2[0])], delta) for e1, e2 in pairwise(tr['exons'])]
[docs] def add_threeprime_a_content(self, genome_fh, length=30):
'''Adds the information of the genomic A content downstream the transcript.
High values of genomic A content indicate internal priming and hence genomic origin of the LRTS read.
This function populates the 'downstream_A_content' field of the transcript dictionaries.
:param geneome_fh: A file handle for the indexed genome fasta file.
:param length: The length of the downstream region to be considered.
'''
a_content = {}
for tr in (t for tL in (self.transcripts, self.ref_transcripts) for t in tL):
if self.strand == '+':
pos = tr['exons'][-1][1]
else:
pos = tr['exons'][0][0] - length
if pos not in a_content:
seq = genome_fh.fetch(self.chrom, max(0, pos), pos + length)
if self.strand == '+':
a_content[pos] = seq.upper().count('A') / length
else:
a_content[pos] = seq.upper().count('T') / length
tr['downstream_A_content'] = a_content[pos]
[docs] def get_sequence(self, genome_fh, trids=None, reference=False, protein=False):
'''Returns the nucleotide sequence of the specified transcripts.
:param genome_fh: The path to the genome fasta file, or FastaFile handle.
:param trids: List of transcript ids for which the sequence are requested.
:param reference: Specifiy whether the sequence is fetched for reference transcripts (True)
or long read transcripts (False, default).
:param protein: Return protein sequences instead of transcript sequences.'''
trL = [(i, tr) for i, tr in enumerate(self.ref_transcripts if reference else self.transcripts) if trids is None or i in trids]
if not trL:
return {}
pos = (min(tr['exons'][0][0] for _, tr in trL), max(tr['exons'][-1][1] for _, tr in trL))
try: # assume its a FastaFile file handle
seq = genome_fh.fetch(self.chrom, *pos)
except AttributeError:
genome_fn = genome_fh
with FastaFile(genome_fn) as genome_fh:
seq = genome_fh.fetch(self.chrom, *pos)
tr_seqs = {}
for i, tr in trL:
trseq = ''
for e in tr['exons']:
trseq += seq[e[0]-pos[0]:e[1]-pos[0]]
tr_seqs[i] = trseq
if self.strand == '-':
tr_seqs = {i: reverse_complement(ts) for i, ts in tr_seqs.items()}
if not protein:
return tr_seqs
prot_seqs = {}
for i, tr in trL:
orf = tr.get("CDS", tr.get("ORF"))
if not orf:
continue
pos = sorted(self.find_transcript_positions(i, orf[:2], reference=reference))
try:
prot_seqs[i] = translate(tr_seqs[i][pos[0]:pos[1]], cds=True)
except TranslationError:
logger.warning(f'CDS sequence of {self.id} {"reference" if reference else ""} transcript {i} cannot be translated.')
return prot_seqs
[docs] def add_orfs(self, genome_fh, reference=False, minlen=30, start_codons=["ATG"], stop_codons=['TAA', 'TAG', 'TGA']):
'''find longest ORF for each transcript and add to the transcript properties tr["ORF"]'''
trL = self.ref_transcripts if reference else self.transcripts
for (_, orfs), tr in zip(self.get_all_orf(genome_fh, reference, minlen, start_codons, stop_codons), trL):
if orfs:
tr["ORF"] = max(orfs, key=lambda x: x[2]['length'])
[docs] def get_all_orf(self, genome_fh, reference=False, minlen=30, start_codons=["ATG"], stop_codons=['TAA', 'TAG', 'TGA']):
''' Predicts ORF.
'''
orf_list = []
trL = self.ref_transcripts if reference else self.transcripts
for trid, tr_seq in self.get_sequence(genome_fh, reference=reference).items():
tr = trL[trid]
tr_start = tr['exons'][0][0]
cum_exon_len = np.cumsum([end-start for start, end in tr['exons']]) # cummulative exon length
cum_intron_len = np.cumsum([0]+[end-start for (_, start), (end, _) in pairwise(tr['exons'])]) # cummulative intron length
orf_list.append((tr_seq, []))
for start, stop, frame, seq_start, seq_end in find_orfs(tr_seq, minlen=minlen):
if self.strand == '-':
start, stop = cum_exon_len[-1]-stop, cum_exon_len[-1]-start
start_exon = next(i for i in range(len(cum_exon_len)) if cum_exon_len[i] >= start)
stop_exon = next(i for i in range(start_exon, len(cum_exon_len)) if cum_exon_len[i] >= stop)
genome_pos = (tr_start+start+cum_intron_len[start_exon],
tr_start+stop+cum_intron_len[stop_exon])
dist_pas = 0 # distance of termination codon to last upstream splice site
if self.strand == '+' and stop_exon < len(cum_exon_len)-1:
dist_pas = cum_exon_len[-2]-stop
if self.strand == '-' and start_exon > 0:
dist_pas = start-cum_exon_len[0]
orf_list[-1][1].append((*genome_pos, {'start': start, 'length': stop-start,
'start_codon': seq_start, 'stop_codon': seq_end, 'NMD': dist_pas > 55}))
return orf_list
[docs] def add_fragments(self):
'''Checks for transcripts that are fully contained in other transcripts.
Transcripts that are fully contained in other transcripts are potential truncations.
This function populates the 'fragment' filed of the transcript dictionaries with the indices of the containing transcripts,
and the exon ids that match the first and last exons.'''
for trid, containers in self.segment_graph.find_fragments().items():
self.transcripts[trid]['fragments'] = containers # list of (containing transcript id, first 5' exons, first 3'exons)
[docs] def coding_len(self, trid):
'''Returns length of 5\'UTR, coding sequence and 3\'UTR.
:param trid: The transcript index for which the coding length is requested. '''
try:
exons = self.transcripts[trid]['exons']
cds = self.transcripts[trid]['CDS']
except KeyError:
return None
else:
coding_len = _coding_len(exons, cds)
if self.strand == '-':
coding_len.reverse()
return coding_len
[docs] def get_infos(self, trid, keys, sample_i, group_i, **kwargs):
'''Returns the transcript information specified in "keys" as a list.'''
return [value for k in keys for value in self._get_info(trid, k, sample_i, group_i)]
def _get_info(self, trid, key, sample_i, group_i, **kwargs):
# returns tuples (as some keys return multiple values)
if key == 'length':
return sum((e - b for b, e in self.transcripts[trid]['exons'])),
elif key == 'n_exons':
return len(self.transcripts[trid]['exons']),
elif key == 'exon_starts':
return ','.join(str(e[0]) for e in self.transcripts[trid]['exons']),
elif key == 'exon_ends':
return ','.join(str(e[1]) for e in self.transcripts[trid]['exons']),
elif key == 'annotation':
# sel=['sj_i','base_i', 'as']
if 'annotation' not in self.transcripts[trid]:
return ('NA',) * 2
nov_class, subcat = self.transcripts[trid]['annotation']
# subcat_string = ';'.join(k if v is None else '{}:{}'.format(k, v) for k, v in subcat.items())
return SPLICE_CATEGORY[nov_class], ','.join(subcat) # only the names of the subcategories
elif key == 'coverage':
return self.coverage[sample_i, trid]
elif key == 'tpm':
return self.tpm(kwargs.get('pseudocount', 1))[sample_i, trid]
elif key == 'group_coverage_sum':
return tuple(self.coverage[si, trid].sum() for si in group_i)
elif key == 'group_tpm_mean':
return tuple(self.tpm(kwargs.get('pseudocount', 1))[si, trid].mean() for si in group_i)
elif key in self.transcripts[trid]:
val = self.transcripts[trid][key]
if isinstance(val, Iterable): # iterables get converted to string
return str(val),
else:
return val, # atomic (e.g. numeric)
return 'NA',
def _set_coverage(self, force=False):
samples = self._transcriptome.samples
cov = np.zeros((len(samples), self.n_transcripts), dtype=int)
if not force: # keep the segment graph if no new transcripts
known = self.data.get('coverage', None)
if known is not None and known.shape[1] == self.n_transcripts:
if known.shape == cov.shape:
return
cov[:known.shape[0], :] = known
for i in range(known.shape[0], len(samples)):
for j, tr in enumerate(self.transcripts):
cov[i, j] = tr['coverage'].get(samples[i], 0)
self.data['coverage'] = cov
return
for i, sa in enumerate(samples):
for j, tr in enumerate(self.transcripts):
cov[i, j] = tr['coverage'].get(sa, 0)
self.data['coverage'] = cov
self.data['segment_graph'] = None
[docs] def tpm(self, pseudocount=1):
'''Returns the transcripts per million (TPM).
TPM is returned as a numpy array, with samples in columns and transcript isoforms in the rows.'''
return (self.coverage+pseudocount)/self._transcriptome.sample_table['nonchimeric_reads'].values.reshape(-1, 1)*1e6
[docs] def find_transcript_positions(self, trid, pos, reference=False):
'''Converts genomic positions to positions within the transcript.
:param trid: The transcript id
:param pos: List of sorted genomic positions, for which the transcript positions are computed.'''
tr_pos = []
exons = self.ref_transcripts[trid]['exons'] if reference else self.transcripts[trid]['exons']
e_idx = 0
offset = 0
for p in sorted(pos):
try:
while p > exons[e_idx][1]:
offset += (exons[e_idx][1]-exons[e_idx][0])
e_idx += 1
except IndexError:
for _ in range(len(pos)-len(tr_pos)):
tr_pos.append(None)
break
tr_pos.append(offset+p-exons[e_idx][0] if p >= exons[e_idx][0] else None)
if self.strand == '-':
trlen = sum(end-start for start, end in exons)
tr_pos = [None if p is None else trlen-p for p in tr_pos]
return tr_pos
@property
def coverage(self):
'''Returns the transcript coverage.
Coverage is returned as a numpy array, with samples in columns and transcript isoforms in the rows.'''
cov = self.data.get('coverage', None)
if cov is not None:
return cov
self._set_coverage()
return self.data['coverage']
@property
def gene_coverage(self):
'''Returns the gene coverage.
Total Coverage of the gene for each sample.'''
return self.coverage.sum(1)
@property
def chrom(self):
'''Returns the genes chromosome.'''
return self.data['chr']
@property
def start(self): # alias for begin
return self.begin
@property
def region(self):
'''Returns the region of the gene as a string in the format "chr:start-end".'''
try:
return '{}:{}-{}'.format(self.chrom, self.start, self.end)
except KeyError:
raise
@property
def id(self):
'''Returns the gene id'''
try:
return self.data['ID']
except KeyError:
logger.error(self.data)
raise
@property
def name(self):
'''Returns the gene name'''
try:
return self.data['name']
except KeyError:
return self.id # e.g. novel genes do not have a name (but id)
@property
def is_annotated(self):
'''Returns "True" iff reference annotation is present for the gene.'''
return 'reference' in self.data
@property
def is_expressed(self):
'''Returns "True" iff gene is covered by at least one long read in at least one sample.'''
return bool(self.transcripts)
@property
def strand(self):
'''Returns the strand of the gene, e.g. "+" or "-"'''
return self.data['strand']
@property
def transcripts(self):
'''Returns the list of transcripts of the gene, as found by LRTS.'''
try:
return self.data['transcripts']
except KeyError:
return []
@property
def ref_transcripts(self):
'''Returns the list of reference transcripts of the gene.'''
try:
return self.data['reference']['transcripts']
except KeyError:
return []
@property
def n_transcripts(self):
'''Returns number of transcripts of the gene, as found by LRTS.'''
return len(self.transcripts)
@property
def n_ref_transcripts(self):
'''Returns number of reference transcripts of the gene.'''
return len(self.ref_transcripts)
@property
def ref_segment_graph(self): # raises key error if not self.is_annotated
'''Returns the segment graph of the reference transcripts for the gene'''
assert self.is_annotated, "reference segment graph requested on novel gene"
if 'segment_graph' not in self.data['reference'] or self.data['reference']['segment_graph'] is None:
exons = [tr['exons'] for tr in self.ref_transcripts]
self.data['reference']['segment_graph'] = SegmentGraph(exons, self.strand)
return self.data['reference']['segment_graph']
@property
def segment_graph(self):
'''Returns the segment graph of the LRTS transcripts for the gene'''
if 'segment_graph' not in self.data or self.data['segment_graph'] is None:
exons = [tr['exons'] for tr in self.transcripts]
try:
self.data['segment_graph'] = SegmentGraph(exons, self.strand)
except Exception:
logger.error('Error initializing Segment Graph on %s with exons %s', self.strand, exons)
raise
return self.data['segment_graph']
def __copy__(self):
return Gene(self.start, self.end, self.data, self._transcriptome)
def __deepcopy__(self, memo): # does not copy _transcriptome!
return Gene(self.start, self.end, copy.deepcopy(self.data, memo), self._transcriptome)
def __reduce__(self):
return Gene, (self.start, self.end, self.data, self._transcriptome)
[docs] def copy(self):
'Returns a shallow copy of self.'
return self.__copy__()
def filter_transcripts(self, query=None, min_coverage=None, max_coverage=None):
tr_filter = self._transcriptome.filter['transcript']
if query:
# used_tags={tag for tag in re.findall(r'\b\w+\b', query) if tag not in BOOL_OP}
query_fun, used_tags = _filter_function(query)
msg = 'did not find the following filter rules: {}\nvalid rules are: {}'
assert all(f in tr_filter for f in used_tags), msg.format(
', '.join(f for f in used_tags if f not in tr_filter), ', '.join(tr_filter))
tr_filter_fun = {tag: _filter_function(tr_filter[tag])[0] for tag in used_tags if tag in tr_filter}
trids = []
for i, tr in enumerate(self.transcripts):
if min_coverage and self.coverage[:, i].sum() < min_coverage:
continue
if max_coverage and self.coverage[:, i].sum() > max_coverage:
continue
if query is None or query_fun(
**{tag: f(g=self, trid=i, **tr) for tag, f in tr_filter_fun.items()}):
trids.append(i)
return trids
def filter_ref_transcripts(self, query=None):
tr_filter = self._transcriptome.filter['reference']
if query:
# used_tags={tag for tag in re.findall(r'\b\w+\b', query) if tag not in BOOL_OP}
query_fun, used_tags = _filter_function(query)
msg = 'did not find the following filter rules: {}\nvalid rules are: {}'
assert all(f in tr_filter for f in used_tags), msg.format(
', '.join(f for f in used_tags if f not in tr_filter), ', '.join(tr_filter))
tr_filter_fun = {tag: _filter_function(tr_filter[tag])[0] for tag in used_tags if tag in tr_filter}
else:
return list(range(len(self.transcripts)))
trids = []
for i, tr in enumerate(self.ref_transcripts):
if query_fun(**{tag: f(g=self, trid=i, **tr) for tag, f in tr_filter_fun.items()}):
trids.append(i)
return trids
def _find_splice_sites(exons, transcripts):
'''Checks whether the splice sites of a new transcript are present in the set of transcripts.
avoids the computation of segment graph, which provides the same functionality.
:param exons: A list of exon tuples representing the transcript
:type exons: list
:return: boolean array indicating whether the splice site is contained or not'''
intron_iter = [pairwise(tr['exons']) for tr in transcripts]
current = [next(tr) for tr in intron_iter]
contained = np.zeros(len(exons)-1)
for j, (e1, e2) in enumerate(pairwise(exons)):
for i, tr in enumerate(intron_iter):
while current[i][0][1] < e1[1]:
try:
current[i] = next(tr)
except StopIteration:
continue
if e1[1] == current[i][0][1] and e2[0] == current[i][1][0]:
contained[j] = True
return current
[docs] def coordination_test(self, samples=None, test="chi2", min_dist=1, min_total=100, min_alt_fraction=.1,
events=None, event_type=("ES", "5AS", "3AS", "IR", "ME")):
'''Performs pairwise independence test for all pairs of Alternative Splicing Events (ASEs) in a gene.
For all pairs of ASEs in a gene creates a contingency table and performs an indeppendence test.
All ASEs A have two states, pri_A and alt_A, the primary and the alternative state respectivley.
Thus, given two events A and B, we have four possible ways in which these events can occur on
a transcript, that is, pri_A and pri_B, pri_A and alt_B, alt_A and pri_B, and alt_A and alt_B.
These four values can be put in a contingency table and independence, or coordination,
between the two events can be tested.
:param samples: Specify the samples that should be considdered in the test.
The samples can be provided either as a single group name, a list of sample names, or a list of sample indices.
:param test: Test to be performed. One of ("chi2", "fisher")
:type test: str
:param min_dist: Minimum distance (in nucleotides) between the two
alternative splicing events for the pair to be tested.
:type min_dist: int
:param min_total: The minimum total number of reads for an event pair to pass the filter.
:type min_total: int
:param min_alt_fraction: The minimum fraction of reads supporting the minor alternative of the two events.
:type min_alt_fraction: float
:param events: To speed up testing on different groups of the same transcriptome objects, events can be precomputed
with the isotools._utils.precompute_events_dict function.
:param event_type: A tuple with event types to test. Valid types are
("ES", "3AS", "5AS", "IR", "ME", "TSS", "PAS"). Default is ("ES", "5AS", "3AS", "IR", "ME").
Not used if the event parameter is already given.
:return: A list of tuples with the test results: (gene_id, gene_name, strand, eventA_type, eventB_type,
eventA_start, eventA_end, eventB_start, eventB_end, p_value, test_stat, log2OR, dcPSI_AB, dcPSI_BA,
priA_priB, priA_altB, altA_priB, altA_altB, priA_priB_trids, priA_altB_trids, altA_priB_trids, altA_altB_trids).
'''
if samples is None:
cov = self.coverage.sum(axis=0)
else:
try:
# Fast mode when testing several genes
cov = self.coverage[samples].sum(0)
except IndexError:
# Fall back to looking up the sample indices
from isotools._transcriptome_stats import _check_groups
_, _, groups = _check_groups(self._transcriptome, [samples], 1)
cov = self.coverage[groups[0]].sum(0)
sg = self.segment_graph
if events is None:
events = sg.find_splice_bubbles(types=event_type)
events = [e for e in events if _filter_event(cov, e, min_total=min_total,
min_alt_fraction=min_alt_fraction)]
# make sure its sorted (according to gene strand)
if self.strand == '+':
events.sort(key=itemgetter(2, 3), reverse=False) # sort by starting node
else:
events.sort(key=itemgetter(3, 2), reverse=True) # reverse sort by end node
test_res = []
for i, j in itertools.combinations(range(len(events)), 2):
if sg.events_dist(events[i], events[j]) < min_dist:
continue
if (events[i][4], events[j][4]) == ("TSS", "TSS") or (events[i][4], events[j][4]) == ("PAS", "PAS"):
continue
con_tab, tr_ID_tab = prepare_contingency_table(events[i], events[j], cov)
if con_tab.sum(None) < min_total: # check that the joint occurrence of the two events passes the threshold
continue
if min(con_tab.sum(1).min(), con_tab.sum(0).min())/con_tab.sum(None) < min_alt_fraction:
continue
test_result = pairwise_event_test(con_tab, test=test) # append to test result
coordinate1 = sg._get_event_coordinate(events[i])
coordinate2 = sg._get_event_coordinate(events[j])
attr = (self.id, self.name, self.strand, events[i][4], events[j][4]) + \
coordinate1 + coordinate2 + test_result + \
tuple(con_tab.flatten()) + tuple(tr_ID_tab.flatten())
# events[i][4] is the events[i] type
# coordinate1[0] is the starting coordinate of event 1
# coordinate1[0] is the ending coordinate of event 1
# coordinate2[0] is the starting coordinate of event 2
# coordinate2[1] is the ending coordinate of event 2
test_res.append(attr)
return test_res
[docs] def die_test(self, groups, min_cov=25, n_isoforms=10):
''' Reimplementation of the DIE test, suggested by Joglekar et al in Nat Commun 12, 463 (2021):
"A spatially resolved brain region- and cell type-specific isoform atlas of the postnatal mouse brain"
Syntax and parameters follow the original implementation in
https://github.com/noush-joglekar/scisorseqr/blob/master/inst/RScript/IsoformTest.R
:param groups: Define the columns for the groups.
:param min_cov: Minimal number of reads per group for the gene.
:param n_isoforms: Number of isoforms to consider in the test for the gene. All additional least expressed isoforms get summarized.'''
# select the samples and sum the group counts
try:
# Fast mode when testing several genes
cov = np.array([self.coverage[grp].sum(0) for grp in groups]).T
except IndexError:
# Fall back to looking up the sample indices
from isotools._transcriptome_stats import _check_groups
_, _, groups = _check_groups(self._transcriptome, groups)
cov = np.array([self.coverage[grp].sum(0) for grp in groups]).T
if np.any(cov.sum(0) < min_cov):
return np.nan, np.nan, []
# if there are more than 'numIsoforms' isoforms of the gene, all additional least expressed get summarized.
if cov.shape[0] > n_isoforms:
idx = np.argpartition(-cov.sum(1), n_isoforms) # take the n_isoforms most expressed isoforms (random order)
additional = cov[idx[n_isoforms:]].sum(0)
cov = cov[idx[:n_isoforms]]
cov[n_isoforms-1] += additional
idx[n_isoforms-1] = -1 # this isoform gets all other - I give it index
elif cov.shape[0] < 2:
return np.nan, np.nan, []
else:
idx = np.array(range(cov.shape[0]))
try:
_, pval, _, _ = chi2_contingency(cov)
except ValueError:
logger.error(f'chi2_contingency({cov})')
raise
iso_frac = cov/cov.sum(0)
deltaPI = iso_frac[..., 0]-iso_frac[..., 1]
order = np.argsort(deltaPI)
pos_idx = [order[-i] for i in range(1, 3) if deltaPI[order[-i]] > 0]
neg_idx = [order[i] for i in range(2) if deltaPI[order[i]] < 0]
deltaPI_pos = deltaPI[pos_idx].sum()
deltaPI_neg = deltaPI[neg_idx].sum()
if deltaPI_pos > -deltaPI_neg:
return pval, deltaPI_pos, idx[pos_idx]
else:
return pval, deltaPI_neg, idx[neg_idx]
def _unify_ends(self, smooth_window=31, rel_prominence=1, search_range=(.1, .9)):
''' Find common TSS/PAS for tanscripts of the gene'''
if not self.transcripts:
# nothing to do here
return
assert 0 <= search_range[0] <= .5 <= search_range[1] <= 1
# get gene tss/pas profiles
tss = {}
pas = {}
for tr in self.transcripts:
for sa in tr['TSS']:
for pos, c in tr['TSS'][sa].items():
tss[pos] = tss.get(pos, 0)+c
for sa in tr['PAS']:
for pos, c in tr['PAS'][sa].items():
pas[pos] = pas.get(pos, 0)+c
tss_pos = [min(tss), max(tss)]
if tss_pos[1]-tss_pos[0] < smooth_window:
tss_pos[0] -= (smooth_window + tss_pos[0]-tss_pos[1] - 1)
pas_pos = [min(pas), max(pas)]
if pas_pos[1]-pas_pos[0] < smooth_window:
pas_pos[0] -= (smooth_window + pas_pos[0]-pas_pos[1] - 1)
tss = [tss.get(pos, 0) for pos in range(tss_pos[0], tss_pos[1]+1)]
pas = [pas.get(pos, 0) for pos in range(pas_pos[0], pas_pos[1]+1)]
# smooth profiles and finde maxima
tss_smooth = smooth(np.array(tss), smooth_window)
pas_smooth = smooth(np.array(pas), smooth_window)
# at least half of smooth_window reads required to call a peak
# minimal distance between peaks is > ~ smooth_window
# rel_prominence=1 -> smaller peak must have twice the hight of valley to call two peaks
tss_peaks, _ = find_peaks(np.log2(tss_smooth+1), prominence=(rel_prominence, None))
tss_peak_pos = tss_peaks+tss_pos[0]
pas_peaks, _ = find_peaks(np.log2(pas_smooth+1), prominence=(rel_prominence, None))
pas_peak_pos = pas_peaks+pas_pos[0]
# find transcripts with common first/last splice site
starts = {}
ends = {}
for trid, tr in enumerate(self.transcripts):
starts.setdefault(tr['exons'][0][1], []).append(trid)
ends.setdefault(tr['exons'][-1][0], []).append(trid)
if self.strand == '-':
starts, ends = ends, starts
# for each site, find consistant "peaks" TSS/PAS
# if none found use median of all read starts
for pos, tr_ids in starts.items():
profile = {}
for trid in tr_ids:
for sa_tss in self.transcripts[trid]['TSS'].values():
for pos, c in sa_tss.items():
profile[pos] = profile.get(pos, 0)+c
quantiles = get_quantiles(sorted(profile.items()), [search_range[0], .5, search_range[1]])
# one/ several peaks within base range? -> quantify by next read_start
# else use median
ol_peaks = [p for p in tss_peak_pos if quantiles[0] < p <= quantiles[-1]]
if not ol_peaks:
ol_peaks = [quantiles[1]]
for trid in tr_ids:
tr = self.transcripts[trid]
tr['TSS_unified'] = {}
for sa, sa_tss in tr['TSS'].items():
tss_unified = {}
for pos, c in sa_tss.items():
next_peak = min((p for p in ol_peaks if p < tr['exons'][0][1]),
default=pos, key=lambda x: abs(x-pos))
tss_unified[next_peak] = tss_unified.get(next_peak, 0)+c
tr['TSS_unified'][sa] = tss_unified
# same for PAS
for pos, tr_ids in ends.items():
profile = {}
for trid in tr_ids:
for sa_pas in self.transcripts[trid]['PAS'].values():
for pos, c in sa_pas.items():
profile[pos] = profile.get(pos, 0)+c
quantiles = get_quantiles(sorted(profile.items()), [search_range[0], .5, search_range[1]])
# one/ several peaks within base range? -> quantify by next read_start
# else use median
ol_peaks = [p for p in pas_peak_pos if quantiles[0] < p <= quantiles[-1]]
if not ol_peaks:
ol_peaks = [quantiles[1]]
for trid in tr_ids:
tr = self.transcripts[trid]
tr['PAS_unified'] = {}
for sa, sa_pas in tr['PAS'].items():
pas_unified = {}
for pos, c in sa_pas.items():
next_peak = min((p for p in ol_peaks if p > tr['exons'][-1][0]),
default=pos, key=lambda x: abs(x-pos))
pas_unified[next_peak] = pas_unified.get(next_peak, 0)+c
tr['PAS_unified'][sa] = pas_unified
for tr in self.transcripts:
# find the most common tss/pas per transcript, and set the exon boundaries
sum_tss = {}
sum_pas = {}
start = end = max_tss = max_pas = 0
for sa_tss in tr['TSS_unified'].values():
for pos, cov in sa_tss.items():
sum_tss[pos] = sum_tss.get(pos, 0)+cov
for pos, cov in sum_tss.items():
if cov > max_tss:
max_tss = cov
start = pos
for sa_pas in tr['PAS_unified'].values():
for pos, cov in sa_pas.items():
sum_pas[pos] = sum_pas.get(pos, 0)+cov
for pos, cov in sum_pas.items():
if cov > max_pas:
max_pas = cov
end = pos
if self.strand == '-':
start, end = end, start
if start >= end: # for monoexons this may happen in rare situations
assert len(tr['exons']) == 1
tr['TSS_unified'] = None
tr['PAS_unified'] = None
else:
try:
# issues if the new exon start is behind the exon end
assert start < tr['exons'][0][1] or len(tr['exons']) == 1, 'error unifying %s: %s>=%s' % (tr["exons"], start, tr['exons'][0][1])
tr['exons'][0][0] = start
assert end > tr['exons'][-1][0] or len(tr['exons']) == 1, 'error unifying %s: %s<=%s' % (tr["exons"], end, tr['exons'][-1][0])
tr['exons'][-1][1] = end
except AssertionError:
logger.error('%s TSS= %s, PAS=%s -> TSS_unified= %s, PAS_unified=%s', self, tr['TSS'], tr['PAS'], tr['TSS_unified'], tr['PAS_unified'])
raise
def _coding_len(exons, cds):
coding_len = [0, 0, 0]
state = 0
for e in exons:
if state < 2 and e[1] >= cds[state]:
coding_len[state] += cds[state] - e[0]
if state == 0 and cds[1] <= e[1]: # special case: CDS start and end in same exon
coding_len[1] = cds[1] - cds[0]
coding_len[2] = e[1] - cds[1]
state += 2
else:
coding_len[state + 1] = e[1] - cds[state]
state += 1
else:
coding_len[state] += e[1] - e[0]
return coding_len
def repeat_len(seq1, seq2, wobble, max_mm):
''' Calcluate direct repeat length between seq1 and seq2
'''
score = [0]*(2*wobble+1)
delta = int(len(seq1)/2-wobble)
for w in range(2*wobble+1): # wobble
s1 = seq1[w:len(seq1)-(2*wobble-w)]
s2 = seq2[wobble:len(seq2)-wobble]
align = [a == b for a, b in zip(s1, s2)]
score_left = find_runlength(reversed(align[:delta]), max_mm)
score_right = find_runlength(align[delta:], max_mm)
score[w] = max([score_left[fmm]+score_right[max_mm-fmm] for fmm in range(max_mm+1)])
return max(score)
def find_runlength(align, max_mm):
'''Find the runlength, e.g. the number of True in the list before the max_mm+1 False occur.
'''
score = [0]*(max_mm+1)
mm = 0
for a in align:
if not a:
mm += 1
if mm > max_mm:
return score
score[mm] = score[mm-1]
else:
score[mm] += 1
for i in range(mm+1, max_mm+1):
score[i] = score[i-1]
return score