-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdetectbreak.py
429 lines (374 loc) · 15 KB
/
detectbreak.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
#!/usr/bin/env python3
"""
Find ssDNA breaks and print a BED file with detected break sites to stdout.
The input is a BAM or CRAM file with aligned reads.
The output BED file has a special "#gffTags" header that is not part of the BED
specification, but that IGV understands and makes it display some nice
annotations when hovering with the mouse over an annotated event.
"""
# /// script
# dependencies = [
# "pyfaidx",
# "pysam",
# ]
# ///
import sys
import argparse
import functools
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Iterator
import time
import pysam
from pyfaidx import Fasta
# Time interval in seconds between log messages printing how many records have been processed
UPDATE_INTERVAL = 60
class CommandlineError(Exception):
pass
@dataclass
class BreakEvent:
start: int # position of first mutated base (reference coordinate)
end: int # position of last mutated base
start_unmodified: int # position of last unmodified base preceding first modified base
end_unmodified: int # position of first unmodified base following last modified base
record: pysam.AlignedSegment
query_bases: list[str]
base_qualities: list[int]
is_revcomp: bool
def region_tuple(self):
return (self.record.reference_name, self.start, self.end)
def bed_record(self, error_rate: float) -> str:
"""Format as BED record"""
record = self.record
base_qual = ",".join(str(q) for q in self.base_qualities)
bases = "".join(b if b is not None else "-" for b in self.query_bases)
number_passes = f"number_passes={record.get_tag('np')};" if record.has_tag("np") else ""
formatted_tags = (
f"Name={bases}{'/rc' if self.is_revcomp else ''};"
f"read_name={record.query_name};"
f"mapping_quality={record.mapping_quality};"
f"bases={bases};"
f"base_qualities={base_qual};"
+ number_passes
+ f"error_rate={error_rate:.2%}25;" # %25 is a URL-encoded '%'
f"mutated_region={record.reference_name}:{self.start + 1}-{self.end};"
f"reverse_complement={'yes' if self.is_revcomp else 'no'}"
)
if self.is_revcomp:
start, end = self.end, self.end_unmodified
else:
start, end = self.start_unmodified, self.start
assert start <= end
return f"{self.record.reference_name}\t{start}\t{end}\t{formatted_tags}"
def count_mismatches(self) -> int:
return len(self.query_bases) - self.query_bases.count(None)
@property
def count(self) -> int:
return len(self.query_bases)
class Statistics:
def __init__(self):
# Record statistics
self.unfiltered_records: int = 0
self.filtered_not_primary: int = 0
self.filtered_min_passes: int = 0
self.filtered_max_error_rate: int = 0
self.records: int = 0
self.event_counts = [0, 0, 0] # zero, one, two or more
# Event filtering statistics
self.unfiltered_events: int = 0
self.events: int = 0 # Final number of events
self.filtered_min_affected: int = 0
self.filtered_min_mismatches: int = 0
self.filtered_min_quality: int = 0
self.filtered_min_average_quality: int = 0
def main():
parser = argparse.ArgumentParser(usage=__doc__)
parser.add_argument(
"--region",
"-r",
help="Only work on reads within this region",
)
parser.add_argument(
"--output-bam", metavar="PATH",
help="Write alignments on which events were detected to PATH"
)
parser.add_argument(
"--min-passes",
metavar="N",
type=int,
help="Skip HiFi reads with fewer than N passes (requires 'np' tag)",
)
parser.add_argument(
"--max-error-rate",
"-e",
metavar="RATE",
type=float,
default=0.1,
help="Skip reads with error rate higher than RATE",
)
parser.add_argument(
"--min-affected",
"-N",
metavar="N",
type=int,
default=5,
help="Require N affected bases (mismatch or deletion)",
)
parser.add_argument(
"--min-mismatches",
"-n",
metavar="N",
type=int,
default=3,
help="Require at least N mismatching (not deleted) bases",
)
parser.add_argument(
"--min-average-quality",
"-b",
metavar="QUAL",
type=int,
default=10,
help="Require at least QUAL average base quality for mismatching bases",
)
parser.add_argument(
"--min-base-quality",
"-q",
metavar="QUAL",
type=int,
default=10,
help="Require at least QUAL base quality for all mismatching bases",
)
parser.add_argument("ref", metavar="fasta", help="Indexed reference FASTA")
parser.add_argument("bam")
args = parser.parse_args()
run(**vars(args))
def run(
ref: str,
bam: str,
output_bam: str | None,
region: str | None,
min_passes: int | None,
max_error_rate: float | None,
min_affected: int,
min_mismatches: int,
min_average_quality: int,
min_base_quality: int,
):
if min_mismatches > min_affected:
raise CommandlineError("--min-mismatches must not be larger than --min-affected")
with open(ref) as f:
if f.read(1) != ">":
raise CommandlineError(
f"file '{ref}' does not appear to be a FASTA file "
f"as it does not start with the character '>'."
)
with ExitStack() as stack:
fasta = stack.enter_context(Fasta(ref))
af = pysam.AlignmentFile(bam)
if output_bam is not None:
if output_bam.endswith(".bam"):
mode = "wb"
elif output_bam.endswith(".cram"):
mode = "wc"
else:
mode = "w"
output_alignments = stack.enter_context(pysam.AlignmentFile(output_bam, mode=mode, template=af))
else:
output_alignments = None
stats = Statistics()
start_time = time.time()
next_update = start_time + UPDATE_INTERVAL
stderr_is_a_tty = sys.stderr.isatty()
print("#gffTags")
for n, record in enumerate(af.fetch(region=region)):
if stderr_is_a_tty and n % 1000 == 0 and (now := time.time()) >= next_update:
rate = n / (now - start_time)
print(f"Processed {n} alignment records in {now - start_time:.1f} s at {rate:.0f} records/s", file=sys.stderr)
next_update += UPDATE_INTERVAL
# Filter alignments
if record.is_secondary or record.is_supplementary or record.is_unmapped:
stats.filtered_not_primary += 1
continue
if (
min_passes is not None
and record.has_tag("np")
and record.get_tag("np") < min_passes
):
stats.filtered_min_passes += 1
continue
contig_sequence = get_contig_sequence(fasta, record.reference_name)
reference_sequence = contig_sequence[record.reference_start:record.reference_end]
try:
errors = record.get_tag("NM")
except KeyError:
errors = len(alignment_error_tuples(record, reference_sequence))
error_rate = errors / len(reference_sequence)
if max_error_rate is not None and error_rate > max_error_rate:
stats.filtered_max_error_rate += 1
continue
# Detect breaks on forward and reverse strand
events = list(detect_break(record, reference_sequence, revcomp=False))
events.extend(detect_break(record, reference_sequence, revcomp=True))
stats.event_counts[min(len(events), 2)] += 1
# Filter events and print BED records for those that remain
written_events = 0
for event in events:
stats.unfiltered_events += 1
if event.count < min_affected:
stats.filtered_min_affected += 1
continue
if event.count_mismatches() < min_mismatches:
stats.filtered_min_mismatches += 1
continue
if any([bq < min_base_quality for bq in event.base_qualities if bq is not None]):
stats.filtered_min_quality += 1
continue
if mean([bq for bq in event.base_qualities if bq is not None]) < min_average_quality:
stats.filtered_min_average_quality += 1
continue
written_events += 1
print(event.bed_record(error_rate))
if written_events > 0 and output_alignments is not None:
output_alignments.write(record)
stats.events += written_events
stats.records += 1
stats.unfiltered_records = n
# Final rate update
now = time.time()
rate = n / (now - start_time)
print(f"Done. Processed {n} alignment records in {now - start_time:.1f} s at {rate:.0f} records/s", file=sys.stderr)
def log(n, *args, **kwargs):
print(f"{n:9}", *args, **kwargs, file=sys.stderr)
print("Discarding reads with fewer than", min_passes if min_passes is not None else 0, "passes (np tag)", file=sys.stderr)
print("Discarding reads with error rate higher than", max_error_rate, file=sys.stderr)
print("Discarding events with fewer than", min_affected, "consecutive affected bases", file=sys.stderr)
print("Discarding events with fewer than", min_mismatches, "mismatching bases", file=sys.stderr)
print(file=sys.stderr)
log(stats.unfiltered_records, "total alignments in input file")
log(stats.filtered_not_primary, "non-primary alignments filtered out")
log(stats.filtered_min_passes, "alignments with too few passes (np tag) filtered out")
log(stats.filtered_max_error_rate, "alignments with too high error rate filtered out")
log(stats.records, "alignments remained after filtering and were analyzed for events")
print(file=sys.stderr)
log(stats.event_counts[0], "alignments had no event")
log(stats.event_counts[1], "alignments had one event")
log(stats.event_counts[2], "alignments had two or more events")
print(file=sys.stderr)
log(stats.unfiltered_events, "events found")
log(stats.filtered_min_affected, "events filtered because they had too few consecutive affected bases")
log(stats.filtered_min_mismatches, "events filtered because they had too many deletions")
log(stats.filtered_min_quality, "events filtered because at least one base had too low quality")
log(stats.filtered_min_average_quality, "events filtered because they had too low average base quality")
log(stats.events, "events reported after filtering")
@functools.lru_cache(maxsize=2)
def get_contig_sequence(fasta, contig_name):
return fasta[contig_name][:].seq.upper()
def detect_break(
record: pysam.AlignedSegment, reference_sequence: str, revcomp: bool
) -> Iterator[BreakEvent]:
"""Detect ssDNA break events on a single AlignedSegment"""
base = "T" if revcomp else "A"
mutated = False # are we currently in a mutated region?
prev_position = record.reference_start
event = None
for query_pos, ref_pos in aligned_pairs_without_softclips(record):
if ref_pos is None:
# Insertion
continue
reference_base = reference_sequence[ref_pos - record.reference_start]
if reference_base != base:
continue
query_base = None if query_pos is None else record.query_sequence[query_pos]
base_quality = None if query_pos is None else record.query_qualities[query_pos]
cur_mutated = query_base != base
if not mutated:
if cur_mutated:
# New break starts
event = BreakEvent(
start=ref_pos,
end=ref_pos + 1,
start_unmodified=prev_position,
end_unmodified=record.reference_end, # fixed later
record=record,
query_bases=[query_base],
base_qualities=[base_quality],
is_revcomp=revcomp,
)
else:
if cur_mutated:
# Extend detected region
event.end = ref_pos + 1
event.query_bases.append(query_base)
event.base_qualities.append(base_quality)
else:
# Unmutated replacement base: End of detected region
pos = reference_sequence.find(base, event.end - record.reference_start)
if pos != -1:
event.end_unmodified = record.reference_start + pos
yield event
mutated = cur_mutated
prev_position = ref_pos
if mutated:
yield event
def alignment_error_tuples(
record: pysam.AlignedSegment, reference_sequence: str
) -> list[tuple[int, int, str, str]]:
"""
Similar to get_aligned_pairs(), but excludes soft-clipped positions and
positions where query and reference are identical.
Return a list of tuples (query_pos, ref_pos, query_base, ref_base).
"""
result = []
for query_pos, ref_pos in aligned_pairs_without_softclips(record):
ref_base = (
None
if ref_pos is None
else reference_sequence[ref_pos - record.reference_start]
)
query_base = None if query_pos is None else record.query_sequence[query_pos]
if ref_base != query_base:
result.append((query_pos, ref_pos, query_base, ref_base))
return result
def aligned_pairs_without_softclips(record: pysam.AlignedSegment):
clip_start, clip_end_length = soft_clip_lengths(record.cigartuples)
aligned_pairs = record.get_aligned_pairs()[clip_start:]
if clip_end_length > 0:
aligned_pairs = aligned_pairs[:-clip_end_length]
return aligned_pairs
def soft_clip_lengths(cigar: list[tuple[int, int]]) -> tuple[int, int]:
"""
>>> soft_clip_lengths([(4, 99), (0, 5), (4, 22)])
(99, 22)
>>> soft_clip_lengths([(4, 99)])
(99, 0)
>>> soft_clip_lengths([(0, 5), (4, 22)])
(0, 22)
>>> soft_clip_lengths([(0, 5)])
(0, 0)
"""
clip_start, clip_end_length = 0, 0
op, length = cigar[0]
if op == 4: # S
clip_start = length
if len(cigar) > 1:
op, length = cigar[-1]
if op == 4:
clip_end_length = length
return clip_start, clip_end_length
def median(values):
n = len(values)
assert n > 0
if n % 2 != 0:
return sorted(values)[n // 2]
else:
return sum(sorted(values)[n // 2 - 1 : n // 2 + 1]) / 2
#assert median([1]) == 1
#assert median([2,3]) == 2.5
#assert median([3,2,3,2]) == 2.5
#assert median([10,3,10,5,3]) == 5
def mean(values) -> float:
if not values:
return 0
return sum(values) / len(values)
if __name__ == "__main__":
main()