Skip to content

Commit

Permalink
rename input_dict, document parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
arogozhnikov committed Feb 4, 2025
1 parent 43c11f8 commit 728beec
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions einops/_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def to_numpy(self, x):
def create_symbol(self, shape):
raise NotImplementedError("framework doesn't support symbolic computations")

def eval_symbol(self, symbol, input_dict):
def eval_symbol(self, symbol, symbol_value_pairs):
# symbol-value pairs is list[tuple[symbol, value-tensor]]
raise NotImplementedError("framework doesn't support symbolic computations")

def arange(self, start, stop):
Expand Down Expand Up @@ -431,9 +432,9 @@ def is_appropriate_type(self, tensor):
def create_symbol(self, shape):
return self.keras.Input(batch_shape=shape)

def eval_symbol(self, symbol, input_dict):
model = self.keras.models.Model([var for (var, _) in input_dict], symbol)
return model.predict_on_batch([val for (_, val) in input_dict])
def eval_symbol(self, symbol, symbol_value_pairs):
model = self.keras.models.Model([var for (var, _) in symbol_value_pairs], symbol)
return model.predict_on_batch([val for (_, val) in symbol_value_pairs])

def arange(self, start, stop):
return self.K.arange(start, stop)
Expand Down Expand Up @@ -689,9 +690,8 @@ def create_symbol(self, shape):
shape = (shape,)
return self.pt.tensor(shape=shape)

def eval_symbol(self, symbol, input_dict):
# input_dict is actually a list of tuple?
return symbol.eval(dict(input_dict))
def eval_symbol(self, symbol, symbol_value_pairs):
return symbol.eval(dict(symbol_value_pairs))

def arange(self, start, stop):
return self.pt.arange(start, stop)
Expand Down

0 comments on commit 728beec

Please sign in to comment.