@@ -75,7 +75,7 @@ def set_interchange_dim(self, interchange_dim):
75
75
self .interchange_dim = interchange_dim
76
76
77
77
@abstractmethod
78
- def forward (self , base , source , subspaces = None , ** kwargs ):
78
+ def forward (self , base , source , subspaces = None ):
79
79
pass
80
80
81
81
@@ -153,7 +153,7 @@ class ZeroIntervention(ConstantSourceIntervention, LocalistRepresentationInterve
153
153
def __init__ (self , ** kwargs ):
154
154
super ().__init__ (** kwargs )
155
155
156
- def forward (self , base , source = None , subspaces = None , ** kwargs ):
156
+ def forward (self , base , source = None , subspaces = None ):
157
157
return _do_intervention_by_swap (
158
158
base ,
159
159
torch .zeros_like (base ),
@@ -175,7 +175,7 @@ class CollectIntervention(ConstantSourceIntervention):
175
175
def __init__ (self , ** kwargs ):
176
176
super ().__init__ (** kwargs )
177
177
178
- def forward (self , base , source = None , subspaces = None , ** kwargs ):
178
+ def forward (self , base , source = None , subspaces = None ):
179
179
return _do_intervention_by_swap (
180
180
base ,
181
181
source ,
@@ -197,7 +197,7 @@ class SkipIntervention(BasisAgnosticIntervention, LocalistRepresentationInterven
197
197
def __init__ (self , ** kwargs ):
198
198
super ().__init__ (** kwargs )
199
199
200
- def forward (self , base , source , subspaces = None , ** kwargs ):
200
+ def forward (self , base , source , subspaces = None ):
201
201
# source here is the base example input to the hook
202
202
return _do_intervention_by_swap (
203
203
base ,
@@ -220,7 +220,7 @@ class VanillaIntervention(Intervention, LocalistRepresentationIntervention):
220
220
def __init__ (self , ** kwargs ):
221
221
super ().__init__ (** kwargs )
222
222
223
- def forward (self , base , source , subspaces = None , ** kwargs ):
223
+ def forward (self , base , source , subspaces = None ):
224
224
return _do_intervention_by_swap (
225
225
base ,
226
226
source if self .source_representation is None else self .source_representation ,
@@ -242,7 +242,7 @@ class AdditionIntervention(BasisAgnosticIntervention, LocalistRepresentationInte
242
242
def __init__ (self , ** kwargs ):
243
243
super ().__init__ (** kwargs )
244
244
245
- def forward (self , base , source , subspaces = None , ** kwargs ):
245
+ def forward (self , base , source , subspaces = None ):
246
246
return _do_intervention_by_swap (
247
247
base ,
248
248
source if self .source_representation is None else self .source_representation ,
@@ -264,7 +264,7 @@ class SubtractionIntervention(BasisAgnosticIntervention, LocalistRepresentationI
264
264
def __init__ (self , ** kwargs ):
265
265
super ().__init__ (** kwargs )
266
266
267
- def forward (self , base , source , subspaces = None , ** kwargs ):
267
+ def forward (self , base , source , subspaces = None ):
268
268
269
269
return _do_intervention_by_swap (
270
270
base ,
@@ -289,7 +289,7 @@ def __init__(self, **kwargs):
289
289
rotate_layer = RotateLayer (self .embed_dim )
290
290
self .rotate_layer = torch .nn .utils .parametrizations .orthogonal (rotate_layer )
291
291
292
- def forward (self , base , source , subspaces = None , ** kwargs ):
292
+ def forward (self , base , source , subspaces = None ):
293
293
rotated_base = self .rotate_layer (base )
294
294
rotated_source = self .rotate_layer (source )
295
295
# interchange
@@ -340,7 +340,7 @@ def set_intervention_boundaries(self, intervention_boundaries):
340
340
torch .tensor ([intervention_boundaries ]), requires_grad = True
341
341
)
342
342
343
- def forward (self , base , source , subspaces = None , ** kwargs ):
343
+ def forward (self , base , source , subspaces = None ):
344
344
batch_size = base .shape [0 ]
345
345
rotated_base = self .rotate_layer (base )
346
346
rotated_source = self .rotate_layer (source )
@@ -391,7 +391,7 @@ def get_temperature(self):
391
391
def set_temperature (self , temp : torch .Tensor ):
392
392
self .temperature .data = temp
393
393
394
- def forward (self , base , source , subspaces = None , ** kwargs ):
394
+ def forward (self , base , source , subspaces = None ):
395
395
batch_size = base .shape [0 ]
396
396
rotated_base = self .rotate_layer (base )
397
397
rotated_source = self .rotate_layer (source )
@@ -431,7 +431,7 @@ def get_temperature(self):
431
431
def set_temperature (self , temp : torch .Tensor ):
432
432
self .temperature .data = temp
433
433
434
- def forward (self , base , source , subspaces = None , ** kwargs ):
434
+ def forward (self , base , source , subspaces = None ):
435
435
batch_size = base .shape [0 ]
436
436
# get boundary mask between 0 and 1 from sigmoid
437
437
mask_sigmoid = torch .sigmoid (self .mask / torch .tensor (self .temperature ))
@@ -456,7 +456,7 @@ def __init__(self, **kwargs):
456
456
rotate_layer = LowRankRotateLayer (self .embed_dim , kwargs ["low_rank_dimension" ])
457
457
self .rotate_layer = torch .nn .utils .parametrizations .orthogonal (rotate_layer )
458
458
459
- def forward (self , base , source , subspaces = None , ** kwargs ):
459
+ def forward (self , base , source , subspaces = None ):
460
460
rotated_base = self .rotate_layer (base )
461
461
rotated_source = self .rotate_layer (source )
462
462
if subspaces is not None :
@@ -529,7 +529,7 @@ def __init__(self, **kwargs):
529
529
)
530
530
self .trainable = False
531
531
532
- def forward (self , base , source , subspaces = None , ** kwargs ):
532
+ def forward (self , base , source , subspaces = None ):
533
533
base_norm = (base - self .pca_mean ) / self .pca_std
534
534
source_norm = (source - self .pca_mean ) / self .pca_std
535
535
@@ -565,7 +565,7 @@ def __init__(self, **kwargs):
565
565
prng (1 , 4 , self .embed_dim )))
566
566
self .register_buffer ('noise_level' , torch .tensor (noise_level ))
567
567
568
- def forward (self , base , source = None , subspaces = None , ** kwargs ):
568
+ def forward (self , base , source = None , subspaces = None ):
569
569
base [..., : self .interchange_dim ] += self .noise * self .noise_level
570
570
return base
571
571
@@ -585,7 +585,7 @@ def __init__(self, **kwargs):
585
585
self .autoencoder = AutoencoderLayer (
586
586
self .embed_dim , kwargs ["latent_dim" ])
587
587
588
- def forward (self , base , source , subspaces = None , ** kwargs ):
588
+ def forward (self , base , source , subspaces = None ):
589
589
base_dtype = base .dtype
590
590
base = base .to (self .autoencoder .encoder [0 ].weight .dtype )
591
591
base_latent = self .autoencoder .encode (base )
@@ -619,7 +619,7 @@ def encode(self, input_acts):
619
619
def decode (self , acts ):
620
620
return acts @ self .W_dec + self .b_dec
621
621
622
- def forward (self , base , source = None , subspaces = None , ** kwargs ):
622
+ def forward (self , base , source = None , subspaces = None ):
623
623
# generate latents for base and source runs.
624
624
base_latent = self .encode (base )
625
625
source_latent = self .encode (source )
0 commit comments