Skip to content

Commit 152b38e

Browse files
bugfix: nt implanting discards imgt positions; minor import fixes
1 parent a40e0d7 commit 152b38e

File tree

8 files changed

+63
-29
lines changed

8 files changed

+63
-29
lines changed

ligo/environment/Constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
class Constants:
22

3-
VERSION = "1.0.7"
3+
VERSION = "1.0.8"
44

55
# encoding constants
66
FEATURE_DELIMITER = "-"

ligo/simulation/generative_models/BackgroundSequences.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ def build_from_receptor_sequences(cls, sequences: List[ReceptorSequence]):
3232
sequence=[s.nucleotide_sequence for s in sequences],
3333
v_call=[s.metadata.v_call if s.metadata else '' for s in sequences],
3434
j_call=[s.metadata.j_call if s.metadata else '' for s in sequences],
35-
region_type=[s.metadata.region_type.name if s.metadata else '' for s in sequences],
36-
frame_type=[s.metadata.frame_type.name if s.metadata else '' for s in sequences],
35+
region_type=[s.metadata.region_type.name if s.metadata and s.metadata.region_type
36+
else '' for s in sequences],
37+
frame_type=[s.metadata.frame_type.name if s.metadata and s.metadata.frame_type
38+
else '' for s in sequences],
3739
p_gen=[-1. for _ in sequences], from_default_model=[1 for _ in sequences],
3840
duplicate_count=[s.metadata.duplicate_count for s in sequences],
3941
chain=[s.metadata.chain.value for s in sequences])

