Skip to content

Commit

Permalink
generalize loading project hts and checking for multiple sample types…
Browse files Browse the repository at this point in the history
… on project table load"
  • Loading branch information
jklugherz committed Oct 3, 2024
1 parent 5442a7f commit 31136be
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 93 deletions.
71 changes: 43 additions & 28 deletions hail_search/queries/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@
# https://github.com/broadinstitute/seqr-private/issues/1283#issuecomment-1973392719
MAX_PARTITIONS = 12

# Need to chunk tables or else evaluating table globals throws LineTooLong exception
# However, minimizing number of chunks minimizes number of aggregations/ evals and improves performance
# Adapted from https://discuss.hail.is/t/importing-many-sample-specific-vcfs/2002/8
MAX_HTS_TO_JOIN = 64

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -340,35 +335,49 @@ def _import_and_filter_multiple_project_hts(

return self._merge_filtered_hts(filtered_comp_het_project_hts, filtered_project_hts, n_partitions)

def _load_project_hts(self, project_samples: dict, n_partitions: int, **kwargs) -> list[tuple[hl.Table, dict]]:
def _load_project_hts(
self, project_samples, n_partitions,
initialize_project_hts=None, initialize_sample_data=None, aggregate_project_data=None, load_project_ht_chunks=None,
**kwargs
):
# Need to chunk tables or else evaluating table globals throws LineTooLong exception
# However, minimizing number of chunks minimizes number of aggregations/ evals and improves performance
# Adapted from https://discuss.hail.is/t/importing-many-sample-specific-vcfs/2002/8
chunk_size = 64
all_project_hts = []
project_hts = []
sample_data = {}
for project_guid, project_sample_type_data in project_samples.items():
sample_type, family_sample_data = list(project_sample_type_data.items())[0]
project_ht = self._read_project_data(family_sample_data, project_guid, sample_type, project_hts, sample_data)
if project_ht is None:
continue
if initialize_project_hts is None:
initialize_project_hts = lambda: []

if initialize_sample_data is None:
initialize_sample_data = lambda: {}

if len(project_hts) >= MAX_HTS_TO_JOIN:
if aggregate_project_data is None:
def aggregate_project_data(family_sample_data, ht, project_hts, sample_data, _):
project_hts.append(ht)
sample_data.update(family_sample_data)

if load_project_ht_chunks is None:
def load_project_ht_chunks(all_project_hts, n_partitions, project_hts, sample_data, **kwargs):
ht = self._prefilter_merged_project_hts(project_hts, n_partitions, **kwargs)
all_project_hts.append((ht, sample_data))
project_hts = []
sample_data = {}

if project_hts:
ht = self._prefilter_merged_project_hts(project_hts, n_partitions, **kwargs)
all_project_hts.append((ht, sample_data))
return all_project_hts
project_hts = initialize_project_hts()
sample_data = initialize_sample_data()

def _read_project_data(self, family_sample_data, project_guid, sample_type, project_hts, sample_data):
project_ht = self._read_project_table(project_guid, sample_type)
if project_ht is not None:
project_ht = project_ht.select_globals('sample_type', 'family_guids', 'family_samples')
project_hts.append(project_ht)
sample_data.update(family_sample_data)

return project_ht
for project_guid, project_sample_type_data in project_samples.items():
for sample_type, family_sample_data in project_sample_type_data.items():
project_ht = self._read_project_data(project_guid, sample_type)
if project_ht is None:
continue
aggregate_project_data(family_sample_data, project_ht, project_hts, sample_data, sample_type)

if len(project_hts) >= chunk_size:
load_project_ht_chunks(all_project_hts, n_partitions, project_hts, sample_data, **kwargs)
project_hts = initialize_project_hts()
sample_data = initialize_sample_data()

load_project_ht_chunks(all_project_hts, n_partitions, project_hts, sample_data, **kwargs)
return all_project_hts

def import_filtered_table(self, project_samples: dict, num_families: int, **kwargs):
if num_families == 1 or len(project_samples) == 1:
Expand Down Expand Up @@ -426,6 +435,12 @@ def _load_project_ht(
def _read_project_table(self, project_guid: str, sample_type: str):
return self._read_table(f'projects/{sample_type}/{project_guid}.ht')

def _read_project_data(self, project_guid: str, sample_type: str):
project_ht = self._read_project_table(project_guid, sample_type)
if project_ht is not None:
project_ht = project_ht.select_globals('sample_type', 'family_guids', 'family_samples')
return project_ht

def _prefilter_merged_project_hts(self, project_hts, n_partitions, **kwargs):
ht = self._merge_project_hts(project_hts, n_partitions, include_all_globals=True)
return self._prefilter_entries_table(ht, **kwargs)
Expand Down
138 changes: 73 additions & 65 deletions hail_search/queries/mito.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
PATHOGENICTY_SORT_KEY, CONSEQUENCE_SORT, \
PATHOGENICTY_HGMD_SORT_KEY, MAX_LOAD_INTERVALS
from hail_search.definitions import SampleType
from hail_search.queries.base import BaseHailTableQuery, PredictionPath, QualityFilterFormat, MAX_PARTITIONS, \
MAX_HTS_TO_JOIN
from hail_search.queries.base import BaseHailTableQuery, PredictionPath, QualityFilterFormat, MAX_PARTITIONS

REFERENCE_DATASETS_DIR = os.environ.get('REFERENCE_DATASETS_DIR', '/seqr/seqr-reference-data')
REFERENCE_DATASET_SUBDIR = 'cached_reference_dataset_queries'
Expand Down Expand Up @@ -145,61 +144,69 @@ def _import_and_filter_entries_ht(
def _import_and_filter_multiple_project_hts(
self, project_samples: dict, n_partitions=MAX_PARTITIONS, **kwargs
) -> tuple[hl.Table, hl.Table]:
sample_types = set()
for sample_dict in project_samples.values():
sample_types.update(sample_dict.keys())
if len(sample_types) == 1:
return super()._import_and_filter_multiple_project_hts(project_samples, n_partitions, **kwargs)

self._has_both_sample_types = True
entries = self._load_project_hts_both_sample_types(project_samples, n_partitions, **kwargs)
single_sample_type_project_samples = {}
both_sample_type_project_samples = {}
for project_guid, sample_dict in project_samples.items():
if len(sample_dict) == 1:
single_sample_type_project_samples[project_guid] = sample_dict
else:
both_sample_type_project_samples[project_guid] = sample_dict

filtered_project_hts = []
filtered_comp_het_project_hts = []
for entry in entries:
wes_ht, wes_project_samples = entry[SampleType.WES.value]
wgs_ht, wgs_project_samples = entry[SampleType.WGS.value]
ht, comp_het_ht = self._filter_entries_ht_both_sample_types(
wes_ht, wes_project_samples, wgs_ht, wgs_project_samples,
is_merged_ht=True, **kwargs
)

# Process projects with only one sample type separately
if single_sample_type_project_samples:
ht, ch_ht = super()._import_and_filter_multiple_project_hts(single_sample_type_project_samples, n_partitions, **kwargs)
if ht is not None:
filtered_project_hts.append(ht)
if comp_het_ht is not None:
filtered_comp_het_project_hts.append(comp_het_ht)

return self._merge_filtered_hts(filtered_comp_het_project_hts, filtered_project_hts, n_partitions)

def _load_project_hts_both_sample_types(
self, project_samples: dict, n_partitions: int, **kwargs
) -> list[dict[str, tuple[hl.Table, dict]]]:
all_project_hts = []
project_hts = defaultdict(list)
sample_data = defaultdict(dict)

for project_guid, project_sample_type_data in project_samples.items():
for sample_type, family_sample_data in project_sample_type_data.items():
project_ht = self._read_project_data(
family_sample_data, project_guid, sample_type, project_hts[sample_type], sample_data[sample_type]
if ch_ht is not None:
filtered_comp_het_project_hts.append(ch_ht)

if both_sample_type_project_samples:
self._has_both_sample_types = True
entries = self._load_project_hts_both_sample_types(project_samples, n_partitions, **kwargs)
for entry in entries:
wes_ht, wes_project_samples = entry[SampleType.WES.value]
wgs_ht, wgs_project_samples = entry[SampleType.WGS.value]
ht, comp_het_ht = self._filter_entries_ht_both_sample_types(
wes_ht, wes_project_samples, wgs_ht, wgs_project_samples,
is_merged_ht=True, **kwargs
)
if project_ht is None:
continue
if ht is not None:
filtered_project_hts.append(ht)
if comp_het_ht is not None:
filtered_comp_het_project_hts.append(comp_het_ht)

# Merge both WES and WGS project_hts when either of their lengths reaches the chunk size
if len(project_hts[SampleType.WES.value]) >= MAX_HTS_TO_JOIN or len(project_hts[SampleType.WGS.value]) >= MAX_HTS_TO_JOIN:
self._load_project_ht_chunks(all_project_hts, kwargs, n_partitions, project_hts, sample_data)
project_hts = defaultdict(list)
sample_data = defaultdict(dict)

self._load_project_ht_chunks(all_project_hts, kwargs, n_partitions, project_hts, sample_data)
return all_project_hts
return self._merge_filtered_hts(filtered_comp_het_project_hts, filtered_project_hts, n_partitions)

def _load_project_ht_chunks(self, all_project_hts, kwargs, n_partitions, project_hts, sample_data):
project_ht_dict = {}
for sample_type in project_hts:
ht = self._prefilter_merged_project_hts(project_hts[sample_type], n_partitions, **kwargs)
project_ht_dict[sample_type] = (ht, sample_data[sample_type])
all_project_hts.append(project_ht_dict)
def _load_project_hts_both_sample_types(self, project_samples, n_partitions, **kwargs):
def initialize_project_hts():
return defaultdict(list)

def initialize_sample_data():
return defaultdict(dict)

def aggregate_project_data(family_sample_data, ht, project_hts, sample_data, sample_type):
project_hts[sample_type].append(ht)
sample_data[sample_type].update(family_sample_data)

def load_project_ht_chunks(all_project_hts, n_partitions, project_hts, sample_data, **kwargs):
project_ht_dict = {}
for sample_type in project_hts:
ht = self._prefilter_merged_project_hts(project_hts[sample_type], n_partitions, **kwargs)
project_ht_dict[sample_type] = (ht, sample_data[sample_type])
all_project_hts.append(project_ht_dict)

return super()._load_project_hts(
project_samples,
n_partitions,
initialize_project_hts,
initialize_sample_data,
aggregate_project_data,
load_project_ht_chunks,
**kwargs
)

def _filter_entries_ht_both_sample_types(
self, wes_ht, wes_project_samples, wgs_ht, wgs_project_samples, inheritance_filter=None, quality_filter=None,
Expand All @@ -222,27 +229,21 @@ def _filter_entries_ht_both_sample_types(
)

ch_ht = None
for sample_type, family_sample_data in sample_types:
family_guid_idx_map = defaultdict(dict)
for sample_type, sorted_family_sample_data in sample_types:
ht, ch_ht = self._filter_inheritance(
ht, ch_ht, inheritance_filter, family_sample_data,
ht, ch_ht, inheritance_filter, sorted_family_sample_data,
annotation=sample_type.passes_inheritance_field, entries_ht_field=sample_type.family_entries_field
)
for family_idx, samples in enumerate(sorted_family_sample_data):
family_guid = samples[0]['familyGuid']
family_guid_idx_map[family_guid][sample_type.value] = family_idx

family_idx_map = self._build_family_index_map(sample_types, sorted_wes_family_sample_data, sorted_wgs_family_sample_data)
family_idx_map = hl.dict(family_guid_idx_map)
ht = self._apply_multi_sample_type_entry_filters(ht, family_idx_map)
ch_ht = self._apply_multi_sample_type_entry_filters(ch_ht, family_idx_map)

return ht, ch_ht

@staticmethod
def _build_family_index_map(sample_types, sorted_wes_family_sample_data, sorted_wgs_family_sample_data):
family_guid_idx_map = defaultdict(dict)
for sample_type, sorted_family_sample_data in sample_types:
for family_idx, samples in enumerate(sorted_family_sample_data):
family_guid = samples[0]['familyGuid']
family_guid_idx_map[family_guid][sample_type.value] = family_idx
return hl.dict(family_guid_idx_map)

def _apply_multi_sample_type_entry_filters(self, ht, family_idx_map):
if ht is None:
return ht
Expand All @@ -269,6 +270,9 @@ def _apply_multi_sample_type_entry_filters(self, ht, family_idx_map):
ht[SampleType.WGS.family_entries_field], hl.empty_array(ht[SampleType.WGS.family_entries_field].dtype.element_type)
)).filter(lambda entries: entries.any(hl.is_defined))
)
ht = ht.select('family_entries')
ht = ht.select_globals('family_guids')

# Filter out families with no valid entries in either sample type
return ht.filter(ht.family_entries.any(hl.is_defined))

Expand Down Expand Up @@ -297,9 +301,13 @@ def _get_sample_genotype(self, samples, r=None, include_genotype_overrides=False
if not self._has_both_sample_types:
return super()._get_sample_genotype(samples, r, include_genotype_overrides, select_fields)

return samples.map(lambda sample: self._select_genotype_for_sample(
sample, r, include_genotype_overrides, select_fields
))
return hl.if_else(
hl.len(hl.set(samples.map(lambda sample: sample.sampleType))) > 1,
samples.map(lambda sample: self._select_genotype_for_sample(
sample, r, include_genotype_overrides, select_fields
)),
[super()._get_sample_genotype(samples, r, include_genotype_overrides, select_fields)]
)

@staticmethod
def _selected_main_transcript_expr(ht):
Expand Down

0 comments on commit 31136be

Please sign in to comment.