diff --git a/hail_search/queries/base.py b/hail_search/queries/base.py index cffbbb08c2..224b91933f 100644 --- a/hail_search/queries/base.py +++ b/hail_search/queries/base.py @@ -21,11 +21,6 @@ # https://github.com/broadinstitute/seqr-private/issues/1283#issuecomment-1973392719 MAX_PARTITIONS = 12 -# Need to chunk tables or else evaluating table globals throws LineTooLong exception -# However, minimizing number of chunks minimizes number of aggregations/ evals and improves performance -# Adapted from https://discuss.hail.is/t/importing-many-sample-specific-vcfs/2002/8 -MAX_HTS_TO_JOIN = 64 - logger = logging.getLogger(__name__) @@ -340,35 +335,49 @@ def _import_and_filter_multiple_project_hts( return self._merge_filtered_hts(filtered_comp_het_project_hts, filtered_project_hts, n_partitions) - def _load_project_hts(self, project_samples: dict, n_partitions: int, **kwargs) -> list[tuple[hl.Table, dict]]: + def _load_project_hts( + self, project_samples, n_partitions, + initialize_project_hts=None, initialize_sample_data=None, aggregate_project_data=None, load_project_ht_chunks=None, + **kwargs + ): + # Need to chunk tables or else evaluating table globals throws LineTooLong exception + # However, minimizing number of chunks minimizes number of aggregations/ evals and improves performance + # Adapted from https://discuss.hail.is/t/importing-many-sample-specific-vcfs/2002/8 + chunk_size = 64 all_project_hts = [] - project_hts = [] - sample_data = {} - for project_guid, project_sample_type_data in project_samples.items(): - sample_type, family_sample_data = list(project_sample_type_data.items())[0] - project_ht = self._read_project_data(family_sample_data, project_guid, sample_type, project_hts, sample_data) - if project_ht is None: - continue + if initialize_project_hts is None: + initialize_project_hts = lambda: [] + + if initialize_sample_data is None: + initialize_sample_data = lambda: {} - if len(project_hts) >= MAX_HTS_TO_JOIN: + if aggregate_project_data is None: + def aggregate_project_data(family_sample_data, ht, project_hts, sample_data, _): + project_hts.append(ht) + sample_data.update(family_sample_data) + + if load_project_ht_chunks is None: + def load_project_ht_chunks(all_project_hts, n_partitions, project_hts, sample_data, **kwargs): ht = self._prefilter_merged_project_hts(project_hts, n_partitions, **kwargs) all_project_hts.append((ht, sample_data)) - project_hts = [] - sample_data = {} - if project_hts: - ht = self._prefilter_merged_project_hts(project_hts, n_partitions, **kwargs) - all_project_hts.append((ht, sample_data)) - return all_project_hts + project_hts = initialize_project_hts() + sample_data = initialize_sample_data() - def _read_project_data(self, family_sample_data, project_guid, sample_type, project_hts, sample_data): - project_ht = self._read_project_table(project_guid, sample_type) - if project_ht is not None: - project_ht = project_ht.select_globals('sample_type', 'family_guids', 'family_samples') - project_hts.append(project_ht) - sample_data.update(family_sample_data) - - return project_ht + for project_guid, project_sample_type_data in project_samples.items(): + for sample_type, family_sample_data in project_sample_type_data.items(): + project_ht = self._read_project_data(project_guid, sample_type) + if project_ht is None: + continue + aggregate_project_data(family_sample_data, project_ht, project_hts, sample_data, sample_type) + + if len(project_hts) >= chunk_size: + load_project_ht_chunks(all_project_hts, n_partitions, project_hts, sample_data, **kwargs) + project_hts = initialize_project_hts() + sample_data = initialize_sample_data() + + load_project_ht_chunks(all_project_hts, n_partitions, project_hts, sample_data, **kwargs) + return all_project_hts def import_filtered_table(self, project_samples: dict, num_families: int, **kwargs): if num_families == 1 or len(project_samples) == 1: @@ -426,6 +435,12 @@ def _load_project_ht( def _read_project_table(self, project_guid: str, sample_type: str): return self._read_table(f'projects/{sample_type}/{project_guid}.ht') + def _read_project_data(self, project_guid: str, sample_type: str): + project_ht = self._read_project_table(project_guid, sample_type) + if project_ht is not None: + project_ht = project_ht.select_globals('sample_type', 'family_guids', 'family_samples') + return project_ht + def _prefilter_merged_project_hts(self, project_hts, n_partitions, **kwargs): ht = self._merge_project_hts(project_hts, n_partitions, include_all_globals=True) return self._prefilter_entries_table(ht, **kwargs) diff --git a/hail_search/queries/mito.py b/hail_search/queries/mito.py index f0ba1f0429..b81b58a4a6 100644 --- a/hail_search/queries/mito.py +++ b/hail_search/queries/mito.py @@ -11,8 +11,7 @@ PATHOGENICTY_SORT_KEY, CONSEQUENCE_SORT, \ PATHOGENICTY_HGMD_SORT_KEY, MAX_LOAD_INTERVALS from hail_search.definitions import SampleType -from hail_search.queries.base import BaseHailTableQuery, PredictionPath, QualityFilterFormat, MAX_PARTITIONS, \ - MAX_HTS_TO_JOIN +from hail_search.queries.base import BaseHailTableQuery, PredictionPath, QualityFilterFormat, MAX_PARTITIONS REFERENCE_DATASETS_DIR = os.environ.get('REFERENCE_DATASETS_DIR', '/seqr/seqr-reference-data') REFERENCE_DATASET_SUBDIR = 'cached_reference_dataset_queries' @@ -145,61 +144,69 @@ def _import_and_filter_entries_ht( def _import_and_filter_multiple_project_hts( self, project_samples: dict, n_partitions=MAX_PARTITIONS, **kwargs ) -> tuple[hl.Table, hl.Table]: - sample_types = set() - for sample_dict in project_samples.values(): - sample_types.update(sample_dict.keys()) - if len(sample_types) == 1: - return super()._import_and_filter_multiple_project_hts(project_samples, n_partitions, **kwargs) - - self._has_both_sample_types = True - entries = self._load_project_hts_both_sample_types(project_samples, n_partitions, **kwargs) + single_sample_type_project_samples = {} + both_sample_type_project_samples = {} + for project_guid, sample_dict in project_samples.items(): + if len(sample_dict) == 1: + single_sample_type_project_samples[project_guid] = sample_dict + else: + both_sample_type_project_samples[project_guid] = sample_dict filtered_project_hts = [] filtered_comp_het_project_hts = [] - for entry in entries: - wes_ht, wes_project_samples = entry[SampleType.WES.value] - wgs_ht, wgs_project_samples = entry[SampleType.WGS.value] - ht, comp_het_ht = self._filter_entries_ht_both_sample_types( - wes_ht, wes_project_samples, wgs_ht, wgs_project_samples, - is_merged_ht=True, **kwargs - ) + + # Process projects with only one sample type separately + if single_sample_type_project_samples: + ht, ch_ht = super()._import_and_filter_multiple_project_hts(single_sample_type_project_samples, n_partitions, **kwargs) if ht is not None: filtered_project_hts.append(ht) - if comp_het_ht is not None: - filtered_comp_het_project_hts.append(comp_het_ht) - - return self._merge_filtered_hts(filtered_comp_het_project_hts, filtered_project_hts, n_partitions) - - def _load_project_hts_both_sample_types( - self, project_samples: dict, n_partitions: int, **kwargs - ) -> list[dict[str, tuple[hl.Table, dict]]]: - all_project_hts = [] - project_hts = defaultdict(list) - sample_data = defaultdict(dict) - - for project_guid, project_sample_type_data in project_samples.items(): - for sample_type, family_sample_data in project_sample_type_data.items(): - project_ht = self._read_project_data( - family_sample_data, project_guid, sample_type, project_hts[sample_type], sample_data[sample_type] + if ch_ht is not None: + filtered_comp_het_project_hts.append(ch_ht) + + if both_sample_type_project_samples: + self._has_both_sample_types = True + entries = self._load_project_hts_both_sample_types(project_samples, n_partitions, **kwargs) + for entry in entries: + wes_ht, wes_project_samples = entry[SampleType.WES.value] + wgs_ht, wgs_project_samples = entry[SampleType.WGS.value] + ht, comp_het_ht = self._filter_entries_ht_both_sample_types( + wes_ht, wes_project_samples, wgs_ht, wgs_project_samples, + is_merged_ht=True, **kwargs ) - if project_ht is None: - continue + if ht is not None: + filtered_project_hts.append(ht) + if comp_het_ht is not None: + filtered_comp_het_project_hts.append(comp_het_ht) - # Merge both WES and WGS project_hts when either of their lengths reaches the chunk size - if len(project_hts[SampleType.WES.value]) >= MAX_HTS_TO_JOIN or len(project_hts[SampleType.WGS.value]) >= MAX_HTS_TO_JOIN: - self._load_project_ht_chunks(all_project_hts, kwargs, n_partitions, project_hts, sample_data) - project_hts = defaultdict(list) - sample_data = defaultdict(dict) - - self._load_project_ht_chunks(all_project_hts, kwargs, n_partitions, project_hts, sample_data) - return all_project_hts + return self._merge_filtered_hts(filtered_comp_het_project_hts, filtered_project_hts, n_partitions) - def _load_project_ht_chunks(self, all_project_hts, kwargs, n_partitions, project_hts, sample_data): - project_ht_dict = {} - for sample_type in project_hts: - ht = self._prefilter_merged_project_hts(project_hts[sample_type], n_partitions, **kwargs) - project_ht_dict[sample_type] = (ht, sample_data[sample_type]) - all_project_hts.append(project_ht_dict) + def _load_project_hts_both_sample_types(self, project_samples, n_partitions, **kwargs): + def initialize_project_hts(): + return defaultdict(list) + + def initialize_sample_data(): + return defaultdict(dict) + + def aggregate_project_data(family_sample_data, ht, project_hts, sample_data, sample_type): + project_hts[sample_type].append(ht) + sample_data[sample_type].update(family_sample_data) + + def load_project_ht_chunks(all_project_hts, n_partitions, project_hts, sample_data, **kwargs): + project_ht_dict = {} + for sample_type in project_hts: + ht = self._prefilter_merged_project_hts(project_hts[sample_type], n_partitions, **kwargs) + project_ht_dict[sample_type] = (ht, sample_data[sample_type]) + all_project_hts.append(project_ht_dict) + + return super()._load_project_hts( + project_samples, + n_partitions, + initialize_project_hts, + initialize_sample_data, + aggregate_project_data, + load_project_ht_chunks, + **kwargs + ) def _filter_entries_ht_both_sample_types( self, wes_ht, wes_project_samples, wgs_ht, wgs_project_samples, inheritance_filter=None, quality_filter=None, @@ -222,27 +229,21 @@ def _filter_entries_ht_both_sample_types( ) ch_ht = None - for sample_type, family_sample_data in sample_types: + family_guid_idx_map = defaultdict(dict) + for sample_type, sorted_family_sample_data in sample_types: ht, ch_ht = self._filter_inheritance( - ht, ch_ht, inheritance_filter, family_sample_data, + ht, ch_ht, inheritance_filter, sorted_family_sample_data, annotation=sample_type.passes_inheritance_field, entries_ht_field=sample_type.family_entries_field ) + for family_idx, samples in enumerate(sorted_family_sample_data): + family_guid = samples[0]['familyGuid'] + family_guid_idx_map[family_guid][sample_type.value] = family_idx - family_idx_map = self._build_family_index_map(sample_types, sorted_wes_family_sample_data, sorted_wgs_family_sample_data) + family_idx_map = hl.dict(family_guid_idx_map) ht = self._apply_multi_sample_type_entry_filters(ht, family_idx_map) ch_ht = self._apply_multi_sample_type_entry_filters(ch_ht, family_idx_map) - return ht, ch_ht - @staticmethod - def _build_family_index_map(sample_types, sorted_wes_family_sample_data, sorted_wgs_family_sample_data): - family_guid_idx_map = defaultdict(dict) - for sample_type, sorted_family_sample_data in sample_types: - for family_idx, samples in enumerate(sorted_family_sample_data): - family_guid = samples[0]['familyGuid'] - family_guid_idx_map[family_guid][sample_type.value] = family_idx - return hl.dict(family_guid_idx_map) - def _apply_multi_sample_type_entry_filters(self, ht, family_idx_map): if ht is None: return ht @@ -269,6 +270,9 @@ def _apply_multi_sample_type_entry_filters(self, ht, family_idx_map): ht[SampleType.WGS.family_entries_field], hl.empty_array(ht[SampleType.WGS.family_entries_field].dtype.element_type) )).filter(lambda entries: entries.any(hl.is_defined)) ) + ht = ht.select('family_entries') + ht = ht.select_globals('family_guids') + # Filter out families with no valid entries in either sample type return ht.filter(ht.family_entries.any(hl.is_defined)) @@ -297,9 +301,13 @@ def _get_sample_genotype(self, samples, r=None, include_genotype_overrides=False if not self._has_both_sample_types: return super()._get_sample_genotype(samples, r, include_genotype_overrides, select_fields) - return samples.map(lambda sample: self._select_genotype_for_sample( - sample, r, include_genotype_overrides, select_fields - )) + return hl.if_else( + hl.len(hl.set(samples.map(lambda sample: sample.sampleType))) > 1, + samples.map(lambda sample: self._select_genotype_for_sample( + sample, r, include_genotype_overrides, select_fields + )), + [super()._get_sample_genotype(samples, r, include_genotype_overrides, select_fields)] + ) @staticmethod def _selected_main_transcript_expr(ht):