Skip to content

Commit 674a978

Browse files
authored
[Tests] Fix tests 4th ed. (#220)
1. Record tests conditions in log file. 2. Fix op name. 3. Print mode in benchmark results. 4. Check pytest.mark.{OP_NAME} in CI. --------- Co-authored-by: zhengyang <[email protected]>
1 parent f4b2495 commit 674a978

File tree

7 files changed

+214
-24
lines changed

7 files changed

+214
-24
lines changed

OperatorList.md

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,29 +29,28 @@
2929
- mean
3030

3131
## v2.0
32-
33-
- mv
3432
- all
3533
- any
3634
- bitwise_and
3735
- bitwise_not
3836
- bitwise_or
3937
- cos
38+
- clamp
4039
- eq
4140
- ge
4241
- gt
42+
- mv
4343
- isinf
4444
- isnan
4545
- le
4646
- lt
4747
- ne
4848
- neg
49-
- or
49+
- or_
5050
- sin
5151
- tanh
5252
- amax
5353
- argmax
54-
- clamp
5554
- max
5655
- min
5756
- outer
@@ -66,13 +65,12 @@
6665
- sigmoid
6766

6867
## v3.0
69-
7068
- _conv_depthwise2d
7169
- _convolution
7270
- conv1d
7371
- conv2d
7472
- convolution
75-
- cudnn_convolution
73+
- cudnn_convolution (N/A)
7674
- multinomial
7775
- nonzero
7876
- normal
@@ -99,12 +97,11 @@
9997
- resolve_neg
10098
- arange
10199
- cat
102-
- chunk
103-
- chunk
100+
- chunk (N/A)
104101
- concat
105102
- constant_pad_nd
106-
- contiguous
107-
- copy_
103+
- contiguous (N/A)
104+
- copy_ (N/A)
108105
- fill
109106
- flip
110107
- full
@@ -117,30 +114,48 @@
117114
- narrow
118115
- ones
119116
- pad
120-
- permute
117+
- permute (N/A)
121118
- repeat
122119
- repeat_interleave
123-
- resize
120+
- resize_ (N/A)
124121
- scatter
125-
- select
122+
- index_add
123+
- select (N/A)
126124
- select_scatter
127125
- slice
128126
- slice_scatter
129127
- sort
130-
- split
131-
- split_with_sizes
128+
- split (N/A)
129+
- split_with_sizes (N/A)
132130
- stack
133131
- tile
134-
- transpose
135-
- unfold
132+
- transpose (N/A)
133+
- unfold (N/A)
136134
- where
137135
- zeros
138136

139137
## v4.0
140-
141138
- rms_norm
142-
- skip_layernorm
143-
- skip_rmsnorm
144-
- apply_rotary_position_embedding
145-
- silu_and_mul
139+
- skip_rms_norm
140+
- skip_layer_norm
146141
- gelu_and_mul
142+
- silu_and_mul
143+
- apply_rotary_pos_emb
144+
145+
## v5.0
146+
- unique
147+
- isin
148+
- allclose
149+
- isclose
150+
- isfinite
151+
- exponential_
152+
- vstack
153+
- minimum
154+
- maximum
155+
- ones_like
156+
- zeros_like
157+
- randn_like
158+
- floor_divide
159+
- masked_select
160+
- trunc_divide
161+
- remainder

benchmark/performance_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,12 @@ def profile(self, op, *args, **kwargs):
6666
return latency
6767

6868
def run(self):
69+
mode_str = "cpu" if CPU_MODE else "cuda"
70+
print("")
6971
for dtype in self.dtypes:
70-
print(f"Operator {self.op_name} Performance Test ({dtype})")
72+
print(
73+
f"Operator {self.op_name} Performance Test (dtype={dtype}, mode={mode_str})"
74+
)
7175
print("Size Torch Latency (ms) Gems Latency (ms) Gems Speedup")
7276
print("---------------------------------------------------------------")
7377
for size in self.sizes:

tests/conftest.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import json
2+
import logging
3+
4+
15
def pytest_addoption(parser):
26
parser.addoption(
37
"--ref",
@@ -15,6 +19,14 @@ def pytest_addoption(parser):
1519
choices=["normal", "quick"],
1620
help="run tests on normal or quick mode",
1721
)
22+
parser.addoption(
23+
"--record",
24+
action="store",
25+
default="none",
26+
required=False,
27+
choices=["none", "log"],
28+
help="tests function param recorded in log files or not",
29+
)
1830

1931

2032
def pytest_configure(config):
@@ -23,3 +35,60 @@ def pytest_configure(config):
2335

2436
global QUICK_MODE
2537
QUICK_MODE = config.getoption("--mode") == "quick"
38+
39+
global RECORD_LOG
40+
RECORD_LOG = config.getoption("--record") == "log"
41+
if RECORD_LOG:
42+
global RUNTEST_INFO, BUILTIN_MARKS, REGISTERED_MARKERS
43+
RUNTEST_INFO = {}
44+
BUILTIN_MARKS = {
45+
"parametrize",
46+
"skip",
47+
"skipif",
48+
"xfail",
49+
"usefixtures",
50+
"filterwarnings",
51+
"timeout",
52+
"tryfirst",
53+
"trylast",
54+
}
55+
REGISTERED_MARKERS = {
56+
marker.split(":")[0].strip() for marker in config.getini("markers")
57+
}
58+
cmd_args = [
59+
arg.replace(".py", "").replace("=", "_").replace("/", "_")
60+
for arg in config.invocation_params.args
61+
]
62+
logging.basicConfig(
63+
filename="result_{}.log".format("_".join(cmd_args)).replace("_-", "-"),
64+
filemode="w",
65+
level=logging.INFO,
66+
format="[%(levelname)s] %(message)s",
67+
)
68+
69+
70+
def pytest_runtest_teardown(item, nextitem):
71+
if not RECORD_LOG:
72+
return
73+
if hasattr(item, "callspec"):
74+
all_marks = list(item.iter_markers())
75+
op_marks = [
76+
mark.name
77+
for mark in all_marks
78+
if mark.name not in BUILTIN_MARKS and mark.name not in REGISTERED_MARKERS
79+
]
80+
if len(op_marks) > 0:
81+
params = str(item.callspec.params)
82+
for op_mark in op_marks:
83+
if op_mark not in RUNTEST_INFO:
84+
RUNTEST_INFO[op_mark] = [params]
85+
else:
86+
RUNTEST_INFO[op_mark].append(params)
87+
else:
88+
func_name = item.function.__name__
89+
logging.warning("There is no mark at {}".format(func_name))
90+
91+
92+
def pytest_sessionfinish(session, exitstatus):
93+
if RECORD_LOG:
94+
logging.info(json.dumps(RUNTEST_INFO, indent=2))

tests/test_binary_pointwise_ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def test_accuracy_bitwiseand_scalar_tensor(shape, dtype):
142142
gems_assert_equal(res_out, ref_out)
143143

144144

145+
@pytest.mark.or_
145146
@pytest.mark.bitwise_or
146147
@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
147148
@pytest.mark.parametrize("dtype", INT_DTYPES + BOOL_TYPES)
@@ -166,6 +167,7 @@ def test_accuracy_bitwiseor(shape, dtype):
166167
gems_assert_equal(res_out, ref_out)
167168

168169

170+
@pytest.mark.or_
169171
@pytest.mark.bitwise_or
170172
@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
171173
@pytest.mark.parametrize("dtype", INT_DTYPES + BOOL_TYPES)
@@ -187,6 +189,7 @@ def test_accuracy_bitwiseor_scalar(shape, dtype):
187189
gems_assert_equal(res_out, ref_out)
188190

189191

192+
@pytest.mark.or_
190193
@pytest.mark.bitwise_or
191194
@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
192195
@pytest.mark.parametrize("dtype", INT_DTYPES + BOOL_TYPES)

tests/test_special_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,7 @@ def test_accuracy_cat(shape, dim, dtype):
668668
]
669669

