Open
Description
What a nice tools! However, current codes has trouble dealing with logic expression with more than 2 terms, for example:
x, y, z = sympy.symbols("x,y,z")
expr =x | y | z
model = SymPyModule(expressions=[expr])
terms = {
"x": torch.randint(0, 2, (10,)).bool(),
"y": torch.randint(0, 2, (10,)).bool(),
"z": torch.randint(0, 2, (10,)).bool(),
}
result = model(**terms)
Then it will raise an error:
TypeError: logical_or() takes 2 positional arguments but 3 were given
Here's the solution: modify
sympytorch/sympytorch/sympy_module.py
Lines 52 to 53 in ca3e3f4
to
sympy.And: _reduce(torch.logical_and),
sympy.Or: _reduce(torch.logical_or),
, which is the same as mul
and add
.
Also, besides the modification above, I suggest two more improvements:
- make this function public
sympytorch/sympytorch/sympy_module.py
Lines 7 to 10 in ca3e3f4
- modify
sympytorch/sympytorch/sympy_module.py
Line 158 in ca3e3f4
to
_func_lookup = co.ChainMap(extra_funcs, _global_func_lookup)
so that users can overload default lookup tables.
Metadata
Metadata
Assignees
Labels
No labels