Skip to content

Commit 724010f

Browse files
authored
Merge pull request #214 from stanfordnlp/revert-191-main
Revert "feat: add intervenable_model to forward's function signature"
2 parents f6dbee1 + 5acd02f commit 724010f

File tree

5 files changed

+20
-68
lines changed

5 files changed

+20
-68
lines changed

pyvene/models/intervenable_base.py

+1-22
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,6 @@ def _intervention_setter(
804804
keys,
805805
unit_locations_base,
806806
subspaces,
807-
**intervention_forward_kwargs
808807
) -> HandlerList:
809808
"""
810809
Create a list of setter tracer that will set activations
@@ -849,7 +848,6 @@ def _intervention_setter(
849848
None,
850849
intervention,
851850
subspaces[key_i] if subspaces is not None else None,
852-
**intervention_forward_kwargs
853851
)
854852
# fail if this is not a fresh collect
855853
assert key not in self.activations
@@ -864,7 +862,6 @@ def _intervention_setter(
864862
None,
865863
intervention,
866864
subspaces[key_i] if subspaces is not None else None,
867-
**intervention_forward_kwargs
868865
)
869866
else:
870867
intervened_representation = do_intervention(
@@ -876,7 +873,6 @@ def _intervention_setter(
876873
),
877874
intervention,
878875
subspaces[key_i] if subspaces is not None else None,
879-
**intervention_forward_kwargs
880876
)
881877
else:
882878
# highly unlikely it's a primitive intervention type
@@ -889,7 +885,6 @@ def _intervention_setter(
889885
),
890886
intervention,
891887
subspaces[key_i] if subspaces is not None else None,
892-
**intervention_forward_kwargs
893888
)
894889
if intervened_representation is None:
895890
return
@@ -975,7 +970,6 @@ def _sync_forward_with_parallel_intervention(
975970
]
976971
if subspaces is not None
977972
else None,
978-
**kwargs
979973
)
980974
counterfactual_outputs = self.model.output.save()
981975

@@ -1003,7 +997,6 @@ def forward(
1003997
output_original_output: Optional[bool] = False,
1004998
return_dict: Optional[bool] = None,
1005999
use_cache: Optional[bool] = None,
1006-
**kwargs
10071000
):
10081001
activations_sources = source_representations
10091002
if sources is not None and not isinstance(sources, list):
@@ -1043,7 +1036,7 @@ def forward(
10431036
try:
10441037

10451038
# run intervened forward
1046-
model_kwargs = { **kwargs }
1039+
model_kwargs = {}
10471040
if labels is not None: # for training
10481041
model_kwargs["labels"] = labels
10491042
if use_cache is not None and 'use_cache' in self.model.config.to_dict(): # for transformer models
@@ -1533,7 +1526,6 @@ def _intervention_setter(
15331526
keys,
15341527
unit_locations_base,
15351528
subspaces,
1536-
**intervention_forward_kwargs
15371529
) -> HandlerList:
15381530
"""
15391531
Create a list of setter handlers that will set activations
@@ -1581,7 +1573,6 @@ def hook_callback(model, args, kwargs, output=None):
15811573
None,
15821574
intervention,
15831575
subspaces[key_i] if subspaces is not None else None,
1584-
**intervention_forward_kwargs
15851576
)
15861577
# fail if this is not a fresh collect
15871578
assert key not in self.activations
@@ -1597,7 +1588,6 @@ def hook_callback(model, args, kwargs, output=None):
15971588
None,
15981589
intervention,
15991590
subspaces[key_i] if subspaces is not None else None,
1600-
**intervention_forward_kwargs
16011591
)
16021592
if isinstance(raw_intervened_representation, InterventionOutput):
16031593
self.full_intervention_outputs.append(raw_intervened_representation)
@@ -1614,7 +1604,6 @@ def hook_callback(model, args, kwargs, output=None):
16141604
),
16151605
intervention,
16161606
subspaces[key_i] if subspaces is not None else None,
1617-
**intervention_forward_kwargs
16181607
)
16191608
else:
16201609
# highly unlikely it's a primitive intervention type
@@ -1627,7 +1616,6 @@ def hook_callback(model, args, kwargs, output=None):
16271616
),
16281617
intervention,
16291618
subspaces[key_i] if subspaces is not None else None,
1630-
**intervention_forward_kwargs
16311619
)
16321620
if intervened_representation is None:
16331621
return
@@ -1695,7 +1683,6 @@ def _wait_for_forward_with_parallel_intervention(
16951683
unit_locations,
16961684
activations_sources: Optional[Dict] = None,
16971685
subspaces: Optional[List] = None,
1698-
**intervention_forward_kwargs
16991686
):
17001687
# torch.autograd.set_detect_anomaly(True)
17011688
all_set_handlers = HandlerList([])
@@ -1751,7 +1738,6 @@ def _wait_for_forward_with_parallel_intervention(
17511738
]
17521739
if subspaces is not None
17531740
else None,
1754-
**intervention_forward_kwargs
17551741
)
17561742
# for setters, we don't remove them.
17571743
all_set_handlers.extend(set_handlers)
@@ -1763,7 +1749,6 @@ def _wait_for_forward_with_serial_intervention(
17631749
unit_locations,
17641750
activations_sources: Optional[Dict] = None,
17651751
subspaces: Optional[List] = None,
1766-
**intervention_forward_kwargs
17671752
):
17681753
all_set_handlers = HandlerList([])
17691754
for group_id, keys in self._intervention_group.items():
@@ -1820,7 +1805,6 @@ def _wait_for_forward_with_serial_intervention(
18201805
]
18211806
if subspaces is not None
18221807
else None,
1823-
**intervention_forward_kwargs
18241808
)
18251809
# for setters, we don't remove them.
18261810
all_set_handlers.extend(set_handlers)
@@ -1837,7 +1821,6 @@ def forward(
18371821
output_original_output: Optional[bool] = False,
18381822
return_dict: Optional[bool] = None,
18391823
use_cache: Optional[bool] = None,
1840-
**intervention_forward_kwargs
18411824
):
18421825
"""
18431826
Main forward function that serves a wrapper to
@@ -1946,7 +1929,6 @@ def forward(
19461929
unit_locations,
19471930
activations_sources,
19481931
subspaces,
1949-
**intervention_forward_kwargs
19501932
)
19511933
)
19521934
elif self.mode == "serial":
@@ -1956,7 +1938,6 @@ def forward(
19561938
unit_locations,
19571939
activations_sources,
19581940
subspaces,
1959-
**intervention_forward_kwargs
19601941
)
19611942
)
19621943

@@ -2090,7 +2071,6 @@ def generate(
20902071
unit_locations,
20912072
activations_sources,
20922073
subspaces,
2093-
**kwargs
20942074
)
20952075
)
20962076
elif self.mode == "serial":
@@ -2100,7 +2080,6 @@ def generate(
21002080
unit_locations,
21012081
activations_sources,
21022082
subspaces,
2103-
**kwargs
21042083
)
21052084
)
21062085

pyvene/models/interventions.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def set_interchange_dim(self, interchange_dim):
7575
self.interchange_dim = interchange_dim
7676

7777
@abstractmethod
78-
def forward(self, base, source, subspaces=None, **kwargs):
78+
def forward(self, base, source, subspaces=None):
7979
pass
8080

8181

@@ -153,7 +153,7 @@ class ZeroIntervention(ConstantSourceIntervention, LocalistRepresentationInterve
153153
def __init__(self, **kwargs):
154154
super().__init__(**kwargs)
155155

156-
def forward(self, base, source=None, subspaces=None, **kwargs):
156+
def forward(self, base, source=None, subspaces=None):
157157
return _do_intervention_by_swap(
158158
base,
159159
torch.zeros_like(base),
@@ -175,7 +175,7 @@ class CollectIntervention(ConstantSourceIntervention):
175175
def __init__(self, **kwargs):
176176
super().__init__(**kwargs)
177177

178-
def forward(self, base, source=None, subspaces=None, **kwargs):
178+
def forward(self, base, source=None, subspaces=None):
179179
return _do_intervention_by_swap(
180180
base,
181181
source,
@@ -197,7 +197,7 @@ class SkipIntervention(BasisAgnosticIntervention, LocalistRepresentationInterven
197197
def __init__(self, **kwargs):
198198
super().__init__(**kwargs)
199199

200-
def forward(self, base, source, subspaces=None, **kwargs):
200+
def forward(self, base, source, subspaces=None):
201201
# source here is the base example input to the hook
202202
return _do_intervention_by_swap(
203203
base,
@@ -220,7 +220,7 @@ class VanillaIntervention(Intervention, LocalistRepresentationIntervention):
220220
def __init__(self, **kwargs):
221221
super().__init__(**kwargs)
222222

223-
def forward(self, base, source, subspaces=None, **kwargs):
223+
def forward(self, base, source, subspaces=None):
224224
return _do_intervention_by_swap(
225225
base,
226226
source if self.source_representation is None else self.source_representation,
@@ -242,7 +242,7 @@ class AdditionIntervention(BasisAgnosticIntervention, LocalistRepresentationInte
242242
def __init__(self, **kwargs):
243243
super().__init__(**kwargs)
244244

245-
def forward(self, base, source, subspaces=None, **kwargs):
245+
def forward(self, base, source, subspaces=None):
246246
return _do_intervention_by_swap(
247247
base,
248248
source if self.source_representation is None else self.source_representation,
@@ -264,7 +264,7 @@ class SubtractionIntervention(BasisAgnosticIntervention, LocalistRepresentationI
264264
def __init__(self, **kwargs):
265265
super().__init__(**kwargs)
266266

267-
def forward(self, base, source, subspaces=None, **kwargs):
267+
def forward(self, base, source, subspaces=None):
268268

269269
return _do_intervention_by_swap(
270270
base,
@@ -289,7 +289,7 @@ def __init__(self, **kwargs):
289289
rotate_layer = RotateLayer(self.embed_dim)
290290
self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
291291

292-
def forward(self, base, source, subspaces=None, **kwargs):
292+
def forward(self, base, source, subspaces=None):
293293
rotated_base = self.rotate_layer(base)
294294
rotated_source = self.rotate_layer(source)
295295
# interchange
@@ -340,7 +340,7 @@ def set_intervention_boundaries(self, intervention_boundaries):
340340
torch.tensor([intervention_boundaries]), requires_grad=True
341341
)
342342

343-
def forward(self, base, source, subspaces=None, **kwargs):
343+
def forward(self, base, source, subspaces=None):
344344
batch_size = base.shape[0]
345345
rotated_base = self.rotate_layer(base)
346346
rotated_source = self.rotate_layer(source)
@@ -391,7 +391,7 @@ def get_temperature(self):
391391
def set_temperature(self, temp: torch.Tensor):
392392
self.temperature.data = temp
393393

394-
def forward(self, base, source, subspaces=None, **kwargs):
394+
def forward(self, base, source, subspaces=None):
395395
batch_size = base.shape[0]
396396
rotated_base = self.rotate_layer(base)
397397
rotated_source = self.rotate_layer(source)
@@ -431,7 +431,7 @@ def get_temperature(self):
431431
def set_temperature(self, temp: torch.Tensor):
432432
self.temperature.data = temp
433433

434-
def forward(self, base, source, subspaces=None, **kwargs):
434+
def forward(self, base, source, subspaces=None):
435435
batch_size = base.shape[0]
436436
# get boundary mask between 0 and 1 from sigmoid
437437
mask_sigmoid = torch.sigmoid(self.mask / torch.tensor(self.temperature))
@@ -456,7 +456,7 @@ def __init__(self, **kwargs):
456456
rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"])
457457
self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
458458

459-
def forward(self, base, source, subspaces=None, **kwargs):
459+
def forward(self, base, source, subspaces=None):
460460
rotated_base = self.rotate_layer(base)
461461
rotated_source = self.rotate_layer(source)
462462
if subspaces is not None:
@@ -529,7 +529,7 @@ def __init__(self, **kwargs):
529529
)
530530
self.trainable = False
531531

532-
def forward(self, base, source, subspaces=None, **kwargs):
532+
def forward(self, base, source, subspaces=None):
533533
base_norm = (base - self.pca_mean) / self.pca_std
534534
source_norm = (source - self.pca_mean) / self.pca_std
535535

@@ -565,7 +565,7 @@ def __init__(self, **kwargs):
565565
prng(1, 4, self.embed_dim)))
566566
self.register_buffer('noise_level', torch.tensor(noise_level))
567567

568-
def forward(self, base, source=None, subspaces=None, **kwargs):
568+
def forward(self, base, source=None, subspaces=None):
569569
base[..., : self.interchange_dim] += self.noise * self.noise_level
570570
return base
571571

@@ -585,7 +585,7 @@ def __init__(self, **kwargs):
585585
self.autoencoder = AutoencoderLayer(
586586
self.embed_dim, kwargs["latent_dim"])
587587

588-
def forward(self, base, source, subspaces=None, **kwargs):
588+
def forward(self, base, source, subspaces=None):
589589
base_dtype = base.dtype
590590
base = base.to(self.autoencoder.encoder[0].weight.dtype)
591591
base_latent = self.autoencoder.encode(base)
@@ -619,7 +619,7 @@ def encode(self, input_acts):
619619
def decode(self, acts):
620620
return acts @ self.W_dec + self.b_dec
621621

622-
def forward(self, base, source=None, subspaces=None, **kwargs):
622+
def forward(self, base, source=None, subspaces=None):
623623
# generate latents for base and source runs.
624624
base_latent = self.encode(base)
625625
source_latent = self.encode(source)

pyvene/models/modeling_utils.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def scatter_neurons(
446446

447447

448448
def do_intervention(
449-
base_representation, source_representation, intervention, subspaces, **intervention_forward_kwargs
449+
base_representation, source_representation, intervention, subspaces
450450
):
451451
"""Do the actual intervention."""
452452

@@ -478,8 +478,7 @@ def do_intervention(
478478
assert False # what's going on?
479479

480480
intervention_output = intervention(
481-
base_representation_f, source_representation_f, subspaces,
482-
**intervention_forward_kwargs
481+
base_representation_f, source_representation_f, subspaces
483482
)
484483
if isinstance(intervention_output, InterventionOutput):
485484
intervened_representation = intervention_output.output

tests/integration_tests/IntervenableBasicTestCase.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ class MultiplierIntervention(
232232
def __init__(self, embed_dim, **kwargs):
233233
super().__init__()
234234
def forward(
235-
self, base, source=None, subspaces=None, **kwargs):
235+
self, base, source=None, subspaces=None):
236236
return base * 99.0
237237
# run with new intervention type
238238
pv_gpt2 = pv.IntervenableModel({

tests/integration_tests/InterventionWithLlamaTestCase.py

-26
Original file line numberDiff line numberDiff line change
@@ -156,32 +156,6 @@ def test_with_multiple_heads_positions_vanilla_intervention_positive(self):
156156
heads=[4, 1],
157157
positions=[7, 2],
158158
)
159-
160-
def test_with_llm_head(self):
161-
that = self
162-
_lm_head_collection = {}
163-
class AccessIntervenableModelIntervention:
164-
is_source_constant = True
165-
keep_last_dim = True
166-
intervention_types = 'access_intervenable_model_intervention'
167-
def __init__(self, layer_index, *args, **kwargs):
168-
super().__init__()
169-
self.layer_index = layer_index
170-
def __call__(self, base, source=None, subspaces=None, model=None, **kwargs):
171-
intervenable_model = kwargs.get('intervenable_model', None)
172-
assert intervenable_model is not None
173-
_lm_head_collection[self.layer_index] = intervenable_model.model.lm_head(base.to(that.device))
174-
return base
175-
# run with new intervention type
176-
pv_llama = IntervenableModel([{
177-
"intervention": AccessIntervenableModelIntervention(layer_index=layer),
178-
"component": f"model.layers.{layer}.input"
179-
} for layer in [1, 3]], model=self.llama)
180-
intervened_outputs = pv_llama(
181-
base=self.tokenizer("The capital of Spain is", return_tensors="pt").to(that.device),
182-
unit_locations={"base": 3},
183-
intervenable_model=pv_llama
184-
)
185159

186160

187161
def suite():

0 commit comments

Comments
 (0)