670670

671+
@pytest.mark.vstack
671672
@pytest.mark.parametrize("shape", VSTACK_SHAPES)
672673
@pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES)
673674
def test_accuracy_vstack(shape, dtype):

tools/op-unit-test.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ PR_ID_DIR="PR${PR_ID}"
1818

1919
COVERAGE_ARGS="--parallel-mode --omit "*/.flaggems/*","*/usr/lib/*" --source=./src,./tests --data-file=${ID_SHA}-op"
2020
cmds=(
21+
"bash tools/pytest_mark_check.sh &"
2122
"CUDA_VISIBLE_DEVICES=0 coverage run ${COVERAGE_ARGS} -m pytest -s tests/test_blas_ops.py &"
2223
"CUDA_VISIBLE_DEVICES=1 coverage run ${COVERAGE_ARGS} -m pytest -s tests/test_reduction_ops.py &"
2324
"CUDA_VISIBLE_DEVICES=2 coverage run ${COVERAGE_ARGS} -m pytest -s tests/test_general_reduction_ops.py &"
@@ -31,7 +32,8 @@ cmds=(
3132
CUDA_VISIBLE_DEVICES=6 coverage run ${COVERAGE_ARGS} -m pytest -s tests/test_libentry.py && \
3233
CUDA_VISIBLE_DEVICES=6 coverage run ${COVERAGE_ARGS} -m pytest -s tests/test_pointwise_dynamic.py && \
3334
CUDA_VISIBLE_DEVICES=6 coverage run ${COVERAGE_ARGS} -m pytest -s tests/test_shape_utils.py && \
34-
CUDA_VISIBLE_DEVICES=6 coverage run ${COVERAGE_ARGS} -m pytest -s tests/test_tensor_wrapper.py &"
35+
CUDA_VISIBLE_DEVICES=6 coverage run ${COVERAGE_ARGS} -m pytest -s tests/test_tensor_wrapper.py && \
36+
CUDA_VISIBLE_DEVICES=6 coverage run ${COVERAGE_ARGS} -m pytest -s tests/test_unary_pointwise_ops.py -m abs --record=log &"
3537
"CUDA_VISIBLE_DEVICES=7 coverage run ${COVERAGE_ARGS} -m pytest -s tests/test_blas_ops.py --ref=cpu --mode=quick && \
3638
CUDA_VISIBLE_DEVICES=7 coverage run ${COVERAGE_ARGS} -m pytest -s tests/test_reduction_ops.py --ref=cpu --mode=quick && \
3739
CUDA_VISIBLE_DEVICES=7 coverage run ${COVERAGE_ARGS} -m pytest -s tests/test_general_reduction_ops.py --ref=cpu --mode=quick && \

tools/pytest_mark_check.sh

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#!/bin/bash
2+
3+
set -e
4+
5+
if [ -z "$BASH_VERSION" ]; then
6+
echo "[ERROR]This script must be run using bash!" >&2
7+
exit 1
8+
fi
9+
10+
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
11+
TEST_FILES="${SCRIPT_DIR}/../tests/test_*.py"
12+
MD_FILE="${SCRIPT_DIR}/../OperatorList.md"
13+
14+
# all
15+
grep 'pytest.mark.' ${TEST_FILES} |grep -v parametrize |grep -v 'skip(' |grep -v 'skipif(' |awk -F':@pytest.mark.' '{print $1"\t"$2}' |sort -k 2 |uniq >mark.txt
16+
17+
# unique
18+
grep 'pytest.mark.' ${TEST_FILES} |grep -v parametrize |grep -v 'skip(' |grep -v 'skipif(' |awk -F':@pytest.mark.' '{print $2}' |sort |uniq |awk '!seen[$1]++' >mark.uniq.txt
19+
20+
# recorded in MD_FILE
21+
grep '-' "${MD_FILE}" |grep -v '#' |grep -v 'N/A' |awk '{print $2}' |sort >mark.md.txt
22+
23+
MARK_NUM=`wc -l mark.txt |awk '{print $1}'`
24+
MARK_UNIQ_NUM=`wc -l mark.uniq.txt |awk '{print $1}'`
25+
MARK_MD_NUM=`wc -l mark.md.txt |awk '{print $1}'`
26+
echo "Generated successfully: mark.txt (${MARK_NUM}), mark.uniq.txt (${MARK_UNIQ_NUM}), mark.md.txt (${MARK_MD_NUM})"
27+
28+
echo "-------- diff mark.uniq.txt mark.md.txt --------"
29+
diff mark.uniq.txt mark.md.txt || true
30+
echo "------------------------------------------------"
31+
32+
TEST_OP_FILES=`ls ${TEST_FILES} |grep "_ops.py" |grep -v "test_named_ops.py"`
33+
EXCLUDED_MARKS=("pytest.mark.parametrize\(" \
34+
"pytest.mark.skip\(" \
35+
"pytest.mark.skipif\(" \
36+
"pytest.mark.xfail\(" \
37+
"pytest.mark.usefixtures\(" \
38+
"pytest.mark.filterwarnings\(" \
39+
"pytest.mark.timeout\(" \
40+
"pytest.mark.tryfirst\(" \
41+
"pytest.mark.trylast\(")
42+
43+
test_file_count=0
44+
for file in ${TEST_OP_FILES}; do
45+
echo "Checking file: ${file}"
46+
set +e
47+
48+
awk -v marks="${EXCLUDED_MARKS[*]}" '
49+
BEGIN {
50+
test_func = 0; decorated = 0; error = 0;
51+
split(marks, excluded_marks, " ")
52+
}
53+
54+
/^@pytest\.mark\./ {
55+
test_func = 1
56+
excluded = 0
57+
for (i in excluded_marks) {
58+
if ($0 ~ excluded_marks[i]) {
59+
excluded = 1
60+
break
61+
}
62+
}
63+
if (excluded == 0) {
64+
decorated = 1
65+
}
66+
next
67+
}
68+
69+
/^def / {
70+
if (test_func == 1) {
71+
if (decorated == 0) {
72+
print "[ERROR]"$0
73+
error = 1
74+
}
75+
test_func = 0
76+
decorated = 0
77+
}
78+
}
79+
80+
END {
81+
if (error == 1) {
82+
exit 1
83+
}
84+
}
85+
' "$file"
86+
87+
if [ $? -ne 0 ]; then
88+
echo "[ERROR]There are some test_op_func without 'pytest.mark.{OP_NAME}' in ${file}"
89+
exit 1
90+
fi
91+
92+
set -e
93+
test_file_count=$((test_file_count + 1))
94+
done
95+
96+
echo "Finish checking ${test_file_count} files successfully."

0 commit comments

Comments
 (0)