diff --git a/python/oneflow/test/modules/test_fused_rotary_embedding.py b/python/oneflow/test/modules/test_fused_rotary_embedding.py index efa9fb90750..d671ac6422e 100644 --- a/python/oneflow/test/modules/test_fused_rotary_embedding.py +++ b/python/oneflow/test/modules/test_fused_rotary_embedding.py @@ -1387,7 +1387,7 @@ def test_fused_rotary_embedding_op_plane(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [_test_plane] # args_dict["x_layout"] = ["MB(H3K)"] - args_dict["x_layout"] = ["BMHK", "MB(HK)"] # TODO: MB(H3K) bug; + args_dict["x_layout"] = ["MB(HK)"] # TODO: MB(H3K) bug; args_dict["mode"] = ["plane"] args_dict["base"] = [1e1] args_dict["rotary_size"] = [4, 8] @@ -1401,7 +1401,6 @@ def test_fused_rotary_embedding_op_plane(test_case): for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) - """ TODO: interval mode grad kernel def test_fused_rotary_embedding_op_interval_2d(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [_test_with_position, @@ -1420,16 +1419,14 @@ def test_fused_rotary_embedding_op_interval_2d(test_case): for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) - """ - """ TODO: interval mode grad kernel def test_fused_rotary_embedding_op_interval_1d(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [ - #_test_without_position_sinuous, + _test_without_position_sinuous, _test_without_position, - #_test_with_position, - #_test_with_position_sinuous, + _test_with_position, + _test_with_position_sinuous, ] args_dict["x_layout"] = ["BMHK"] args_dict["mode"] = ["interval"] @@ -1444,7 +1441,6 @@ def test_fused_rotary_embedding_op_interval_1d(test_case): for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) - """ if __name__ == "__main__":