Skip to content

Commit f9ca748

Browse files
authored
[CI] Speed up slow tests in tests-gpu/tests-cpu (#3395)
1 parent 519b92a commit f9ca748

File tree

4 files changed

+94
-24
lines changed

4 files changed

+94
-24
lines changed

.github/unittest/linux/scripts/run_all.sh

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ uv_pip_install \
115115
pytest-forked \
116116
pytest-asyncio \
117117
pytest-isolate \
118+
pytest-xdist \
118119
expecttest \
119120
"pybind11[global]>=2.13" \
120121
pyyaml \
@@ -285,18 +286,65 @@ run_distributed_tests() {
285286
echo "TORCHRL_TEST_SUITE=${TORCHRL_TEST_SUITE}: distributed tests require GPU (CU_VERSION != cpu)."
286287
return 1
287288
fi
288-
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_distributed.py \
289+
# Run both test_distributed.py and test_rb_distributed.py (both use torch.distributed)
290+
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_distributed.py test/test_rb_distributed.py \
289291
--instafail --durations 200 -vv --capture no \
290292
--timeout=120 --mp_fork_if_no_cuda
291293
}
292294

293295
run_non_distributed_tests() {
294296
# Note: we always ignore distributed tests here (they can be run in a separate job).
295-
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
296-
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \
297-
--ignore test/test_distributed.py \
298-
--ignore test/llm \
299-
--timeout=120 --mp_fork_if_no_cuda
297+
# Also ignore test_setup.py as it's tested in the dedicated test-setup-minimal job.
298+
#
299+
# Test sharding: Split tests into groups for parallel execution.
300+
# TORCHRL_TEST_SHARD can be: "all" (default), "1", "2", or "3"
301+
# - Shard 1: test_transforms.py (heaviest file, 571 parametrize decorators)
302+
# - Shard 2: test_envs.py, test_collectors.py (multiprocessing-heavy)
303+
# - Shard 3: Everything else (can use pytest-xdist for parallelism)
304+
local shard="${TORCHRL_TEST_SHARD:-all}"
305+
local common_ignores="--ignore test/test_rlhf.py --ignore test/test_distributed.py --ignore test/test_rb_distributed.py --ignore test/llm --ignore test/test_setup.py"
306+
local common_args="--instafail --durations 200 -vv --capture no --timeout=120 --mp_fork_if_no_cuda"
307+
308+
# pytest-xdist parallelism: use -n auto for shard 3 (fewer multiprocessing tests)
309+
# Set TORCHRL_XDIST=0 to disable parallel execution
310+
local xdist_args=""
311+
if [ "${TORCHRL_XDIST:-1}" = "1" ] && [ "${shard}" = "3" ]; then
312+
xdist_args="-n auto --dist loadgroup"
313+
echo "Using pytest-xdist for parallel execution"
314+
fi
315+
316+
case "${shard}" in
317+
1)
318+
echo "Running shard 1: test_transforms.py only"
319+
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_transforms.py \
320+
${common_args}
321+
;;
322+
2)
323+
echo "Running shard 2: test_envs.py and test_collectors.py"
324+
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_envs.py test/test_collectors.py \
325+
${common_args}
326+
;;
327+
3)
328+
echo "Running shard 3: All other tests"
329+
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
330+
${common_ignores} \
331+
--ignore test/test_transforms.py \
332+
--ignore test/test_envs.py \
333+
--ignore test/test_collectors.py \
334+
${xdist_args} \
335+
${common_args}
336+
;;
337+
all|"")
338+
echo "Running all tests (no sharding)"
339+
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
340+
${common_ignores} \
341+
${common_args}
342+
;;
343+
*)
344+
echo "Unknown TORCHRL_TEST_SHARD='${shard}'. Expected: all|1|2|3."
345+
exit 2
346+
;;
347+
esac
300348
}
301349

302350
case "${TORCHRL_TEST_SUITE}" in

.github/workflows/test-linux.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ jobs:
7777
matrix:
7878
python_version: ["3.12"]
7979
cuda_arch_version: ["13.0"]
80+
# Test sharding: split tests into 3 parallel jobs for faster execution
81+
# Shard 1: test_transforms.py (heaviest)
82+
# Shard 2: test_envs.py + test_collectors.py (multiprocessing-heavy)
83+
# Shard 3: all other tests
84+
shard: ["1", "2", "3"]
8085
fail-fast: false
8186
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
8287
with:
@@ -103,12 +108,14 @@ jobs:
103108
104109
# Run everything except distributed tests; those run in parallel in tests-gpu-distributed.
105110
export TORCHRL_TEST_SUITE=nondistributed
111+
export TORCHRL_TEST_SHARD=${{ matrix.shard }}
106112
107113
# Remove the following line when the GPU tests are working inside docker, and uncomment the above lines
108114
#export CU_VERSION="cpu"
109115
110116
echo "PYTHON_VERSION: $PYTHON_VERSION"
111117
echo "CU_VERSION: $CU_VERSION"
118+
echo "TORCHRL_TEST_SHARD: $TORCHRL_TEST_SHARD"
112119
113120
## setup_env.sh
114121
bash .github/unittest/linux/scripts/run_all.sh
@@ -227,6 +234,8 @@ jobs:
227234
matrix:
228235
python_version: ["3.12"] # "3.9", "3.10", "3.11"
229236
cuda_arch_version: ["13.0"] # "11.6", "11.7"
237+
# Test sharding: split tests into 3 parallel jobs for faster execution
238+
shard: ["1", "2", "3"]
230239
fail-fast: false
231240
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
232241
with:
@@ -260,6 +269,9 @@ jobs:
260269
261270
# Run everything except distributed tests; those run in parallel in tests-stable-gpu-distributed.
262271
export TORCHRL_TEST_SUITE=nondistributed
272+
export TORCHRL_TEST_SHARD=${{ matrix.shard }}
273+
274+
echo "TORCHRL_TEST_SHARD: $TORCHRL_TEST_SHARD"
263275
264276
## setup_env.sh
265277
bash .github/unittest/linux/scripts/run_all.sh

