Skip to content

Commit

Permalink
bugfix: nt implanting discards imgt positions; minor import fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pavlovicmilena committed Jul 18, 2024
1 parent a40e0d7 commit 152b38e
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 29 deletions.
2 changes: 1 addition & 1 deletion ligo/environment/Constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
class Constants:

VERSION = "1.0.7"
VERSION = "1.0.8"

# encoding constants
FEATURE_DELIMITER = "-"
Expand Down
6 changes: 4 additions & 2 deletions ligo/simulation/generative_models/BackgroundSequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ def build_from_receptor_sequences(cls, sequences: List[ReceptorSequence]):
sequence=[s.nucleotide_sequence for s in sequences],
v_call=[s.metadata.v_call if s.metadata else '' for s in sequences],
j_call=[s.metadata.j_call if s.metadata else '' for s in sequences],
region_type=[s.metadata.region_type.name if s.metadata else '' for s in sequences],
frame_type=[s.metadata.frame_type.name if s.metadata else '' for s in sequences],
region_type=[s.metadata.region_type.name if s.metadata and s.metadata.region_type
else '' for s in sequences],
frame_type=[s.metadata.frame_type.name if s.metadata and s.metadata.frame_type
else '' for s in sequences],
p_gen=[-1. for _ in sequences], from_default_model=[1 for _ in sequences],
duplicate_count=[s.metadata.duplicate_count for s in sequences],
chain=[s.metadata.chain.value for s in sequences])
35 changes: 21 additions & 14 deletions ligo/simulation/simulation_strategy/ImplantingStrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,28 @@ def _implant_in_sequence(self, sequence_row: BackgroundSequences, signal: Signal
-> dict | None:

limit = len(motif_instance)
sequence_length = len(getattr(sequence_row, sequence_type.value))

if sequence_type == SequenceType.NUCLEOTIDE:
limit = limit * 3
if signal.sequence_position_weights:
logging.warning(f"{ImplantingStrategy.__name__}: IMGT positions are defined for the signal {signal.id},"
f" but positions is not supported for nucleotide sequences. Positions will be ignored "
f"and signal implanted at a random position.")

sequence_length = len(getattr(sequence_row, sequence_type.value))
region_type = RegionType[getattr(sequence_row, 'region_type').to_string()]
position_weights = PositionHelper.get_imgt_position_weights_for_implanting(sequence_length, region_type,
signal.sequence_position_weights,
limit)
if sum(list(position_weights.values())) == 0:
logging.info( f"Sequence {sequence_row} has no valid positions where the signal could be implanted, "
f"skipping the sequence.")
return None
implant_position = random.choice(list(range(sequence_length - limit)))
position_weights = {}

implant_position = choose_implant_position(list(position_weights.keys()), position_weights)
else:
region_type = RegionType[getattr(sequence_row, 'region_type').to_string()]
position_weights = PositionHelper.get_imgt_position_weights_for_implanting(sequence_length, region_type,
signal.sequence_position_weights,
limit)
if sum(list(position_weights.values())) == 0:
logging.info(f"Sequence {sequence_row} has no valid positions where the signal could be implanted, "
f"skipping the sequence.")
return None

implant_position = choose_implant_position(list(position_weights.keys()), position_weights)

new_sequence = self._make_new_sequence(sequence_row, motif_instance, implant_position, sequence_type)

Expand All @@ -141,7 +149,8 @@ def _implant_in_sequence(self, sequence_row: BackgroundSequences, signal: Signal
new_sequence['p_gen'] = -1.

new_sequence[signal.id] = 1
new_sequence[f'{signal.id}_positions'] = ["m"] + ["0" if ind != implant_position else "1" for ind in range(sequence_length)]
new_sequence[f'{signal.id}_positions'] = ["m"] + ["0" if ind != implant_position else "1" for ind in
range(sequence_length)]
new_sequence[f'{signal.id}_positions'] = "".join(new_sequence[f'{signal.id}_positions'])

zero_mask = "m" + "".join(["0" for _ in range(len(new_sequence[sequence_type.value]))])
Expand All @@ -164,8 +173,6 @@ def _make_new_sequence(self, sequence_row: BackgroundSequences, motif_instance:
motif_left = motif_instance.instance
motif_right = ""

if sequence_type == SequenceType.NUCLEOTIDE:
position *= 3
sequence_string = getattr(sequence_row, sequence_type.value).to_string()

gap_start = position + len(motif_left)
Expand Down
19 changes: 13 additions & 6 deletions ligo/simulation/util/bnp_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ def pad_ragged_array(new_array, target_shape, padded_value):

return padded_array

def make_bnp_dataclass_object_from_dicts(dict_objects: List[dict], field_type_map: dict = None, signals: list = None, base_class=None) -> BNPDataClass:

def make_bnp_dataclass_object_from_dicts(dict_objects: List[dict], field_type_map: dict = None, signals: list = None,
base_class=None) -> BNPDataClass:
if not isinstance(dict_objects, list) or len(dict_objects) == 0:
raise RuntimeError("Cannot make dataclass, got empty list as input.")

Expand All @@ -31,10 +33,13 @@ def make_bnp_dataclass_object_from_dicts(dict_objects: List[dict], field_type_ma
functions = {"get_signal_matrix": lambda self: np.array([getattr(self, name) for name in signal_names]).T,
"get_signal_names": lambda self: signal_names}

new_class = bnpdataclass(dc_make_dataclass("DynamicDC", bases=tuple([base_class]) if base_class is not None else (), namespace=functions,
fields=fields_list))
new_class = bnpdataclass(
dc_make_dataclass("DynamicDC", bases=tuple([base_class]) if base_class is not None else (),
namespace=functions,
fields=fields_list))
elif base_class:
new_class = base_class.extend(fields)
base_class_fields = [f.name for f in get_fields(base_class)]
new_class = base_class.extend([(field, field_type) for field, field_type in fields_list if field not in base_class_fields])
else:
new_class = bnpdataclass(dc_make_dataclass("DynamicDC", fields=fields_list))

Expand All @@ -54,7 +59,8 @@ def _extract_fields(transformed_objs, field_type_map):
if field_type_map is not None and field_name in field_type_map:
field_type = field_type_map[field_name]
if isinstance(field_type, Encoding):
transformed_objs[field_name] = as_encoded_array(transformed_objs[field_name], field_type) if any(transformed_objs[field_name]) else None
transformed_objs[field_name] = as_encoded_array(transformed_objs[field_name], field_type) if any(
transformed_objs[field_name]) else None
elif isinstance(transformed_objs[field_name][0], EncodedArray):
field_type = transformed_objs[field_name][0].encoding
elif transformed_objs[field_name] is not None:
Expand All @@ -74,7 +80,8 @@ def merge_dataclass_objects(objects: list): # TODO: replace with equivalent fro
assert all(hasattr(obj, field) for field in field_names), ([f.name for f in get_fields(obj)], field_names)

cls = type(objects[0])
return cls(**{field_name: list(chain.from_iterable([getattr(obj, field_name) for obj in objects])) for field_name in field_names})
return cls(**{field_name: list(chain.from_iterable([getattr(obj, field_name) for obj in objects])) for field_name in
field_names})


def _make_new_fields(new_fields: dict) -> List[tuple]:
Expand Down
2 changes: 1 addition & 1 deletion ligo/simulation/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_bnp_data(sequence_path, bnp_data_class):
if sequence_path.is_file():
buff_type = delimited_buffers.get_bufferclass_for_datatype(bnp_data_class, delimiter='\t', has_header=True)

with bnp.open(sequence_path, buffer_type=buff_type, lazy=False) as file:
with bnp.open(sequence_path, 'r', buff_type, False) as file:
data = file.read()

return data
Expand Down
10 changes: 5 additions & 5 deletions ligo/util/PositionHelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,26 @@ def get_allowed_positions_for_annotation(input_length: int, region_type: RegionT
return [int(bool(weight)) for weight in position_weights.values()]

@staticmethod
def get_imgt_position_weights_for_implanting(input_length: int, region_type: RegionType,
def get_imgt_position_weights_for_implanting(aa_input_length: int, region_type: RegionType,
sequence_position_weights: dict, limit: int):
position_weights = PositionHelper.get_imgt_position_weights_for_annotation(input_length, region_type,
position_weights = PositionHelper.get_imgt_position_weights_for_annotation(aa_input_length, region_type,
sequence_position_weights)

for index, position in enumerate(position_weights.keys()):
if index > input_length - limit:
if index > aa_input_length - limit:
position_weights[position] = 0.

weights_sum = sum(list(position_weights.values()))
if weights_sum == 0:
logging.warning(f"Sequence of length {input_length} has no allowed positions for signal with sequence "
logging.warning(f"Sequence of length {aa_input_length} has no allowed positions for signal with sequence "
f"position weights {sequence_position_weights} and motif length {limit}, it will be discarded.")
return position_weights

position_weights = {position: np.array([weight]).astype(np.float64)[0] / weights_sum
for position, weight in position_weights.items()}

assert np.isclose(sum(list(position_weights.values())), 1.), \
(input_length, region_type.name, position_weights, sum(list(position_weights.values())), limit)
(aa_input_length, region_type.name, position_weights, sum(list(position_weights.values())), limit)

return position_weights

Expand Down
File renamed without changes.
18 changes: 18 additions & 0 deletions test/simulation/sim_strategy/test_implanting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from ligo.environment.SequenceType import SequenceType
from ligo.simulation.SimConfigItem import SimConfigItem
from ligo.simulation.generative_models.BackgroundSequences import BackgroundSequences
from ligo.simulation.implants.SeedMotif import SeedMotif
from ligo.simulation.implants.Signal import Signal
from ligo.simulation.simulation_strategy.ImplantingStrategy import ImplantingStrategy


def test_implanting():
s1 = Signal('s1', [SeedMotif('m1', 'AAA')], {'104': 0, '105': 0})
seqs = BackgroundSequences(sequence=["CCCCC", "CCCCCCCCC"], sequence_aa=["A", "AA"], v_call=["", ""], j_call=["", ""],
region_type=["IMGT_JUNCTION", "IMGT_JUNCTION"], frame_type=["", ""], p_gen=[-1., -1.], from_default_model=[1, 1],
duplicate_count=[1,1], chain=["TRB", "TRB"])

processed_seqs = ImplantingStrategy().process_sequences(seqs, {'s1': 2}, False, SequenceType.NUCLEOTIDE,
SimConfigItem({s1: 1.}), [s1], False)

print(processed_seqs)

0 comments on commit 152b38e

Please sign in to comment.