diff --git a/python/oneflow/nn/modules/normalization.py b/python/oneflow/nn/modules/normalization.py index 09f73cf8995..592cd5502dd 100644 --- a/python/oneflow/nn/modules/normalization.py +++ b/python/oneflow/nn/modules/normalization.py @@ -137,8 +137,12 @@ def __init__( if dtype: factory_kwargs["dtype"] = dtype if self.affine: - self.weight = flow.nn.Parameter(flow.Tensor(num_channels, **factory_kwargs)) - self.bias = flow.nn.Parameter(flow.Tensor(num_channels, **factory_kwargs)) + self.weight = flow.nn.Parameter( + flow.Tensor(num_channels).to(**factory_kwargs) + ) + self.bias = flow.nn.Parameter( + flow.Tensor(num_channels).to(**factory_kwargs) + ) else: self.register_parameter("weight", None) self.register_parameter("bias", None)