Skip to content

Commit 31136be

Browse files
committed
generalize loading project hts and checking for multiple sample types on project table load"
1 parent 5442a7f commit 31136be

File tree

2 files changed

+116
-93
lines changed

2 files changed

+116
-93
lines changed

hail_search/queries/base.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,6 @@
2121
# https://github.com/broadinstitute/seqr-private/issues/1283#issuecomment-1973392719
2222
MAX_PARTITIONS = 12
2323

24-
# Need to chunk tables or else evaluating table globals throws LineTooLong exception
25-
# However, minimizing number of chunks minimizes number of aggregations/ evals and improves performance
26-
# Adapted from https://discuss.hail.is/t/importing-many-sample-specific-vcfs/2002/8
27-
MAX_HTS_TO_JOIN = 64
28-
2924
logger = logging.getLogger(__name__)
3025

3126

@@ -340,35 +335,49 @@ def _import_and_filter_multiple_project_hts(
340335

341336
return self._merge_filtered_hts(filtered_comp_het_project_hts, filtered_project_hts, n_partitions)
342337

343-
def _load_project_hts(self, project_samples: dict, n_partitions: int, **kwargs) -> list[tuple[hl.Table, dict]]:
338+
def _load_project_hts(
339+
self, project_samples, n_partitions,
340+
initialize_project_hts=None, initialize_sample_data=None, aggregate_project_data=None, load_project_ht_chunks=None,
341+
**kwargs
342+
):
343+
# Need to chunk tables or else evaluating table globals throws LineTooLong exception
344+
# However, minimizing number of chunks minimizes number of aggregations/ evals and improves performance
345+
# Adapted from https://discuss.hail.is/t/importing-many-sample-specific-vcfs/2002/8
346+
chunk_size = 64
344347
all_project_hts = []
345-
project_hts = []
346-
sample_data = {}
347-
for project_guid, project_sample_type_data in project_samples.items():
348-
sample_type, family_sample_data = list(project_sample_type_data.items())[0]
349-
project_ht = self._read_project_data(family_sample_data, project_guid, sample_type, project_hts, sample_data)
350-
if project_ht is None:
351-
continue
348+
if initialize_project_hts is None:
349+
initialize_project_hts = lambda: []
350+
351+
if initialize_sample_data is None:
352+
initialize_sample_data = lambda: {}
352353

353-
if len(project_hts) >= MAX_HTS_TO_JOIN:
354+
if aggregate_project_data is None:
355+
def aggregate_project_data(family_sample_data, ht, project_hts, sample_data, _):
356+
project_hts.append(ht)
357+
sample_data.update(family_sample_data)
358+
359+
if load_project_ht_chunks is None:
360+
def load_project_ht_chunks(all_project_hts, n_partitions, project_hts, sample_data, **kwargs):
354361
ht = self._prefilter_merged_project_hts(project_hts, n_partitions, **kwargs)
355362
all_project_hts.append((ht, sample_data))
356-
project_hts = []
357-
sample_data = {}
358363

359-
if project_hts:
360-
ht = self._prefilter_merged_project_hts(project_hts, n_partitions, **kwargs)
361-
all_project_hts.append((ht, sample_data))
362-
return all_project_hts
364+
project_hts = initialize_project_hts()
365+
sample_data = initialize_sample_data()
363366

364-
def _read_project_data(self, family_sample_data, project_guid, sample_type, project_hts, sample_data):
365-
project_ht = self._read_project_table(project_guid, sample_type)
366-
if project_ht is not None:
367-
project_ht = project_ht.select_globals('sample_type', 'family_guids', 'family_samples')
368-
project_hts.append(project_ht)
369-
sample_data.update(family_sample_data)
370-
371-
return project_ht
367+
for project_guid, project_sample_type_data in project_samples.items():
368+
for sample_type, family_sample_data in project_sample_type_data.items():
369+
project_ht = self._read_project_data(project_guid, sample_type)
370+
if project_ht is None:
371+
continue
372+
aggregate_project_data(family_sample_data, project_ht, project_hts, sample_data, sample_type)
373+
374+
if len(project_hts) >= chunk_size:
375+
load_project_ht_chunks(all_project_hts, n_partitions, project_hts, sample_data, **kwargs)
376+
project_hts = initialize_project_hts()
377+
sample_data = initialize_sample_data()
378+
379+
load_project_ht_chunks(all_project_hts, n_partitions, project_hts, sample_data, **kwargs)
380+
return all_project_hts
372381

373382
def import_filtered_table(self, project_samples: dict, num_families: int, **kwargs):
374383
if num_families == 1 or len(project_samples) == 1:
@@ -426,6 +435,12 @@ def _load_project_ht(
426435
def _read_project_table(self, project_guid: str, sample_type: str):
427436
return self._read_table(f'projects/{sample_type}/{project_guid}.ht')
428437

438+
def _read_project_data(self, project_guid: str, sample_type: str):
439+
project_ht = self._read_project_table(project_guid, sample_type)
440+
if project_ht is not None:
441+
project_ht = project_ht.select_globals('sample_type', 'family_guids', 'family_samples')
442+
return project_ht
443+
429444
def _prefilter_merged_project_hts(self, project_hts, n_partitions, **kwargs):
430445
ht = self._merge_project_hts(project_hts, n_partitions, include_all_globals=True)
431446
return self._prefilter_entries_table(ht, **kwargs)

hail_search/queries/mito.py

Lines changed: 73 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
PATHOGENICTY_SORT_KEY, CONSEQUENCE_SORT, \
1212
PATHOGENICTY_HGMD_SORT_KEY, MAX_LOAD_INTERVALS
1313
from hail_search.definitions import SampleType
14-
from hail_search.queries.base import BaseHailTableQuery, PredictionPath, QualityFilterFormat, MAX_PARTITIONS, \
15-
MAX_HTS_TO_JOIN
14+
from hail_search.queries.base import BaseHailTableQuery, PredictionPath, QualityFilterFormat, MAX_PARTITIONS
1615

1716
REFERENCE_DATASETS_DIR = os.environ.get('REFERENCE_DATASETS_DIR', '/seqr/seqr-reference-data')
1817
REFERENCE_DATASET_SUBDIR = 'cached_reference_dataset_queries'
@@ -145,61 +144,69 @@ def _import_and_filter_entries_ht(
145144
def _import_and_filter_multiple_project_hts(
146145
self, project_samples: dict, n_partitions=MAX_PARTITIONS, **kwargs
147146
) -> tuple[hl.Table, hl.Table]:
148-
sample_types = set()
149-
for sample_dict in project_samples.values():
150-
sample_types.update(sample_dict.keys())
151-
if len(sample_types) == 1:
152-
return super()._import_and_filter_multiple_project_hts(project_samples, n_partitions, **kwargs)
153-
154-
self._has_both_sample_types = True
155-
entries = self._load_project_hts_both_sample_types(project_samples, n_partitions, **kwargs)
147+
single_sample_type_project_samples = {}
148+
both_sample_type_project_samples = {}
149+
for project_guid, sample_dict in project_samples.items():
150+
if len(sample_dict) == 1:
151+
single_sample_type_project_samples[project_guid] = sample_dict
152+
else:
153+
both_sample_type_project_samples[project_guid] = sample_dict
156154

157155
filtered_project_hts = []
158156
filtered_comp_het_project_hts = []
159-
for entry in entries:
160-
wes_ht, wes_project_samples = entry[SampleType.WES.value]
161-
wgs_ht, wgs_project_samples = entry[SampleType.WGS.value]
162-
ht, comp_het_ht = self._filter_entries_ht_both_sample_types(
163-
wes_ht, wes_project_samples, wgs_ht, wgs_project_samples,
164-
is_merged_ht=True, **kwargs
165-
)
157+
158+
# Process projects with only one sample type separately
159+
if single_sample_type_project_samples:
160+
ht, ch_ht = super()._import_and_filter_multiple_project_hts(single_sample_type_project_samples, n_partitions, **kwargs)
166161
if ht is not None:
167162
filtered_project_hts.append(ht)
168-
if comp_het_ht is not None:
169-
filtered_comp_het_project_hts.append(comp_het_ht)
170-
171-
return self._merge_filtered_hts(filtered_comp_het_project_hts, filtered_project_hts, n_partitions)
172-
173-
def _load_project_hts_both_sample_types(
174-
self, project_samples: dict, n_partitions: int, **kwargs
175-
) -> list[dict[str, tuple[hl.Table, dict]]]:
176-
all_project_hts = []
177-
project_hts = defaultdict(list)
178-
sample_data = defaultdict(dict)
179-
180-
for project_guid, project_sample_type_data in project_samples.items():
181-
for sample_type, family_sample_data in project_sample_type_data.items():
182-
project_ht = self._read_project_data(
183-
family_sample_data, project_guid, sample_type, project_hts[sample_type], sample_data[sample_type]
163+
if ch_ht is not None:
164+
filtered_comp_het_project_hts.append(ch_ht)
165+
166+
if both_sample_type_project_samples:
167+
self._has_both_sample_types = True
168+
entries = self._load_project_hts_both_sample_types(project_samples, n_partitions, **kwargs)
169+
for entry in entries:
170+
wes_ht, wes_project_samples = entry[SampleType.WES.value]
171+
wgs_ht, wgs_project_samples = entry[SampleType.WGS.value]
172+
ht, comp_het_ht = self._filter_entries_ht_both_sample_types(
173+
wes_ht, wes_project_samples, wgs_ht, wgs_project_samples,
174+
is_merged_ht=True, **kwargs
184175
)
185-
if project_ht is None:
186-
continue
176+
if ht is not None:
177+
filtered_project_hts.append(ht)
178+
if comp_het_ht is not None:
179+
filtered_comp_het_project_hts.append(comp_het_ht)
187180

188-
# Merge both WES and WGS project_hts when either of their lengths reaches the chunk size
189-
if len(project_hts[SampleType.WES.value]) >= MAX_HTS_TO_JOIN or len(project_hts[SampleType.WGS.value]) >= MAX_HTS_TO_JOIN:
190-
self._load_project_ht_chunks(all_project_hts, kwargs, n_partitions, project_hts, sample_data)
191-
project_hts = defaultdict(list)
192-
sample_data = defaultdict(dict)
193-
194-
self._load_project_ht_chunks(all_project_hts, kwargs, n_partitions, project_hts, sample_data)
195-
return all_project_hts
181+
return self._merge_filtered_hts(filtered_comp_het_project_hts, filtered_project_hts, n_partitions)
196182

197-
def _load_project_ht_chunks(self, all_project_hts, kwargs, n_partitions, project_hts, sample_data):
198-
project_ht_dict = {}
199-
for sample_type in project_hts:
200-
ht = self._prefilter_merged_project_hts(project_hts[sample_type], n_partitions, **kwargs)
201-
project_ht_dict[sample_type] = (ht, sample_data[sample_type])
202-
all_project_hts.append(project_ht_dict)
183+
def _load_project_hts_both_sample_types(self, project_samples, n_partitions, **kwargs):
184+
def initialize_project_hts():
185+
return defaultdict(list)
186+
187+
def initialize_sample_data():
188+
return defaultdict(dict)
189+
190+
def aggregate_project_data(family_sample_data, ht, project_hts, sample_data, sample_type):
191+
project_hts[sample_type].append(ht)
192+
sample_data[sample_type].update(family_sample_data)
193+
194+
def load_project_ht_chunks(all_project_hts, n_partitions, project_hts, sample_data, **kwargs):
195+
project_ht_dict = {}
196+
for sample_type in project_hts:
197+
ht = self._prefilter_merged_project_hts(project_hts[sample_type], n_partitions, **kwargs)
198+
project_ht_dict[sample_type] = (ht, sample_data[sample_type])
199+
all_project_hts.append(project_ht_dict)
200+
201+
return super()._load_project_hts(
202+
project_samples,
203+
n_partitions,
204+
initialize_project_hts,
205+
initialize_sample_data,
206+
aggregate_project_data,
207+
load_project_ht_chunks,
208+
**kwargs
209+
)
203210

204211
def _filter_entries_ht_both_sample_types(
205212
self, wes_ht, wes_project_samples, wgs_ht, wgs_project_samples, inheritance_filter=None, quality_filter=None,
@@ -222,27 +229,21 @@ def _filter_entries_ht_both_sample_types(
222229
)
223230

224231
ch_ht = None
225-
for sample_type, family_sample_data in sample_types:
232+
family_guid_idx_map = defaultdict(dict)
233+
for sample_type, sorted_family_sample_data in sample_types:
226234
ht, ch_ht = self._filter_inheritance(
227-
ht, ch_ht, inheritance_filter, family_sample_data,
235+
ht, ch_ht, inheritance_filter, sorted_family_sample_data,
228236
annotation=sample_type.passes_inheritance_field, entries_ht_field=sample_type.family_entries_field
229237
)
238+
for family_idx, samples in enumerate(sorted_family_sample_data):
239+
family_guid = samples[0]['familyGuid']
240+
family_guid_idx_map[family_guid][sample_type.value] = family_idx
230241

231-
family_idx_map = self._build_family_index_map(sample_types, sorted_wes_family_sample_data, sorted_wgs_family_sample_data)
242+
family_idx_map = hl.dict(family_guid_idx_map)
232243
ht = self._apply_multi_sample_type_entry_filters(ht, family_idx_map)
233244
ch_ht = self._apply_multi_sample_type_entry_filters(ch_ht, family_idx_map)
234-
235245
return ht, ch_ht
236246

237-
@staticmethod
238-
def _build_family_index_map(sample_types, sorted_wes_family_sample_data, sorted_wgs_family_sample_data):
239-
family_guid_idx_map = defaultdict(dict)
240-
for sample_type, sorted_family_sample_data in sample_types:
241-
for family_idx, samples in enumerate(sorted_family_sample_data):
242-
family_guid = samples[0]['familyGuid']
243-
family_guid_idx_map[family_guid][sample_type.value] = family_idx
244-
return hl.dict(family_guid_idx_map)
245-
246247
def _apply_multi_sample_type_entry_filters(self, ht, family_idx_map):
247248
if ht is None:
248249
return ht
@@ -269,6 +270,9 @@ def _apply_multi_sample_type_entry_filters(self, ht, family_idx_map):
269270
ht[SampleType.WGS.family_entries_field], hl.empty_array(ht[SampleType.WGS.family_entries_field].dtype.element_type)
270271
)).filter(lambda entries: entries.any(hl.is_defined))
271272
)
273+
ht = ht.select('family_entries')
274+
ht = ht.select_globals('family_guids')
275+
272276
# Filter out families with no valid entries in either sample type
273277
return ht.filter(ht.family_entries.any(hl.is_defined))
274278

@@ -297,9 +301,13 @@ def _get_sample_genotype(self, samples, r=None, include_genotype_overrides=False
297301
if not self._has_both_sample_types:
298302
return super()._get_sample_genotype(samples, r, include_genotype_overrides, select_fields)
299303

300-
return samples.map(lambda sample: self._select_genotype_for_sample(
301-
sample, r, include_genotype_overrides, select_fields
302-
))
304+
return hl.if_else(
305+
hl.len(hl.set(samples.map(lambda sample: sample.sampleType))) > 1,
306+
samples.map(lambda sample: self._select_genotype_for_sample(
307+
sample, r, include_genotype_overrides, select_fields
308+
)),
309+
[super()._get_sample_genotype(samples, r, include_genotype_overrides, select_fields)]
310+
)
303311

304312
@staticmethod
305313
def _selected_main_transcript_expr(ht):

0 commit comments

Comments
 (0)