ligo/simulation/simulation_strategy/ImplantingStrategy.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -107,20 +107,28 @@ def _implant_in_sequence(self, sequence_row: BackgroundSequences, signal: Signal
107107
-> dict | None:
108108

109109
limit = len(motif_instance)
110+
sequence_length = len(getattr(sequence_row, sequence_type.value))
111+
110112
if sequence_type == SequenceType.NUCLEOTIDE:
111-
limit = limit * 3
113+
if signal.sequence_position_weights:
114+
logging.warning(f"{ImplantingStrategy.__name__}: IMGT positions are defined for the signal {signal.id},"
115+
f" but positions is not supported for nucleotide sequences. Positions will be ignored "
116+
f"and signal implanted at a random position.")
112117

113-
sequence_length = len(getattr(sequence_row, sequence_type.value))
114-
region_type = RegionType[getattr(sequence_row, 'region_type').to_string()]
115-
position_weights = PositionHelper.get_imgt_position_weights_for_implanting(sequence_length, region_type,
116-
signal.sequence_position_weights,
117-
limit)
118-
if sum(list(position_weights.values())) == 0:
119-
logging.info( f"Sequence {sequence_row} has no valid positions where the signal could be implanted, "
120-
f"skipping the sequence.")
121-
return None
118+
implant_position = random.choice(list(range(sequence_length - limit)))
119+
position_weights = {}
122120

123-
implant_position = choose_implant_position(list(position_weights.keys()), position_weights)
121+
else:
122+
region_type = RegionType[getattr(sequence_row, 'region_type').to_string()]
123+
position_weights = PositionHelper.get_imgt_position_weights_for_implanting(sequence_length, region_type,
124+
signal.sequence_position_weights,
125+
limit)
126+
if sum(list(position_weights.values())) == 0:
127+
logging.info(f"Sequence {sequence_row} has no valid positions where the signal could be implanted, "
128+
f"skipping the sequence.")
129+
return None
130+
131+
implant_position = choose_implant_position(list(position_weights.keys()), position_weights)
124132

125133
new_sequence = self._make_new_sequence(sequence_row, motif_instance, implant_position, sequence_type)
126134

@@ -141,7 +149,8 @@ def _implant_in_sequence(self, sequence_row: BackgroundSequences, signal: Signal
141149
new_sequence['p_gen'] = -1.
142150

143151
new_sequence[signal.id] = 1
144-
new_sequence[f'{signal.id}_positions'] = ["m"] + ["0" if ind != implant_position else "1" for ind in range(sequence_length)]
152+
new_sequence[f'{signal.id}_positions'] = ["m"] + ["0" if ind != implant_position else "1" for ind in
153+
range(sequence_length)]
145154
new_sequence[f'{signal.id}_positions'] = "".join(new_sequence[f'{signal.id}_positions'])
146155

147156
zero_mask = "m" + "".join(["0" for _ in range(len(new_sequence[sequence_type.value]))])
@@ -164,8 +173,6 @@ def _make_new_sequence(self, sequence_row: BackgroundSequences, motif_instance:
164173
motif_left = motif_instance.instance
165174
motif_right = ""
166175

167-
if sequence_type == SequenceType.NUCLEOTIDE:
168-
position *= 3
169176
sequence_string = getattr(sequence_row, sequence_type.value).to_string()
170177

171178
gap_start = position + len(motif_left)

ligo/simulation/util/bnp_util.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ def pad_ragged_array(new_array, target_shape, padded_value):
1818

1919
return padded_array
2020

21-
def make_bnp_dataclass_object_from_dicts(dict_objects: List[dict], field_type_map: dict = None, signals: list = None, base_class=None) -> BNPDataClass:
21+
22+
def make_bnp_dataclass_object_from_dicts(dict_objects: List[dict], field_type_map: dict = None, signals: list = None,
23+
base_class=None) -> BNPDataClass:
2224
if not isinstance(dict_objects, list) or len(dict_objects) == 0:
2325
raise RuntimeError("Cannot make dataclass, got empty list as input.")
2426

@@ -31,10 +33,13 @@ def make_bnp_dataclass_object_from_dicts(dict_objects: List[dict], field_type_ma
3133
functions = {"get_signal_matrix": lambda self: np.array([getattr(self, name) for name in signal_names]).T,
3234
"get_signal_names": lambda self: signal_names}
3335

34-
new_class = bnpdataclass(dc_make_dataclass("DynamicDC", bases=tuple([base_class]) if base_class is not None else (), namespace=functions,
35-
fields=fields_list))
36+
new_class = bnpdataclass(
37+
dc_make_dataclass("DynamicDC", bases=tuple([base_class]) if base_class is not None else (),
38+
namespace=functions,
39+
fields=fields_list))
3640
elif base_class:
37-
new_class = base_class.extend(fields)
41+
base_class_fields = [f.name for f in get_fields(base_class)]
42+
new_class = base_class.extend([(field, field_type) for field, field_type in fields_list if field not in base_class_fields])
3843
else:
3944
new_class = bnpdataclass(dc_make_dataclass("DynamicDC", fields=fields_list))
4045

@@ -54,7 +59,8 @@ def _extract_fields(transformed_objs, field_type_map):
5459
if field_type_map is not None and field_name in field_type_map:
5560
field_type = field_type_map[field_name]
5661
if isinstance(field_type, Encoding):
57-
transformed_objs[field_name] = as_encoded_array(transformed_objs[field_name], field_type) if any(transformed_objs[field_name]) else None
62+
transformed_objs[field_name] = as_encoded_array(transformed_objs[field_name], field_type) if any(
63+
transformed_objs[field_name]) else None
5864
elif isinstance(transformed_objs[field_name][0], EncodedArray):
5965
field_type = transformed_objs[field_name][0].encoding
6066
elif transformed_objs[field_name] is not None:
@@ -74,7 +80,8 @@ def merge_dataclass_objects(objects: list): # TODO: replace with equivalent fro
7480
assert all(hasattr(obj, field) for field in field_names), ([f.name for f in get_fields(obj)], field_names)
7581

7682
cls = type(objects[0])
77-
return cls(**{field_name: list(chain.from_iterable([getattr(obj, field_name) for obj in objects])) for field_name in field_names})
83+
return cls(**{field_name: list(chain.from_iterable([getattr(obj, field_name) for obj in objects])) for field_name in
84+
field_names})
7885

7986

8087
def _make_new_fields(new_fields: dict) -> List[tuple]:

ligo/simulation/util/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def get_bnp_data(sequence_path, bnp_data_class):
6161
if sequence_path.is_file():
6262
buff_type = delimited_buffers.get_bufferclass_for_datatype(bnp_data_class, delimiter='\t', has_header=True)
6363

64-
with bnp.open(sequence_path, buffer_type=buff_type, lazy=False) as file:
64+
with bnp.open(sequence_path, 'r', buff_type, False) as file:
6565
data = file.read()
6666

6767
return data

ligo/util/PositionHelper.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,26 +44,26 @@ def get_allowed_positions_for_annotation(input_length: int, region_type: RegionT
4444
return [int(bool(weight)) for weight in position_weights.values()]
4545

4646
@staticmethod
47-
def get_imgt_position_weights_for_implanting(input_length: int, region_type: RegionType,
47+
def get_imgt_position_weights_for_implanting(aa_input_length: int, region_type: RegionType,
4848
sequence_position_weights: dict, limit: int):
49-
position_weights = PositionHelper.get_imgt_position_weights_for_annotation(input_length, region_type,
49+
position_weights = PositionHelper.get_imgt_position_weights_for_annotation(aa_input_length, region_type,
5050
sequence_position_weights)
5151

5252
for index, position in enumerate(position_weights.keys()):
53-
if index > input_length - limit:
53+
if index > aa_input_length - limit:
5454
position_weights[position] = 0.
5555

5656
weights_sum = sum(list(position_weights.values()))
5757
if weights_sum == 0:
58-
logging.warning(f"Sequence of length {input_length} has no allowed positions for signal with sequence "
58+
logging.warning(f"Sequence of length {aa_input_length} has no allowed positions for signal with sequence "
5959
f"position weights {sequence_position_weights} and motif length {limit}, it will be discarded.")
6060
return position_weights
6161

6262
position_weights = {position: np.array([weight]).astype(np.float64)[0] / weights_sum
6363
for position, weight in position_weights.items()}
6464

6565
assert np.isclose(sum(list(position_weights.values())), 1.), \
66-
(input_length, region_type.name, position_weights, sum(list(position_weights.values())), limit)
66+
(aa_input_length, region_type.name, position_weights, sum(list(position_weights.values())), limit)
6767

6868
return position_weights
6969

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from ligo.environment.SequenceType import SequenceType
2+
from ligo.simulation.SimConfigItem import SimConfigItem
3+
from ligo.simulation.generative_models.BackgroundSequences import BackgroundSequences
4+
from ligo.simulation.implants.SeedMotif import SeedMotif
5+
from ligo.simulation.implants.Signal import Signal
6+
from ligo.simulation.simulation_strategy.ImplantingStrategy import ImplantingStrategy
7+
8+
9+
def test_implanting():
10+
s1 = Signal('s1', [SeedMotif('m1', 'AAA')], {'104': 0, '105': 0})
11+
seqs = BackgroundSequences(sequence=["CCCCC", "CCCCCCCCC"], sequence_aa=["A", "AA"], v_call=["", ""], j_call=["", ""],
12+
region_type=["IMGT_JUNCTION", "IMGT_JUNCTION"], frame_type=["", ""], p_gen=[-1., -1.], from_default_model=[1, 1],
13+
duplicate_count=[1,1], chain=["TRB", "TRB"])
14+
15+
processed_seqs = ImplantingStrategy().process_sequences(seqs, {'s1': 2}, False, SequenceType.NUCLEOTIDE,
16+
SimConfigItem({s1: 1.}), [s1], False)
17+
18+
print(processed_seqs)

0 commit comments

Comments
 (0)