-
Notifications
You must be signed in to change notification settings - Fork 17
Open
Description
Functions written with the eDSL should be callable from within other computations, regardless of whether they've been wrapped with the pm.computation decorator. For example,
@pm.computation
def plus1(x: pm.Argument(alice, dtype=pm.float64):
with alice:
one = pm.constant(1, dtype=pm.float64)
return pm.add(x, one)
@pm.computation
def alice_add():
with alice:
x = pm.constant(3, dtype=pm.float64)
x_plus_one = plus1(x)
return x_plus_one
if __name__ == "__main__":
[...]
runtime.set_default()
val = alice_add() # <-- will fail during tracingWhen alice_add is called, current behavior would be the following:
- inside a runtime context,
alice_add.__call__invokestrace(alice_add) trace(alice_add)will then invokeplus1.__call__. in order for this call to succeed,plus1will need to return anExpressionto be used to trace the rest ofalice_add.- however, since the default runtime context is not None,
plus1will be executed against the default runtime'sevaluate_computationwith arguments of typeExpression - the rust runtime bindings will try to interpret these Expression
pyobj's as Moose Values, which will fail with a TypeError because these are not concrete values.
One solution for the user is to just drop the pm.computation decorator from plus1, so that it returns Expression no matter what runtime context is around. But this makes it hard for users to use "standard library" computations if they are already decorated with AbstractComputation (which would likely often be the case).
I think the simplest solution here would be to do the following:
- Inside
pm.trace, temporarily unset the default runtime context, so thatget_current_runtimereturns None. - If
AbstractComputation.__call__is invoked without a runtime context (i.e.get_current_runtimereturns None), invokeAbstractComputation.func.__call__. This invocation maps Expressions to Expressions, so tracing can proceed normally. - If
AbstractComputation.__call__is invoked inside a runtime context, invokeget_current_runtime().evaluate_computation(...)with the computation as usual
Some other options:
- Allow for nesting runtime contexts and create a new "dummy" Runtime class whose
evaluate_computationsimply forwards toAbstractComputation.func.__call__ - Something "moose-ier", e.g. accommodate Expression conversion in Moose bindings and in this case execute symbolically, i.e. run computation against a SymbolicSession instead of against the AsyncSession in AsyncTestRuntime
Metadata
Metadata
Assignees
Labels
No labels