From 47c742ff94b21dbe2de35ea14fca17d6632f8f73 Mon Sep 17 00:00:00 2001 From: Alex Rogozhnikov Date: Tue, 4 Feb 2025 02:26:36 -0800 Subject: [PATCH] rename input_dict, document parameter (#363) --- einops/_backends.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/einops/_backends.py b/einops/_backends.py index 8215977c..3d2f9160 100644 --- a/einops/_backends.py +++ b/einops/_backends.py @@ -77,7 +77,8 @@ def to_numpy(self, x): def create_symbol(self, shape): raise NotImplementedError("framework doesn't support symbolic computations") - def eval_symbol(self, symbol, input_dict): + def eval_symbol(self, symbol, symbol_value_pairs): + # symbol-value pairs is list[tuple[symbol, value-tensor]] raise NotImplementedError("framework doesn't support symbolic computations") def arange(self, start, stop): @@ -431,9 +432,9 @@ def is_appropriate_type(self, tensor): def create_symbol(self, shape): return self.keras.Input(batch_shape=shape) - def eval_symbol(self, symbol, input_dict): - model = self.keras.models.Model([var for (var, _) in input_dict], symbol) - return model.predict_on_batch([val for (_, val) in input_dict]) + def eval_symbol(self, symbol, symbol_value_pairs): + model = self.keras.models.Model([var for (var, _) in symbol_value_pairs], symbol) + return model.predict_on_batch([val for (_, val) in symbol_value_pairs]) def arange(self, start, stop): return self.K.arange(start, stop) @@ -689,9 +690,8 @@ def create_symbol(self, shape): shape = (shape,) return self.pt.tensor(shape=shape) - def eval_symbol(self, symbol, input_dict): - # input_dict is actually a list of tuple? - return symbol.eval(dict(input_dict)) + def eval_symbol(self, symbol, symbol_value_pairs): + return symbol.eval(dict(symbol_value_pairs)) def arange(self, start, stop): return self.pt.arange(start, stop)