Skip to content
Merged
7 changes: 7 additions & 0 deletions hail_search/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ def family_entries_field(self) -> str:
SampleType.WGS: 'wgs_family_entries',
}[self]

@property
def failed_family_sample_field(self) -> str:
return {
SampleType.WES: f'wes_failed_family_sample_indices',
SampleType.WGS: f'wgs_failed_family_sample_indices',
}[self]

@property
def passes_inheritance_field(self) -> str:
return {
Expand Down
11 changes: 6 additions & 5 deletions hail_search/queries/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def _apply_entry_filters(ht):
def _filter_single_entries_table(self, ht, project_families, inheritance_filter=None, quality_filter=None, is_merged_ht=False, **kwargs):
ht, sorted_family_sample_data = self._add_entry_sample_families(ht, project_families, is_merged_ht)
ht = self._filter_quality(ht, quality_filter, **kwargs)
ht, ch_ht = self._filter_inheritance(
ht, ch_ht, _, _ = self._filter_inheritance(
ht, None, inheritance_filter, sorted_family_sample_data,
)
ht = self._apply_entry_filters(ht)
Expand Down Expand Up @@ -588,8 +588,9 @@ def _filter_inheritance(
lambda entries: hl.or_missing(entries.any(any_valid_entry), entries)
)})

