Skip to content

Commit 74d24ea

Browse files
q10facebook-github-bot
authored andcommitted
Fix ROCm test reliability (#4385)
Summary: X-link: facebookresearch/FBGEMM#1455 - Fix ROCm test reliability in quantize_test.py Pull Request resolved: #4385 Reviewed By: spcyppt Differential Revision: D77038656 Pulled By: q10 fbshipit-source-id: 309843fdbfa7ff7df92d7ef7a6f08bbb124c2f49
1 parent 95bae74 commit 74d24ea

File tree

1 file changed

+17
-20
lines changed

1 file changed

+17
-20
lines changed

fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,21 @@ def pack_int4(x: torch.Tensor) -> torch.Tensor:
120120
return torch.bitwise_or(low_x, high_x).contiguous()
121121

122122

123+
def sample_scales() -> st.SearchStrategy[Optional[torch.Tensor]]:
124+
return st.sampled_from(
125+
[
126+
None,
127+
torch.tensor(
128+
[1.0],
129+
dtype=torch.float,
130+
device=torch.accelerator.current_accelerator(),
131+
),
132+
]
133+
if torch.cuda.is_available()
134+
else [None]
135+
)
136+
137+
123138
@unittest.skipIf(
124139
not torch.cuda.is_available(),
125140
"Skip when no GPU is available. This test is only for GPU.",
@@ -1678,26 +1693,8 @@ class NVFP4Tests(unittest.TestCase):
16781693
B_T=st.sampled_from([2048, 4096]),
16791694
D=st.sampled_from([128, 256]),
16801695
HD_L=st.sampled_from([256, 512]),
1681-
static_scale=st.sampled_from(
1682-
[
1683-
None,
1684-
torch.tensor(
1685-
[1.0],
1686-
dtype=torch.float,
1687-
device=torch.accelerator.current_accelerator(),
1688-
),
1689-
]
1690-
),
1691-
scale_ub=st.sampled_from(
1692-
[
1693-
None,
1694-
torch.tensor(
1695-
[1.0],
1696-
dtype=torch.float,
1697-
device=torch.accelerator.current_accelerator(),
1698-
),
1699-
]
1700-
),
1696+
static_scale=sample_scales(),
1697+
scale_ub=sample_scales(),
17011698
)
17021699
def test_fake_quantize_nvfp4_per_tensor(
17031700
self,

0 commit comments

Comments
 (0)