Skip to content

Commit 21a16d0

Browse files
committed
Fix get_tfreq_msca typo (thanks @Sotdo, #315)
1 parent 41e1404 commit 21a16d0

File tree

2 files changed

+765
-102
lines changed

2 files changed

+765
-102
lines changed

goatools/semantic.py

Lines changed: 135 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,33 @@
77
notebooks/semantic_similarity.ipynb
88
"""
99

10-
from __future__ import print_function
11-
1210
import sys
13-
from collections import Counter
14-
from collections import defaultdict
15-
from goatools.godag.consts import NAMESPACE2GO
16-
from goatools.godag.consts import NAMESPACE2NS
17-
from goatools.godag.go_tasks import get_go2ancestors
18-
from goatools.gosubdag.gosubdag import GoSubDag
19-
from goatools.godag.relationship_combos import RelationshipCombos
20-
from goatools.anno.update_association import clean_anno
21-
from goatools.utils import get_b2aset
11+
from collections import Counter, defaultdict
12+
13+
from .anno.update_association import clean_anno
14+
from .godag.consts import NAMESPACE2GO, NAMESPACE2NS
15+
from .godag.go_tasks import get_go2ancestors
16+
from .godag.relationship_combos import RelationshipCombos
17+
from .gosubdag.gosubdag import GoSubDag
18+
from .utils import get_b2aset
2219

2320

2421
class TermCounts:
25-
'''
26-
TermCounts counts the term counts for each
27-
'''
22+
"""
23+
TermCounts counts the term counts for each
24+
"""
25+
2826
# pylint: disable=too-many-instance-attributes
2927
def __init__(self, go2obj, annots, relationships=None, **kws):
30-
'''
31-
Initialise the counts and
32-
'''
33-
_prt = kws.get('prt')
28+
"""
29+
Initialise the counts and
30+
"""
31+
_prt = kws.get("prt")
32+
# Handle boolean prt parameter by converting to sys.stdout if True
33+
if _prt is True:
34+
_prt = sys.stdout
35+
elif _prt is False:
36+
_prt = None
3437
# Backup
3538
self.go2obj = go2obj # Full GODag
3639
self.annots, go_alts = clean_anno(annots, go2obj, _prt)[:2]
@@ -40,33 +43,46 @@ def __init__(self, go2obj, annots, relationships=None, **kws):
4043
self.gene2gos = get_b2aset(self.go2genes)
4144
# Annotation main GO IDs (prefer main id to alt_id)
4245
self.goids = set(self.go2genes.keys())
43-
self.gocnts = Counter({go:len(geneset) for go, geneset in self.go2genes.items()})
46+
self.gocnts = Counter(
47+
{go: len(geneset) for go, geneset in self.go2genes.items()}
48+
)
4449
# Get total count for each branch: BP MF CC
4550
self.aspect_counts = {
46-
'biological_process': self.gocnts.get(NAMESPACE2GO['biological_process'], 0),
47-
'molecular_function': self.gocnts.get(NAMESPACE2GO['molecular_function'], 0),
48-
'cellular_component': self.gocnts.get(NAMESPACE2GO['cellular_component'], 0)}
51+
"biological_process": self.gocnts.get(
52+
NAMESPACE2GO["biological_process"], 0
53+
),
54+
"molecular_function": self.gocnts.get(
55+
NAMESPACE2GO["molecular_function"], 0
56+
),
57+
"cellular_component": self.gocnts.get(
58+
NAMESPACE2GO["cellular_component"], 0
59+
),
60+
}
4961
self._init_add_goid_alt(go_alts)
5062
self.gosubdag = GoSubDag(
5163
set(self.gocnts.keys()),
5264
go2obj,
5365
tcntobj=self,
5466
relationships=_relationship_set,
55-
prt=None)
67+
prt=None,
68+
)
5669
if _prt:
5770
self.prt_objdesc(_prt)
5871

5972
def get_annotations_reversed(self):
6073
"""Return go2geneset for all GO IDs explicitly annotated to a gene"""
61-
return set.union(*get_b2aset(self.annots))
74+
go2genes = get_b2aset(self.annots)
75+
if go2genes:
76+
return set.union(*go2genes.values())
77+
return set()
6278

6379
def _init_go2genes(self, relationship_set, godag):
64-
'''
65-
Fills in the genes annotated to each GO, including ancestors
80+
"""
81+
Fills in the genes annotated to each GO, including ancestors
6682
67-
Due to the ontology structure, gene products annotated to
68-
a GO Terma are also annotated to all ancestors.
69-
'''
83+
Due to the ontology structure, gene products annotated to
84+
a GO Terma are also annotated to all ancestors.
85+
"""
7086
go2geneset = defaultdict(set)
7187
go2up = get_go2ancestors(set(godag.values()), relationship_set)
7288
# Fill go-geneset dict with GO IDs in annotations and their corresponding counts
@@ -84,9 +100,9 @@ def _init_go2genes(self, relationship_set, godag):
84100
return dict(go2geneset)
85101

86102
def _init_add_goid_alt(self, not_main):
87-
'''
88-
Add alternate GO IDs to term counts. Report GO IDs not found in GO DAG.
89-
'''
103+
"""
104+
Add alternate GO IDs to term counts. Report GO IDs not found in GO DAG.
105+
"""
90106
if not not_main:
91107
return
92108
for go_id in not_main:
@@ -96,54 +112,63 @@ def _init_add_goid_alt(self, not_main):
96112
self.go2genes[go_id] = self.go2genes[goid_main]
97113

98114
def get_count(self, go_id):
99-
'''
100-
Returns the count of that GO term observed in the annotations.
101-
'''
115+
"""
116+
Returns the count of that GO term observed in the annotations.
117+
"""
102118
return self.gocnts[go_id]
103119

104120
def get_total_count(self, aspect):
105-
'''
106-
Gets the total count that's been precomputed.
107-
'''
121+
"""
122+
Gets the total count that's been precomputed.
123+
"""
108124
return self.aspect_counts[aspect]
109125

110126
def get_term_freq(self, go_id):
111-
'''
112-
Returns the frequency at which a particular GO term has
113-
been observed in the annotations.
114-
'''
127+
"""
128+
Returns the frequency at which a particular GO term has
129+
been observed in the annotations.
130+
"""
115131
num_ns = float(self.get_total_count(self.go2obj[go_id].namespace))
116-
return float(self.get_count(go_id))/num_ns if num_ns != 0 else 0
132+
return float(self.get_count(go_id)) / num_ns if num_ns != 0 else 0
117133

118134
def get_gosubdag_all(self, prt=sys.stdout):
119-
'''
120-
Get GO DAG subset include descendants which are not included in the annotations
121-
'''
135+
"""
136+
Get GO DAG subset include descendants which are not included in the annotations
137+
"""
122138
goids = set()
123139
for gos in self.gosubdag.rcntobj.go2descendants.values():
124140
goids.update(gos)
125-
return GoSubDag(goids, self.go2obj, self.gosubdag.relationships, tcntobj=self, prt=prt)
141+
return GoSubDag(
142+
goids, self.go2obj, self.gosubdag.relationships, tcntobj=self, prt=prt
143+
)
126144

127145
def prt_objdesc(self, prt=sys.stdout):
128146
"""Print TermCount object description"""
129147
ns_tot = sorted(self.aspect_counts.items())
130-
cnts = ['{NS}({N:,})'.format(NS=NAMESPACE2NS.get(ns, ns), N=n) for ns, n in ns_tot if n != 0]
131-
go_msg = "TermCounts {CNT}".format(CNT=' '.join(cnts))
132-
prt.write('{GO_MSG} {N:,} genes\n'.format(GO_MSG=go_msg, N=len(self.gene2gos)))
148+
cnts = [
149+
"{NS}({N:,})".format(NS=NAMESPACE2NS.get(ns, ns), N=n)
150+
for ns, n in ns_tot
151+
if n != 0
152+
]
153+
go_msg = "TermCounts {CNT}".format(CNT=" ".join(cnts))
154+
prt.write("{GO_MSG} {N:,} genes\n".format(GO_MSG=go_msg, N=len(self.gene2gos)))
133155
self.gosubdag.prt_objdesc(prt, go_msg)
134156

135157

136158
def get_info_content(go_id, termcounts):
137-
'''
138-
Retrieve the information content of a GO term.
139-
'''
159+
"""
160+
Retrieve the information content of a GO term.
161+
"""
162+
if termcounts is None:
163+
return 0.0
140164
ntd = termcounts.gosubdag.go2nt.get(go_id)
141165
return ntd.tinfo if ntd else 0.0
142166

167+
143168
def resnik_sim(go_id1, go_id2, godag, termcounts):
144-
'''
145-
Computes Resnik's similarity measure.
146-
'''
169+
"""
170+
Computes Resnik's similarity measure.
171+
"""
147172
goterm1 = godag[go_id1]
148173
goterm2 = godag[go_id2]
149174
if goterm1.namespace == goterm2.namespace:
@@ -153,76 +178,84 @@ def resnik_sim(go_id1, go_id2, godag, termcounts):
153178

154179

155180
def lin_sim(goid1, goid2, godag, termcnts, dfltval=None):
156-
'''
157-
Computes Lin's similarity measure.
158-
'''
181+
"""
182+
Computes Lin's similarity measure.
183+
"""
159184
sim_r = resnik_sim(goid1, goid2, godag, termcnts)
160185
return lin_sim_calc(goid1, goid2, sim_r, termcnts, dfltval)
161186

162187

163188
def lin_sim_calc(goid1, goid2, sim_r, termcnts, dfltval=None):
164-
'''
165-
Computes Lin's similarity measure using pre-calculated Resnik's similarities.
166-
'''
189+
"""
190+
Computes Lin's similarity measure using pre-calculated Resnik's similarities.
191+
"""
167192
# If goid1 and goid2 are in the same namespace
168193
if sim_r is not None:
169194
tinfo1 = get_info_content(goid1, termcnts)
170195
tinfo2 = get_info_content(goid2, termcnts)
171196
info = tinfo1 + tinfo2
172197
# Both GO IDs must be annotated
173198
if tinfo1 != 0.0 and tinfo2 != 0.0 and info != 0:
174-
return (2*sim_r)/info
199+
return (2 * sim_r) / info
200+
# Check if they are identical terms - for identical terms with zero info content, return 1.0
175201
if termcnts.go2obj[goid1].item_id == termcnts.go2obj[goid2].item_id:
176202
return 1.0
177203
# The GOs are separated by the root term, so are not similar
178204
if sim_r == 0.0:
179205
return 0.0
180206
return dfltval
181207

208+
182209
def get_freq_msca(go_id1, go_id2, godag, termcounts):
183-
'''
184-
Retrieve the frequency of the MSCA of two GO terms.
185-
'''
186-
goterm1 = godag[go_id1]
187-
goterm2 = godag[go_id2]
188-
if goterm1.namespace == goterm2.namespace:
189-
msca_goid = deepest_common_ancestor([go_id1, go_id2], godag)
190-
ntd = termcounts.gosubdag.go2nt.get(msca_goid)
191-
return ntd.tfreq
192-
return 0
210+
"""
211+
Retrieve the frequency of the MSCA of two GO terms.
212+
"""
213+
try:
214+
goterm1 = godag[go_id1]
215+
goterm2 = godag[go_id2]
216+
if goterm1.namespace == goterm2.namespace:
217+
msca_goid = deepest_common_ancestor([go_id1, go_id2], godag)
218+
ntd = termcounts.gosubdag.go2nt.get(msca_goid)
219+
return ntd.tfreq if ntd else 0
220+
return 0
221+
except KeyError:
222+
return 0
223+
193224

194225
def schlicker_sim(goid1, goid2, godag, termcnts, dfltval=None):
195-
'''
196-
Computes Schlicker's similarity measure.
197-
'''
226+
"""
227+
Computes Schlicker's similarity measure.
228+
"""
198229
sim_r = resnik_sim(goid1, goid2, godag, termcnts)
199-
tfreq = get_tfreq_msca(goid1, goid2, godag, termcnts)
230+
tfreq = get_freq_msca(goid1, goid2, godag, termcnts)
200231
return schlicker_sim_calc(goid1, goid2, sim_r, tfreq, termcnts, dfltval)
201232

233+
202234
def schlicker_sim_calc(goid1, goid2, sim_r, tfreq, termcnts, dfltval=None):
203-
'''
204-
Computes Schlicker's similarity measure using pre-calculated Resnik's similarities.
205-
'''
235+
"""
236+
Computes Schlicker's similarity measure using pre-calculated Resnik's similarities.
237+
"""
206238
# If goid1 and goid2 are in the same namespace
207239
if sim_r is not None:
208240
tinfo1 = get_info_content(goid1, termcnts)
209241
tinfo2 = get_info_content(goid2, termcnts)
210242
info = tinfo1 + tinfo2
211243
# Both GO IDs must be annotated
212244
if tinfo1 != 0.0 and tinfo2 != 0.0 and info != 0:
213-
return (2*sim_r)/(info) * (1 - tfreq)
245+
return (2 * sim_r) / (info) * (1 - tfreq)
214246
if termcnts.go2obj[goid1].item_id == termcnts.go2obj[goid2].item_id:
215-
return (1.0 - tfreq)
247+
return 1.0 - tfreq
216248
# The GOs are separated by the root term, so are not similar
217249
if sim_r == 0.0:
218250
return 0.0
219251
return dfltval
220252

253+
221254
def common_parent_go_ids(goids, godag):
222-
'''
223-
This function finds the common ancestors in the GO
224-
tree of the list of goids in the input.
225-
'''
255+
"""
256+
This function finds the common ancestors in the GO
257+
tree of the list of goids in the input.
258+
"""
226259
# Find main GO ID candidates from first main or alt GO ID
227260
rec = godag[goids[0]]
228261
candidates = rec.get_all_parents()
@@ -240,19 +273,19 @@ def common_parent_go_ids(goids, godag):
240273

241274

242275
def deepest_common_ancestor(goterms, godag):
243-
'''
244-
This function gets the nearest common ancestor
245-
using the above function.
246-
Only returns single most specific - assumes unique exists.
247-
'''
276+
"""
277+
This function gets the nearest common ancestor
278+
using the above function.
279+
Only returns single most specific - assumes unique exists.
280+
"""
248281
# Take the element at maximum depth.
249282
return max(common_parent_go_ids(goterms, godag), key=lambda t: godag[t].depth)
250283

251284

252285
def min_branch_length(go_id1, go_id2, godag, branch_dist):
253-
'''
254-
Finds the minimum branch length between two terms in the GO DAG.
255-
'''
286+
"""
287+
Finds the minimum branch length between two terms in the GO DAG.
288+
"""
256289
# First get the deepest common ancestor
257290
goterm1 = godag[go_id1]
258291
goterm2 = godag[go_id2]
@@ -273,18 +306,18 @@ def min_branch_length(go_id1, go_id2, godag, branch_dist):
273306

274307

275308
def semantic_distance(go_id1, go_id2, godag, branch_dist=None):
276-
'''
277-
Finds the semantic distance (minimum number of connecting branches)
278-
between two GO terms.
279-
'''
309+
"""
310+
Finds the semantic distance (minimum number of connecting branches)
311+
between two GO terms.
312+
"""
280313
return min_branch_length(go_id1, go_id2, godag, branch_dist)
281314

282315

283316
def semantic_similarity(go_id1, go_id2, godag, branch_dist=None):
284-
'''
285-
Finds the semantic similarity (inverse of the semantic distance)
286-
between two GO terms.
287-
'''
317+
"""
318+
Finds the semantic similarity (inverse of the semantic distance)
319+
between two GO terms.
320+
"""
288321
dist = semantic_distance(go_id1, go_id2, godag, branch_dist)
289322
if dist is not None:
290323
return 1.0 / float(dist) if dist != 0 else 1.0

0 commit comments

Comments
 (0)