ch_ht_entry_indices_by_gt = None
if self._has_comp_het_search:
comp_het_ht = self._annotate_families_inheritance(
comp_het_ht, ch_ht_entry_indices_by_gt = self._annotate_families_inheritance(
comp_het_ht if comp_het_ht is not None else ht, COMPOUND_HET, inheritance_filter,
sorted_family_sample_data, annotation, entries_ht_field
)
Expand All @@ -598,12 +599,12 @@ def _filter_inheritance(
# No sample-specific inheritance filtering needed
sorted_family_sample_data = []

ht = None if self._inheritance_mode == COMPOUND_HET else self._annotate_families_inheritance(
ht, ht_entry_indices_by_gt = (None, None) if self._inheritance_mode == COMPOUND_HET else self._annotate_families_inheritance(
ht, self._inheritance_mode, inheritance_filter, sorted_family_sample_data,
annotation, entries_ht_field
)

return ht, comp_het_ht
return ht, comp_het_ht, ht_entry_indices_by_gt, ch_ht_entry_indices_by_gt

def _annotate_families_inheritance(
self, ht, inheritance_mode, inheritance_filter, sorted_family_sample_data,
Expand Down Expand Up @@ -644,7 +645,7 @@ def _annotate_families_inheritance(
)
})

return ht
return ht, entry_indices_by_gt

def _get_family_passes_quality_filter(self, quality_filter, ht, **kwargs):
quality_filter = quality_filter or {}
Expand Down
118 changes: 97 additions & 21 deletions hail_search/queries/mito.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,37 +206,70 @@ def _filter_entries_ht_both_sample_types(

ch_ht = None
family_guid_idx_map = defaultdict(dict)
family_sample_idx_map = defaultdict(lambda: defaultdict(dict))
for sample_type, sorted_family_sample_data in sample_types:
ht, ch_ht = self._filter_inheritance(
ht, ch_ht, ht_idx_by_gt_map, ch_idx_by_gt_map = self._filter_inheritance(
ht, ch_ht, inheritance_filter, sorted_family_sample_data,
annotation=sample_type.passes_inheritance_field, entries_ht_field=sample_type.family_entries_field
)
ht = self._annotate_failed_family_samples_inheritance(
ht, ht_idx_by_gt_map,
annotation=sample_type.failed_family_sample_field, entries_ht_field=sample_type.family_entries_field
)
ch_ht = self._annotate_failed_family_samples_inheritance(
ch_ht, ch_idx_by_gt_map,
annotation=sample_type.failed_family_sample_field, entries_ht_field=sample_type.family_entries_field
)

for family_idx, samples in enumerate(sorted_family_sample_data):
family_guid = samples[0]['familyGuid']
family_guid_idx_map[family_guid][sample_type.value] = family_idx
for sample_idx, sample in enumerate(samples):
family_sample_idx_map[family_guid][sample['sampleId']][sample_type.value] = sample_idx

family_idx_map = hl.dict(family_guid_idx_map)
ht = self._apply_multi_sample_type_entry_filters(ht, family_idx_map)
ch_ht = self._apply_multi_sample_type_entry_filters(ch_ht, family_idx_map)
family_guid_idx_map = hl.dict(family_guid_idx_map)
family_sample_idx_map = hl.dict(family_sample_idx_map)
ht = self._apply_multi_sample_type_entry_filters(ht, family_guid_idx_map, family_sample_idx_map)
ch_ht = self._apply_multi_sample_type_entry_filters(ch_ht, family_guid_idx_map, family_sample_idx_map)
return ht, ch_ht

def _apply_multi_sample_type_entry_filters(self, ht, family_idx_map):
def _annotate_failed_family_samples_inheritance(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels incredibly similar to _annotate_families_inheritance. Rather than making a whole separate function for this, you could make a much more tightly scoped conditional helper to pass into _annotate_families_inheritance, perhaps just for the lambda function applied to the hl.enumerate(ht[entries_ht_field]).starmap(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented this, but because I don't want the single sample type families that come originally through the mito class code path to call the mito 'family_passes_inheritance_filter' function, I'm still passing a family_passes_inheritance_filter function to _filter_inheritance instead of using python inheritance here. Do you know if there's a cleaner way to do this?

self, ht, entry_indices_by_gt, annotation, entries_ht_field,
):
if ht is None:
return ht

ht = ht.annotate(**{annotation: hl.empty_dict(hl.tint32, hl.tarray(hl.tint32))})
# print(annotation, ht[annotation].collect())

for genotype, entry_indices in entry_indices_by_gt.items():
if not entry_indices:
continue
# print(genotype, entry_indices)
entry_indices = hl.dict(entry_indices)
ht = ht.annotate(
**{annotation: hl.dict(
hl.enumerate(ht[entries_ht_field]).starmap(
lambda family_index, entries: hl.bind(
lambda failed_samples: hl.tuple((
family_index,
ht[annotation].get(family_index, hl.empty_array(hl.tint32)).extend(failed_samples)
)),
entry_indices.get(family_index).filter(lambda sample_i: ~self.GENOTYPE_QUERY_MAP[genotype](entries[sample_i].GT))
)
)
)})
# print(annotation, ht[annotation].collect())
return ht

def _apply_multi_sample_type_entry_filters(self, ht, family_idx_map, sample_idx_map):
if ht is None:
return ht

# Keep family from both sample types if either passes quality AND inheritance
for sample_type in SampleType:
ht = ht.annotate(**{
sample_type.family_entries_field: hl.enumerate(ht[sample_type.family_entries_field]).starmap(
lambda i, family_samples: hl.or_missing(
hl.bind(
lambda other_sample_type_idx: (
self._family_has_valid_sample_type_entries(ht, sample_type, i) |
self._family_has_valid_sample_type_entries(ht, sample_type.other_sample_type, other_sample_type_idx)
),
family_idx_map.get(hl.coalesce(family_samples)[0]['familyGuid']).get(sample_type.other_sample_type.value),
), family_samples)
)})
ht = self._apply_quality_entry_filters(ht, sample_type, family_idx_map)
ht = self._apply_inheritance_entry_filters(ht, sample_type, family_idx_map, sample_idx_map)

# Merge family entries and filters from both sample types
ht = ht.transmute(
Expand All @@ -252,14 +285,57 @@ def _apply_multi_sample_type_entry_filters(self, ht, family_idx_map):
# Filter out families with no valid entries in either sample type
return ht.filter(ht.family_entries.any(hl.is_defined))

def _apply_quality_entry_filters(self, ht, sample_type, family_idx_map):
return ht.annotate(**{
sample_type.family_entries_field: hl.enumerate(ht[sample_type.family_entries_field]).starmap(
lambda i, family_samples: hl.or_missing(
hl.bind(lambda other_sample_type_idx: (
self._family_has_valid_quality(ht, sample_type, i) |
self._family_has_valid_quality(ht, sample_type.other_sample_type, other_sample_type_idx)
), family_idx_map.get(hl.coalesce(family_samples)[0]['familyGuid']).get(sample_type.other_sample_type.value),
), family_samples)
)})

@staticmethod
def _family_has_valid_sample_type_entries(ht, sample_type, sample_type_family_idx):
# Note: This logic does not sufficiently handle case 2 here https://docs.google.com/presentation/d/1hqDV8ulhviUcR5C4PtNUqkCLXKDsc6pccgFVlFmWUAU/edit?usp=sharing
# and will need to be changed to support it - https://github.com/broadinstitute/seqr/issues/4403
def _family_has_valid_quality(ht, sample_type, sample_type_family_idx):
return (
hl.is_defined(sample_type_family_idx) &
hl.is_defined(ht[sample_type.passes_quality_field][sample_type_family_idx]) &
hl.is_defined(ht[sample_type.passes_inheritance_field][sample_type_family_idx])
hl.is_defined(ht[sample_type.passes_quality_field][sample_type_family_idx])
)

def _apply_inheritance_entry_filters(self, ht, sample_type, family_idx_map, sample_idx_map):
return ht.annotate(**{
sample_type.family_entries_field: hl.if_else(
hl.is_missing(ht[sample_type.family_entries_field]), # If family entries has already been filtered due to quality do nothing
ht[sample_type.family_entries_field],
hl.enumerate(ht[sample_type.family_entries_field]).starmap( # Else,
lambda family_i, family_samples: hl.or_missing(
hl.all(hl.enumerate(family_samples).starmap(
lambda sample_i, sample: hl.any( # For each sample in a family,
hl.bind(lambda other_sample_type_indices: ( # Get the sample and family index of the sample in the other sample type family_entries
hl.if_else(
hl.is_defined(sample_i) & hl.is_defined(other_sample_type_indices[1]), # If samples are present for both sample types,
( # Keep the family entries if family passes inheritance in either sample type.
hl.is_defined(ht[sample_type.passes_inheritance_field][family_i]) |
hl.is_defined(ht[sample_type.other_sample_type.passes_inheritance_field][other_sample_type_indices[0]])
), # Else, if sample is in only one sample type, check if that sample did not fail inheritance
self._family_sample_has_valid_inheritance(ht, sample_type, family_i, sample_i) |
self._family_sample_has_valid_inheritance(ht, sample_type.other_sample_type, other_sample_type_indices[0], other_sample_type_indices[1])
)
),(
family_idx_map.get(hl.coalesce(sample)['familyGuid']).get(sample_type.other_sample_type.value),
sample_idx_map.get(hl.coalesce(sample)['familyGuid']).get(hl.coalesce(sample)['sampleId']).get(sample_type.other_sample_type.value)),
))
)), family_samples)
))
})

@staticmethod
def _family_sample_has_valid_inheritance(ht, sample_type, family_idx, sample_idx):
return (
hl.is_defined(family_idx) &
hl.is_defined(sample_idx) &
~hl.is_defined(ht[sample_type.failed_family_sample_field][family_idx].contains(sample_idx))
)

def _get_sample_genotype(self, samples, r=None, include_genotype_overrides=False, select_fields=None, **kwargs):
Expand Down
Loading