test/conftest.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,36 @@
1818
IS_OSX = sys.platform == "darwin"
1919

2020

21-
def pytest_sessionfinish(maxprint=50):
22-
out_str = """
23-
Call times:
24-
===========
25-
"""
21+
def pytest_sessionfinish(session, exitstatus, maxprint=50):
22+
"""Print aggregated test times per function (across all parametrizations)."""
2623
keys = list(CALL_TIMES.keys())
27-
if len(keys) > 1:
28-
maxchar = max(*[len(key) for key in keys])
29-
elif len(keys) == 1:
30-
maxchar = len(keys[0])
31-
else:
24+
if not keys:
3225
return
26+
27+
# Calculate total time
28+
total_time = sum(CALL_TIMES.values())
29+
30+
out_str = f"""
31+
================================================================================
32+
AGGREGATED TEST TIMES (by function, across all parametrizations)
33+
================================================================================
34+
Total test time: {total_time:.1f}s ({total_time/60:.1f} min)
35+
Top {min(maxprint, len(keys))} slowest test functions:
36+
--------------------------------------------------------------------------------
37+
"""
38+
maxchar = max(len(key) for key in keys)
3339
for i, (key, item) in enumerate(
3440
sorted(CALL_TIMES.items(), key=lambda x: x[1], reverse=True)
3541
):
36-
spaces = " " + " " * (maxchar - len(key))
37-
out_str += f"\t{key}{spaces}{item: 4.4f}s\n"
42+
spaces = " " * (maxchar - len(key) + 2)
43+
pct = (item / total_time) * 100 if total_time > 0 else 0
44+
out_str += f" {key}{spaces}{item:7.2f}s ({pct:5.1f}%)\n"
3845
if i == maxprint - 1:
3946
break
4047

48+
out_str += "================================================================================\n"
49+
sys.stdout.write(out_str)
50+
4151

4252
@pytest.fixture(autouse=True)
4353
def measure_duration(request: pytest.FixtureRequest):

test/test_collectors.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -759,8 +759,8 @@ def env_fn(seed):
759759
create_env_kwargs={"seed": seed},
760760
policy=policy,
761761
frames_per_batch=20,
762-
max_frames_per_traj=2000,
763-
total_frames=20000,
762+
max_frames_per_traj=200,
763+
total_frames=200,
764764
device="cpu",
765765
)
766766
torchrl_logger.info("Loop")
@@ -932,7 +932,7 @@ def _set_seed(self, seed: Optional[int]) -> None:
932932
result = subprocess.run(
933933
["python", "-c", script], capture_output=True, text=True
934934
)
935-
# This errors if the timeout is 5 secs, not 15
935+
# This errors if the timeout is too short (3), succeeds if long enough (10)
936936
assert result.returncode == int(
937937
to == 3
938938
), f"Test failed with output: {result.stdout}"
@@ -1136,7 +1136,7 @@ def make_and_test_policy(
11361136
c = collector_type(
11371137
envs,
11381138
policy=policy,
1139-
total_frames=1000,
1139+
total_frames=100,
11401140
frames_per_batch=10,
11411141
policy_device=policy_device,
11421142
env_device=env_device,
@@ -1779,7 +1779,7 @@ def _reset(self, tensordict: TensorDict | None = None, **kwargs) -> TensorDict:
17791779
# Random sleep up to 10ms
17801780
time.sleep(torch.rand(1).item() * 0.01)
17811781
elif self.env_id % 2 == 1:
1782-
time.sleep(1)
1782+
time.sleep(0.1)
17831783

17841784
self._step_count = 0
17851785
return TensorDict(
@@ -1800,7 +1800,7 @@ def _step(self, tensordict: TensorDict) -> TensorDict:
18001800
done = self._step_count >= self.max_steps
18011801

18021802
if self.sleep_odd_only and self.env_id % 2 == 1:
1803-
time.sleep(1)
1803+
time.sleep(0.1)
18041804

18051805
return TensorDict(
18061806
{

0 commit comments

Comments
 (0)