Skip to content

Commit 2945a19

Browse files
committed
fix (gtf): created GTFSourceInferrer class
1 parent 8685d08 commit 2945a19

File tree

4 files changed

+45
-32
lines changed

4 files changed

+45
-32
lines changed

moPepGen/cli/generate_index.py

+3
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ def create_gtf_copy(file:Path, output_dir:Path, symlink:bool=True) -> Path:
8383
if file.suffix.lower() == '.gz':
8484
if symlink:
8585
symlink = False
86+
logger(
87+
"--gtf-symlink was suppressed because compressed GTF file was received. "
88+
)
8689
elif file.suffix.lower() != '.gtf':
8790
raise ValueError(f"Cannot handle gtf file {file}")
8891

moPepGen/gtf/GTFPointer.py

+3-14
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
TranscriptAnnotationModel,
99
GTF_FEATURE_TYPES
1010
)
11+
from moPepGen.gtf.GTFSourceInferrer import GTFSourceInferrer
1112

1213

1314
GENE_DICT_CACHE_SIZE = 10
@@ -104,8 +105,7 @@ def to_line(self) -> str:
104105
def iterate_pointer(handle:IO, source:str=None) -> Iterable[Union[GenePointer, TranscriptPointer]]:
105106
""" Iterate over a GTF file and yield pointers. """
106107
if not source:
107-
count = 0
108-
inferred = {}
108+
inferrer = GTFSourceInferrer()
109109

110110
cur_gene_id:str = None
111111
cur_tx_id:str = None
@@ -124,18 +124,7 @@ def iterate_pointer(handle:IO, source:str=None) -> Iterable[Union[GenePointer, T
124124
record = GtfIO.line_to_seq_feature(line)
125125

126126
if not source:
127-
if count > 100:
128-
inferred = sorted(inferred.items(), key=lambda x: x[1])
129-
source = inferred[-1][0]
130-
record.source = source
131-
else:
132-
count += 1
133-
record.infer_annotation_source()
134-
inferred_source = record.source
135-
if inferred_source not in inferred:
136-
inferred[inferred_source] = 1
137-
else:
138-
inferred[inferred_source] += 1
127+
record.source = inferrer.infer(record)
139128
else:
140129
record.source = source
141130

moPepGen/gtf/GTFSourceInferrer.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
""" Infer GTF source (e.g. GENCODE/ENSEMBL) """
2+
from __future__ import annotations
3+
from typing import Dict, TYPE_CHECKING
4+
5+
6+
if TYPE_CHECKING:
7+
from moPepGen.gtf.GTFSeqFeature import GTFSeqFeature
8+
9+
class GTFSourceInferrer():
10+
""" Infer GTF source (e.g. GENOCDE/ENSEMBL) """
11+
def __init__(self):
12+
""" Constructor """
13+
self.max_iter = 100
14+
self.data:Dict[str,int] = {}
15+
self.count = 0
16+
self.source:str = None
17+
18+
def infer(self, record:GTFSeqFeature) -> str:
19+
""" Infer the source of a GTF record """
20+
if self.count > self.max_iter:
21+
if not self.source:
22+
self.source = sorted(self.data.items(), key=lambda x:x[1])[-1][0]
23+
return self.source
24+
self.count += 1
25+
record.infer_annotation_source()
26+
source = record.source
27+
if source not in self.data:
28+
self.data[source] = 1
29+
else:
30+
self.data[source] += 1
31+
return source

moPepGen/gtf/GenomicAnnotation.py

+8-18
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .TranscriptAnnotationModel import TranscriptAnnotationModel, GTF_FEATURE_TYPES
1111
from .GeneAnnotationModel import GeneAnnotationModel
1212
from .GTFSeqFeature import GTFSeqFeature
13+
from .GTFSourceInferrer import GTFSourceInferrer
1314

1415

1516
if TYPE_CHECKING:
@@ -113,25 +114,14 @@ def dump_gtf(self, handle:Union[str, IO], biotype:List[str]=None, source:str=Non
113114
"""
114115
record:GTFSeqFeature
115116
if not source:
116-
count = 0
117-
inferred = {}
117+
inferrer = GTFSourceInferrer()
118+
118119
for record in GtfIO.parse(handle):
119120
if biotype is not None and record.biotype not in biotype:
120121
continue
121122

122123
if not source:
123-
if count > 100:
124-
inferred = sorted(inferred.items(), key=lambda x: x[1])
125-
source = inferred[-1][0]
126-
record.source = source
127-
else:
128-
count += 1
129-
record.infer_annotation_source()
130-
inferred_source = record.source
131-
if inferred_source not in inferred:
132-
inferred[inferred_source] = 1
133-
else:
134-
inferred[inferred_source] += 1
124+
record.source = inferrer.infer(record)
135125
else:
136126
record.source = source
137127

@@ -142,11 +132,11 @@ def dump_gtf(self, handle:Union[str, IO], biotype:List[str]=None, source:str=Non
142132

143133
self.add_transcript_record(record)
144134

145-
if not source:
146-
inferred = sorted(inferred.items(), key=lambda x: x[1])
147-
source = inferred[-1][0]
148135

149-
self.source = source
136+
if not source:
137+
source = inferrer.source
138+
else:
139+
self.source = source
150140

151141
for transcript_model in self.transcripts.values():
152142
transcript_model.sort_records()

0 commit comments

Comments
 (0)