- Add
ENABLE_DSIPATCH environment
- use the primitive wrapper
import mindspore
from mindspore.ops import Primitive
from mindspore.ops._primitive_cache import _get_cache_prim
from mindspore import ops
def hook_call(instance):
# instance.set_device('CPU')
def wrapped_call(*args, **kwargs):
if USE_DISPATCH:
if getattr(instance, 'primitive_target', 'CPU') != args[0].device:
instance.set_device(args[0].device[:-2])
print("【装饰器Hook】调用前")
result = instance(*args, **kwargs)
print("【装饰器Hook】调用后")
return result
return wrapped_call
x = ops.randn(3)
USE_DISPATCH = False
@mindspore.jit(fullgraph=True, backend='GE')
def fn(x):
add_op = ops.Add()
# print(add_op.primitive_target)
add = hook_call(add_op)
# print(add_op.primitive_target)
y = add(x, 1)
return y
y = fn(x)
print(y, y.device)