11
11
PATHOGENICTY_SORT_KEY , CONSEQUENCE_SORT , \
12
12
PATHOGENICTY_HGMD_SORT_KEY , MAX_LOAD_INTERVALS
13
13
from hail_search .definitions import SampleType
14
- from hail_search .queries .base import BaseHailTableQuery , PredictionPath , QualityFilterFormat , MAX_PARTITIONS , \
15
- MAX_HTS_TO_JOIN
14
+ from hail_search .queries .base import BaseHailTableQuery , PredictionPath , QualityFilterFormat , MAX_PARTITIONS
16
15
17
16
REFERENCE_DATASETS_DIR = os .environ .get ('REFERENCE_DATASETS_DIR' , '/seqr/seqr-reference-data' )
18
17
REFERENCE_DATASET_SUBDIR = 'cached_reference_dataset_queries'
@@ -145,61 +144,69 @@ def _import_and_filter_entries_ht(
145
144
def _import_and_filter_multiple_project_hts (
146
145
self , project_samples : dict , n_partitions = MAX_PARTITIONS , ** kwargs
147
146
) -> tuple [hl .Table , hl .Table ]:
148
- sample_types = set ()
149
- for sample_dict in project_samples .values ():
150
- sample_types .update (sample_dict .keys ())
151
- if len (sample_types ) == 1 :
152
- return super ()._import_and_filter_multiple_project_hts (project_samples , n_partitions , ** kwargs )
153
-
154
- self ._has_both_sample_types = True
155
- entries = self ._load_project_hts_both_sample_types (project_samples , n_partitions , ** kwargs )
147
+ single_sample_type_project_samples = {}
148
+ both_sample_type_project_samples = {}
149
+ for project_guid , sample_dict in project_samples .items ():
150
+ if len (sample_dict ) == 1 :
151
+ single_sample_type_project_samples [project_guid ] = sample_dict
152
+ else :
153
+ both_sample_type_project_samples [project_guid ] = sample_dict
156
154
157
155
filtered_project_hts = []
158
156
filtered_comp_het_project_hts = []
159
- for entry in entries :
160
- wes_ht , wes_project_samples = entry [SampleType .WES .value ]
161
- wgs_ht , wgs_project_samples = entry [SampleType .WGS .value ]
162
- ht , comp_het_ht = self ._filter_entries_ht_both_sample_types (
163
- wes_ht , wes_project_samples , wgs_ht , wgs_project_samples ,
164
- is_merged_ht = True , ** kwargs
165
- )
157
+
158
+ # Process projects with only one sample type separately
159
+ if single_sample_type_project_samples :
160
+ ht , ch_ht = super ()._import_and_filter_multiple_project_hts (single_sample_type_project_samples , n_partitions , ** kwargs )
166
161
if ht is not None :
167
162
filtered_project_hts .append (ht )
168
- if comp_het_ht is not None :
169
- filtered_comp_het_project_hts .append (comp_het_ht )
170
-
171
- return self ._merge_filtered_hts (filtered_comp_het_project_hts , filtered_project_hts , n_partitions )
172
-
173
- def _load_project_hts_both_sample_types (
174
- self , project_samples : dict , n_partitions : int , ** kwargs
175
- ) -> list [dict [str , tuple [hl .Table , dict ]]]:
176
- all_project_hts = []
177
- project_hts = defaultdict (list )
178
- sample_data = defaultdict (dict )
179
-
180
- for project_guid , project_sample_type_data in project_samples .items ():
181
- for sample_type , family_sample_data in project_sample_type_data .items ():
182
- project_ht = self ._read_project_data (
183
- family_sample_data , project_guid , sample_type , project_hts [sample_type ], sample_data [sample_type ]
163
+ if ch_ht is not None :
164
+ filtered_comp_het_project_hts .append (ch_ht )
165
+
166
+ if both_sample_type_project_samples :
167
+ self ._has_both_sample_types = True
168
+ entries = self ._load_project_hts_both_sample_types (project_samples , n_partitions , ** kwargs )
169
+ for entry in entries :
170
+ wes_ht , wes_project_samples = entry [SampleType .WES .value ]
171
+ wgs_ht , wgs_project_samples = entry [SampleType .WGS .value ]
172
+ ht , comp_het_ht = self ._filter_entries_ht_both_sample_types (
173
+ wes_ht , wes_project_samples , wgs_ht , wgs_project_samples ,
174
+ is_merged_ht = True , ** kwargs
184
175
)
185
- if project_ht is None :
186
- continue
176
+ if ht is not None :
177
+ filtered_project_hts .append (ht )
178
+ if comp_het_ht is not None :
179
+ filtered_comp_het_project_hts .append (comp_het_ht )
187
180
188
- # Merge both WES and WGS project_hts when either of their lengths reaches the chunk size
189
- if len (project_hts [SampleType .WES .value ]) >= MAX_HTS_TO_JOIN or len (project_hts [SampleType .WGS .value ]) >= MAX_HTS_TO_JOIN :
190
- self ._load_project_ht_chunks (all_project_hts , kwargs , n_partitions , project_hts , sample_data )
191
- project_hts = defaultdict (list )
192
- sample_data = defaultdict (dict )
193
-
194
- self ._load_project_ht_chunks (all_project_hts , kwargs , n_partitions , project_hts , sample_data )
195
- return all_project_hts
181
+ return self ._merge_filtered_hts (filtered_comp_het_project_hts , filtered_project_hts , n_partitions )
196
182
197
- def _load_project_ht_chunks (self , all_project_hts , kwargs , n_partitions , project_hts , sample_data ):
198
- project_ht_dict = {}
199
- for sample_type in project_hts :
200
- ht = self ._prefilter_merged_project_hts (project_hts [sample_type ], n_partitions , ** kwargs )
201
- project_ht_dict [sample_type ] = (ht , sample_data [sample_type ])
202
- all_project_hts .append (project_ht_dict )
183
+ def _load_project_hts_both_sample_types (self , project_samples , n_partitions , ** kwargs ):
184
+ def initialize_project_hts ():
185
+ return defaultdict (list )
186
+
187
+ def initialize_sample_data ():
188
+ return defaultdict (dict )
189
+
190
+ def aggregate_project_data (family_sample_data , ht , project_hts , sample_data , sample_type ):
191
+ project_hts [sample_type ].append (ht )
192
+ sample_data [sample_type ].update (family_sample_data )
193
+
194
+ def load_project_ht_chunks (all_project_hts , n_partitions , project_hts , sample_data , ** kwargs ):
195
+ project_ht_dict = {}
196
+ for sample_type in project_hts :
197
+ ht = self ._prefilter_merged_project_hts (project_hts [sample_type ], n_partitions , ** kwargs )
198
+ project_ht_dict [sample_type ] = (ht , sample_data [sample_type ])
199
+ all_project_hts .append (project_ht_dict )
200
+
201
+ return super ()._load_project_hts (
202
+ project_samples ,
203
+ n_partitions ,
204
+ initialize_project_hts ,
205
+ initialize_sample_data ,
206
+ aggregate_project_data ,
207
+ load_project_ht_chunks ,
208
+ ** kwargs
209
+ )
203
210
204
211
def _filter_entries_ht_both_sample_types (
205
212
self , wes_ht , wes_project_samples , wgs_ht , wgs_project_samples , inheritance_filter = None , quality_filter = None ,
@@ -222,27 +229,21 @@ def _filter_entries_ht_both_sample_types(
222
229
)
223
230
224
231
ch_ht = None
225
- for sample_type , family_sample_data in sample_types :
232
+ family_guid_idx_map = defaultdict (dict )
233
+ for sample_type , sorted_family_sample_data in sample_types :
226
234
ht , ch_ht = self ._filter_inheritance (
227
- ht , ch_ht , inheritance_filter , family_sample_data ,
235
+ ht , ch_ht , inheritance_filter , sorted_family_sample_data ,
228
236
annotation = sample_type .passes_inheritance_field , entries_ht_field = sample_type .family_entries_field
229
237
)
238
+ for family_idx , samples in enumerate (sorted_family_sample_data ):
239
+ family_guid = samples [0 ]['familyGuid' ]
240
+ family_guid_idx_map [family_guid ][sample_type .value ] = family_idx
230
241
231
- family_idx_map = self . _build_family_index_map ( sample_types , sorted_wes_family_sample_data , sorted_wgs_family_sample_data )
242
+ family_idx_map = hl . dict ( family_guid_idx_map )
232
243
ht = self ._apply_multi_sample_type_entry_filters (ht , family_idx_map )
233
244
ch_ht = self ._apply_multi_sample_type_entry_filters (ch_ht , family_idx_map )
234
-
235
245
return ht , ch_ht
236
246
237
- @staticmethod
238
- def _build_family_index_map (sample_types , sorted_wes_family_sample_data , sorted_wgs_family_sample_data ):
239
- family_guid_idx_map = defaultdict (dict )
240
- for sample_type , sorted_family_sample_data in sample_types :
241
- for family_idx , samples in enumerate (sorted_family_sample_data ):
242
- family_guid = samples [0 ]['familyGuid' ]
243
- family_guid_idx_map [family_guid ][sample_type .value ] = family_idx
244
- return hl .dict (family_guid_idx_map )
245
-
246
247
def _apply_multi_sample_type_entry_filters (self , ht , family_idx_map ):
247
248
if ht is None :
248
249
return ht
@@ -269,6 +270,9 @@ def _apply_multi_sample_type_entry_filters(self, ht, family_idx_map):
269
270
ht [SampleType .WGS .family_entries_field ], hl .empty_array (ht [SampleType .WGS .family_entries_field ].dtype .element_type )
270
271
)).filter (lambda entries : entries .any (hl .is_defined ))
271
272
)
273
+ ht = ht .select ('family_entries' )
274
+ ht = ht .select_globals ('family_guids' )
275
+
272
276
# Filter out families with no valid entries in either sample type
273
277
return ht .filter (ht .family_entries .any (hl .is_defined ))
274
278
@@ -297,9 +301,13 @@ def _get_sample_genotype(self, samples, r=None, include_genotype_overrides=False
297
301
if not self ._has_both_sample_types :
298
302
return super ()._get_sample_genotype (samples , r , include_genotype_overrides , select_fields )
299
303
300
- return samples .map (lambda sample : self ._select_genotype_for_sample (
301
- sample , r , include_genotype_overrides , select_fields
302
- ))
304
+ return hl .if_else (
305
+ hl .len (hl .set (samples .map (lambda sample : sample .sampleType ))) > 1 ,
306
+ samples .map (lambda sample : self ._select_genotype_for_sample (
307
+ sample , r , include_genotype_overrides , select_fields
308
+ )),
309
+ [super ()._get_sample_genotype (samples , r , include_genotype_overrides , select_fields )]
310
+ )
303
311
304
312
@staticmethod
305
313
def _selected_main_transcript_expr (ht ):
0 commit comments