Skip to content

Commit 40f7972

Browse files
authored
fixing dtype promotion in where (#1734)
1 parent 52ec9a5 commit 40f7972

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

thunder/executors/nvfuserex_impl.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -1965,7 +1965,18 @@ def where(
19651965
nva = getnv(a, fd, lc_to_nv_map)
19661966
nvb = getnv(b, fd, lc_to_nv_map)
19671967

1968-
return fd.ops.where(nvpred, nva, nvb)
1968+
# explicit type promotion is necessary, since nvfuser can't do this properly with scalar inputs. See
1969+
# issue: https://github.com/NVIDIA/Fuser/issues/3816
1970+
# Determines result dtype
1971+
numbertype, tensordtype = utils.check_same_dtype(a, b)
1972+
dtype = tensordtype if tensordtype is not None else numbertype
1973+
1974+
# NOTE: for scalar inputs, dtype mapping is different. e.g. float -> double. We convert dtypes to strong
1975+
# type if the output is supposed to be a tensor proxy
1976+
if any(map(lambda x: isinstance(x, TensorProxy), (pred, a, b))):
1977+
dtype = dtypes.to_strong_dtype(dtype)
1978+
1979+
return fd.ops.cast(fd.ops.where(nvpred, nva, nvb), lcdtype_to_nvdtype(dtype))
19691980

19701981

19711982
register_supported(PrimIDs.WHERE, where, _elementwise_ternary_check)

thunder/tests/opinfos.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -2537,18 +2537,19 @@ def div_sample_generator(op, device, dtype, requires_grad, **kwargs):
25372537
torch_reference=torch.div,
25382538
test_directives=(
25392539
# NOTE: PyTorch doesn't support boolean division
2540-
# TODO: fix dtype mismatch when using nvfuser executors
25412540
DecorateInfo(
25422541
pytest.mark.xfail,
25432542
"test_core_vs_torch_consistency",
25442543
dtypes=(datatypes.bool8,),
25452544
devicetypes=(devices.DeviceType.CPU, devices.DeviceType.CUDA),
25462545
),
2546+
# NOTE: bfloat16 and float16 is skipped
2547+
# See: https://github.com/Lightning-AI/lightning-thunder/issues/1724
25472548
DecorateInfo(
25482549
pytest.mark.xfail,
25492550
"test_core_vs_torch_consistency",
25502551
executors=("nvfuser",),
2551-
dtypes=(datatypes.bool8, datatypes.bfloat16, datatypes.float16, datatypes.float32),
2552+
dtypes=(datatypes.bool8, datatypes.bfloat16, datatypes.float16),
25522553
),
25532554
DecorateInfo(pytest.mark.xfail, "test_vjp_correctness"),
25542555
),
@@ -2718,6 +2719,17 @@ def where_sample_generator(op, device, dtype, requires_grad, **kwargs):
27182719
pred, a, b = make(pred_shape, dtype=torch.bool, requires_grad=False), make(a_shape), make(b_shape)
27192720
yield SampleInput(pred, a, b)
27202721

2722+
# NOTE: requires_grad needs tensor inputs on non-pred.
2723+
if not requires_grad:
2724+
# generate scalar inputs
2725+
dtypes = [float, int, bool, complex]
2726+
2727+
for dtype in dtypes:
2728+
pred = make([2, 3], dtype=torch.bool, requires_grad=False)
2729+
a = dtype(1.0)
2730+
b = dtype(0.0)
2731+
yield SampleInput(pred, a, b)
2732+
27212733

27222734
def where_error_generator(op, device, dtype=torch.float32, **kwargs):
27232735
make = partial(make_tensor, device=device, dtype=dtype)

0 commit comments

Comments
 (0)