-
Notifications
You must be signed in to change notification settings - Fork 350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Example] Add MCTS example #2796
base: gh/kurtamohler/5/base
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2796
Note: Links to docs will display an error until the docs builds have been completed. ❌ 9 New Failures, 5 Unrelated FailuresAs of commit 0e55274 with merge base 27d3680 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 5dc5cbdb68a621e14617734c386aef6e91edbda3 Pull Request resolved: #2796
This seems to work, but at the moment, it is about 100x slower than the one I implemented outside of TorchRL here. I will see what I can do to speed it up. Once I improve performance, then I'll think about how to add a good API for it |
ghstack-source-id: 512f8540518396b5beb68bb74aafaf8638f44156 Pull Request resolved: #2796
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.6670s | 0.5657s | 1.7678 Ops/s | 1.8609 Ops/s | |
test_transformed | 1.1937s | 1.1151s | 0.8968 Ops/s | 0.9544 Ops/s | |
test_serial | 1.5891s | 1.5822s | 0.6320 Ops/s | 0.6516 Ops/s | |
test_parallel | 1.4041s | 1.3072s | 0.7650 Ops/s | 0.7655 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.1281ms | 30.6692μs | 32.6060 KOps/s | 31.8336 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 48.5910μs | 18.2776μs | 54.7118 KOps/s | 56.3213 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 58.3490μs | 17.9955μs | 55.5695 KOps/s | 58.8661 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 52.4480μs | 10.5994μs | 94.3453 KOps/s | 97.4872 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 77.4040μs | 33.4255μs | 29.9173 KOps/s | 31.2275 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 47.3890μs | 20.4616μs | 48.8721 KOps/s | 50.2605 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 0.6047ms | 19.8741μs | 50.3168 KOps/s | 52.5319 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 38.8520μs | 12.4776μs | 80.1436 KOps/s | 83.1997 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 78.1160μs | 35.1778μs | 28.4270 KOps/s | 29.5983 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 63.1480μs | 22.4285μs | 44.5862 KOps/s | 45.9753 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 56.7050μs | 19.8779μs | 50.3071 KOps/s | 52.3006 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 44.2430μs | 12.3023μs | 81.2858 KOps/s | 83.5345 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 79.4780μs | 36.8559μs | 27.1327 KOps/s | 27.8850 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 65.3020μs | 24.0665μs | 41.5516 KOps/s | 42.5832 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 67.5660μs | 21.4726μs | 46.5709 KOps/s | 48.9753 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 44.1620μs | 14.1554μs | 70.6443 KOps/s | 73.3485 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 89.2260μs | 35.1502μs | 28.4493 KOps/s | 29.3951 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 53.0590μs | 22.2377μs | 44.9688 KOps/s | 46.7607 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 55.3930μs | 22.4340μs | 44.5752 KOps/s | 46.5204 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 46.6070μs | 13.8630μs | 72.1345 KOps/s | 74.9763 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 79.1870μs | 36.8129μs | 27.1644 KOps/s | 27.4907 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 2.4669ms | 24.4106μs | 40.9657 KOps/s | 42.3422 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 63.7790μs | 24.3232μs | 41.1130 KOps/s | 42.4322 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 0.5936ms | 15.7628μs | 63.4404 KOps/s | 65.9291 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 0.1398ms | 38.7365μs | 25.8154 KOps/s | 26.5864 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 62.1860μs | 25.8683μs | 38.6573 KOps/s | 39.6015 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 76.6530μs | 24.1655μs | 41.3813 KOps/s | 42.9654 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 61.3640μs | 15.6321μs | 63.9710 KOps/s | 65.6277 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 85.9110μs | 40.1345μs | 24.9162 KOps/s | 25.8407 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 92.3440μs | 27.6099μs | 36.2189 KOps/s | 37.4261 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 60.5130μs | 25.7848μs | 38.7825 KOps/s | 37.7366 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 94.3540μs | 17.2691μs | 57.9068 KOps/s | 59.4643 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 10.6862ms | 10.0041ms | 99.9587 Ops/s | 102.1639 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 27.6173ms | 24.5006ms | 40.8154 Ops/s | 37.8485 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.3514ms | 0.2258ms | 4.4283 KOps/s | 5.5216 KOps/s | |
test_values[td1_return_estimate-False-False] | 28.3776ms | 24.8935ms | 40.1711 Ops/s | 41.7830 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 27.3800ms | 24.6668ms | 40.5403 Ops/s | 37.4782 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 35.8925ms | 35.0271ms | 28.5493 Ops/s | 28.6664 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 26.7276ms | 24.6828ms | 40.5140 Ops/s | 37.3563 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 9.5037ms | 8.4806ms | 117.9166 Ops/s | 118.9131 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 2.4665ms | 1.9874ms | 503.1822 Ops/s | 524.9830 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.6568ms | 0.3774ms | 2.6496 KOps/s | 2.7049 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 49.4232ms | 45.4212ms | 22.0162 Ops/s | 20.8984 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 4.6964ms | 3.5654ms | 280.4697 Ops/s | 289.4240 Ops/s | |
test_dqn_speed[False-None] | 6.2207ms | 1.4725ms | 679.0958 Ops/s | 712.1553 Ops/s | |
test_dqn_speed[False-backward] | 2.0657ms | 1.9452ms | 514.0792 Ops/s | 527.0985 Ops/s | |
test_dqn_speed[True-None] | 0.9742ms | 0.5752ms | 1.7385 KOps/s | 1.7572 KOps/s | |
test_dqn_speed[True-backward] | 1.1054ms | 0.9898ms | 1.0103 KOps/s | 785.4052 Ops/s | |
test_dqn_speed[reduce-overhead-None] | 0.9310ms | 0.5628ms | 1.7769 KOps/s | 1.7587 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 1.0243ms | 0.9810ms | 1.0194 KOps/s | 1.0131 KOps/s | |
test_ddpg_speed[False-None] | 3.7434ms | 2.9791ms | 335.6763 Ops/s | 344.3897 Ops/s | |
test_ddpg_speed[False-backward] | 4.2266ms | 4.1262ms | 242.3543 Ops/s | 246.8492 Ops/s | |
test_ddpg_speed[True-None] | 1.9377ms | 1.4427ms | 693.1523 Ops/s | 684.2267 Ops/s | |
test_ddpg_speed[True-backward] | 2.4205ms | 2.3307ms | 429.0576 Ops/s | 417.1818 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.9018ms | 1.4416ms | 693.6696 Ops/s | 678.2458 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 2.4246ms | 2.3261ms | 429.9062 Ops/s | 422.7577 Ops/s | |
test_sac_speed[False-None] | 8.6580ms | 8.2573ms | 121.1052 Ops/s | 123.1160 Ops/s | |
test_sac_speed[False-backward] | 11.4724ms | 10.9531ms | 91.2983 Ops/s | 90.7021 Ops/s | |
test_sac_speed[True-None] | 3.3506ms | 2.5689ms | 389.2706 Ops/s | 382.8235 Ops/s | |
test_sac_speed[True-backward] | 5.2565ms | 4.2676ms | 234.3244 Ops/s | 231.5071 Ops/s | |
test_sac_speed[reduce-overhead-None] | 3.1729ms | 2.5712ms | 388.9194 Ops/s | 382.0138 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 4.3214ms | 4.2427ms | 235.6994 Ops/s | 229.2568 Ops/s | |
test_redq_speed[False-None] | 13.7217ms | 13.0737ms | 76.4894 Ops/s | 74.9000 Ops/s | |
test_redq_speed[False-backward] | 24.0607ms | 22.5596ms | 44.3270 Ops/s | 42.6657 Ops/s | |
test_redq_speed[True-None] | 7.3458ms | 6.6228ms | 150.9942 Ops/s | 142.2213 Ops/s | |
test_redq_speed[True-backward] | 16.1895ms | 14.3244ms | 69.8108 Ops/s | 67.5794 Ops/s | |
test_redq_speed[reduce-overhead-None] | 8.0604ms | 6.7333ms | 148.5151 Ops/s | 140.1701 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 15.3748ms | 14.2692ms | 70.0812 Ops/s | 69.3793 Ops/s | |
test_redq_deprec_speed[False-None] | 13.6675ms | 12.9704ms | 77.0984 Ops/s | 75.8742 Ops/s | |
test_redq_deprec_speed[False-backward] | 19.6393ms | 18.6507ms | 53.6174 Ops/s | 52.3380 Ops/s | |
test_redq_deprec_speed[True-None] | 5.8744ms | 5.1439ms | 194.4067 Ops/s | 192.1556 Ops/s | |
test_redq_deprec_speed[True-backward] | 10.5701ms | 9.9102ms | 100.9058 Ops/s | 99.7643 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 5.5362ms | 5.1505ms | 194.1567 Ops/s | 192.4768 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 10.7469ms | 9.9701ms | 100.2998 Ops/s | 99.4953 Ops/s | |
test_td3_speed[False-None] | 8.4482ms | 8.2259ms | 121.5678 Ops/s | 123.8643 Ops/s | |
test_td3_speed[False-backward] | 12.0238ms | 10.7250ms | 93.2405 Ops/s | 94.9045 Ops/s | |
test_td3_speed[True-None] | 3.7342ms | 2.4258ms | 412.2346 Ops/s | 436.2082 Ops/s | |
test_td3_speed[True-backward] | 5.8898ms | 4.1622ms | 240.2558 Ops/s | 247.2589 Ops/s | |
test_td3_speed[reduce-overhead-None] | 3.4836ms | 2.4274ms | 411.9705 Ops/s | 436.0162 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 4.7328ms | 3.9302ms | 254.4390 Ops/s | 251.6302 Ops/s | |
test_cql_speed[False-None] | 39.3408ms | 36.4926ms | 27.4028 Ops/s | 27.2499 Ops/s | |
test_cql_speed[False-backward] | 63.8771ms | 48.7278ms | 20.5221 Ops/s | 20.7983 Ops/s | |
test_cql_speed[True-None] | 23.4197ms | 22.3086ms | 44.8257 Ops/s | 43.1851 Ops/s | |
test_cql_speed[True-backward] | 29.9610ms | 29.0382ms | 34.4374 Ops/s | 33.8017 Ops/s | |
test_cql_speed[reduce-overhead-None] | 24.0368ms | 22.4104ms | 44.6221 Ops/s | 43.9143 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 30.4437ms | 29.1107ms | 34.3516 Ops/s | 33.8345 Ops/s | |
test_a2c_speed[False-None] | 7.9690ms | 7.1985ms | 138.9183 Ops/s | 136.1004 Ops/s | |
test_a2c_speed[False-backward] | 16.4940ms | 14.3265ms | 69.8009 Ops/s | 69.3241 Ops/s | |
test_a2c_speed[True-None] | 5.5012ms | 4.6716ms | 214.0586 Ops/s | 213.7634 Ops/s | |
test_a2c_speed[True-backward] | 11.6603ms | 11.1312ms | 89.8376 Ops/s | 88.2366 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 5.0690ms | 4.6514ms | 214.9881 Ops/s | 213.3272 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 12.3482ms | 11.1782ms | 89.4601 Ops/s | 88.3800 Ops/s | |
test_ppo_speed[False-None] | 8.9908ms | 7.5480ms | 132.4857 Ops/s | 132.9993 Ops/s | |
test_ppo_speed[False-backward] | 15.9991ms | 14.8850ms | 67.1818 Ops/s | 66.8864 Ops/s | |
test_ppo_speed[True-None] | 7.0669ms | 5.6722ms | 176.2993 Ops/s | 196.3270 Ops/s | |
test_ppo_speed[True-backward] | 13.1540ms | 11.9436ms | 83.7266 Ops/s | 91.1146 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 5.9280ms | 5.0496ms | 198.0340 Ops/s | 196.7899 Ops/s | |
test_ppo_speed[reduce-overhead-backward] | 12.9746ms | 11.0568ms | 90.4422 Ops/s | 90.1317 Ops/s | |
test_reinforce_speed[False-None] | 7.8501ms | 6.6475ms | 150.4328 Ops/s | 151.8951 Ops/s | |
test_reinforce_speed[False-backward] | 10.3156ms | 10.0050ms | 99.9499 Ops/s | 101.2409 Ops/s | |
test_reinforce_speed[True-None] | 4.8988ms | 4.0596ms | 246.3294 Ops/s | 242.3324 Ops/s | |
test_reinforce_speed[True-backward] | 11.3386ms | 10.0807ms | 99.1991 Ops/s | 98.2430 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 5.6285ms | 4.1157ms | 242.9698 Ops/s | 243.6155 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 10.3976ms | 10.0302ms | 99.6992 Ops/s | 99.6690 Ops/s | |
test_iql_speed[False-None] | 38.5091ms | 32.8647ms | 30.4278 Ops/s | 30.6699 Ops/s | |
test_iql_speed[False-backward] | 66.0998ms | 46.3514ms | 21.5743 Ops/s | 21.8439 Ops/s | |
test_iql_speed[True-None] | 16.8175ms | 15.7538ms | 63.4769 Ops/s | 62.1946 Ops/s | |
test_iql_speed[True-backward] | 27.6974ms | 27.0039ms | 37.0317 Ops/s | 36.8073 Ops/s | |
test_iql_speed[reduce-overhead-None] | 16.7057ms | 15.7787ms | 63.3766 Ops/s | 62.3173 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 28.5708ms | 27.0915ms | 36.9119 Ops/s | 36.4693 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 7.5242ms | 4.8603ms | 205.7501 Ops/s | 205.1682 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.8807ms | 0.5223ms | 1.9144 KOps/s | 1.9520 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.8289ms | 0.4993ms | 2.0027 KOps/s | 2.0203 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 5.1899ms | 4.6744ms | 213.9300 Ops/s | 218.1507 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 2.7189ms | 0.5112ms | 1.9561 KOps/s | 1.9739 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.9171ms | 0.4892ms | 2.0442 KOps/s | 2.0881 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 2.3909ms | 1.6847ms | 593.5674 Ops/s | 607.5345 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 2.4403ms | 1.5996ms | 625.1756 Ops/s | 638.7079 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 7.2508ms | 4.7626ms | 209.9684 Ops/s | 211.5272 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 2.3785ms | 0.6683ms | 1.4964 KOps/s | 1.5403 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 1.0514ms | 0.6354ms | 1.5737 KOps/s | 1.5987 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 5.5881ms | 4.6224ms | 216.3362 Ops/s | 215.5646 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 2.9884ms | 0.5251ms | 1.9043 KOps/s | 1.9002 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.7819ms | 0.4966ms | 2.0136 KOps/s | 2.0338 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 4.8889ms | 4.5685ms | 218.8907 Ops/s | 218.8448 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 2.4492ms | 0.5083ms | 1.9673 KOps/s | 1.9853 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.9461ms | 0.4937ms | 2.0256 KOps/s | 1.9903 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 4.8796ms | 4.6731ms | 213.9890 Ops/s | 210.7630 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 2.2775ms | 0.6575ms | 1.5210 KOps/s | 1.5512 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 1.0146ms | 0.6375ms | 1.5687 KOps/s | 1.5910 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.7474s | 19.1502ms | 52.2188 Ops/s | 252.6653 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 7.5421ms | 2.4836ms | 402.6348 Ops/s | 439.7529 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 1.9992ms | 1.2654ms | 790.2648 Ops/s | 781.9049 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 5.7821ms | 4.4504ms | 224.6987 Ops/s | 24.5577 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 7.0013ms | 2.3783ms | 420.4668 Ops/s | 434.1877 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 5.5109ms | 1.4076ms | 710.4371 Ops/s | 773.8216 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 5.9358ms | 4.5892ms | 217.9010 Ops/s | 218.1338 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 8.9733ms | 2.6094ms | 383.2256 Ops/s | 408.1929 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 4.2783ms | 1.4517ms | 688.8326 Ops/s | 640.0356 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 60.2168ms | 51.2038ms | 19.5298 Ops/s | 19.5130 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 15.8485ms | 14.6016ms | 68.4858 Ops/s | 69.6231 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 60.1486ms | 51.1486ms | 19.5509 Ops/s | 19.7582 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 15.6831ms | 14.6469ms | 68.2738 Ops/s | 68.8215 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 59.7138ms | 50.2430ms | 19.9033 Ops/s | 19.3723 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 17.5019ms | 15.9584ms | 62.6628 Ops/s | 61.3305 Ops/s |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.9247s | 0.8352s | 1.1974 Ops/s | 1.2342 Ops/s | |
test_transformed | 1.5751s | 1.4831s | 0.6743 Ops/s | 0.7204 Ops/s | |
test_serial | 2.3348s | 2.3276s | 0.4296 Ops/s | 0.4395 Ops/s | |
test_parallel | 1.8737s | 1.8501s | 0.5405 Ops/s | 0.5247 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.2181ms | 39.4480μs | 25.3499 KOps/s | 24.8669 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 0.2084ms | 23.3873μs | 42.7583 KOps/s | 41.7959 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 0.2152ms | 22.7146μs | 44.0245 KOps/s | 44.3122 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 0.1913ms | 12.9656μs | 77.1273 KOps/s | 76.8508 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 0.2512ms | 42.8422μs | 23.3415 KOps/s | 23.1463 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 59.3010μs | 25.7878μs | 38.7780 KOps/s | 37.9085 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 57.0410μs | 25.0397μs | 39.9365 KOps/s | 39.6334 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 0.2081ms | 15.3561μs | 65.1208 KOps/s | 64.2779 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 0.1249ms | 43.6469μs | 22.9111 KOps/s | 22.1216 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 51.7410μs | 28.2656μs | 35.3787 KOps/s | 34.8829 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 51.7710μs | 24.8089μs | 40.3081 KOps/s | 40.1668 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 41.5800μs | 15.2632μs | 65.5173 KOps/s | 64.0967 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 81.8010μs | 47.3826μs | 21.1048 KOps/s | 21.1678 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 59.7110μs | 30.5440μs | 32.7396 KOps/s | 32.8600 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 54.7010μs | 26.7415μs | 37.3951 KOps/s | 37.1751 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 59.2710μs | 17.4605μs | 57.2721 KOps/s | 56.2251 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 82.1310μs | 44.8012μs | 22.3208 KOps/s | 22.1909 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 0.2168ms | 27.6648μs | 36.1470 KOps/s | 35.2172 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 2.6289ms | 29.0503μs | 34.4230 KOps/s | 34.7866 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 0.1943ms | 17.1984μs | 58.1449 KOps/s | 58.4806 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 0.2138ms | 47.8154μs | 20.9138 KOps/s | 20.7955 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 68.9910μs | 30.5250μs | 32.7600 KOps/s | 32.0963 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 0.1353ms | 31.0988μs | 32.1556 KOps/s | 31.9362 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 46.3300μs | 19.2687μs | 51.8975 KOps/s | 50.9353 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 79.8110μs | 49.6547μs | 20.1391 KOps/s | 19.7188 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 0.1200ms | 32.7121μs | 30.5697 KOps/s | 29.8974 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 0.1674ms | 30.5696μs | 32.7122 KOps/s | 32.4523 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 42.2810μs | 19.4340μs | 51.4561 KOps/s | 50.9271 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 76.7710μs | 51.5037μs | 19.4161 KOps/s | 19.2582 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 69.6310μs | 35.3274μs | 28.3067 KOps/s | 27.8861 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 66.3210μs | 32.1517μs | 31.1026 KOps/s | 29.9902 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 0.1321ms | 21.6747μs | 46.1368 KOps/s | 45.8517 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 25.5069ms | 25.0141ms | 39.9774 Ops/s | 38.9588 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 0.1123s | 3.1487ms | 317.5872 Ops/s | 351.7154 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.1059ms | 79.0372μs | 12.6523 KOps/s | 12.5983 KOps/s | |
test_values[td1_return_estimate-False-False] | 58.2833ms | 55.6867ms | 17.9576 Ops/s | 18.2474 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 1.3975ms | 1.0869ms | 920.0721 Ops/s | 918.2950 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 92.5602ms | 87.4584ms | 11.4340 Ops/s | 11.5141 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 1.3603ms | 1.0876ms | 919.4267 Ops/s | 923.3755 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 24.8300ms | 24.4789ms | 40.8515 Ops/s | 38.5277 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 1.0378ms | 0.7412ms | 1.3493 KOps/s | 1.3283 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.8303ms | 0.6613ms | 1.5122 KOps/s | 1.5002 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 1.6134ms | 1.4798ms | 675.7894 Ops/s | 672.6477 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 0.8405ms | 0.6766ms | 1.4780 KOps/s | 1.4736 KOps/s | |
test_dqn_speed[False-None] | 1.6871ms | 1.5261ms | 655.2808 Ops/s | 670.0316 Ops/s | |
test_dqn_speed[False-backward] | 2.3038ms | 2.1482ms | 465.5000 Ops/s | 445.7629 Ops/s | |
test_dqn_speed[True-None] | 0.7694ms | 0.5443ms | 1.8372 KOps/s | 1.7646 KOps/s | |
test_dqn_speed[True-backward] | 1.3331ms | 1.1110ms | 900.1268 Ops/s | 878.7431 Ops/s | |
test_dqn_speed[reduce-overhead-None] | 0.7192ms | 0.5631ms | 1.7757 KOps/s | 1.7556 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 1.1137ms | 1.0647ms | 939.2745 Ops/s | 925.1307 Ops/s | |
test_ddpg_speed[False-None] | 3.2907ms | 2.8505ms | 350.8194 Ops/s | 354.8874 Ops/s | |
test_ddpg_speed[False-backward] | 4.8511ms | 4.2849ms | 233.3757 Ops/s | 238.1097 Ops/s | |
test_ddpg_speed[True-None] | 1.5396ms | 1.3308ms | 751.4217 Ops/s | 743.2127 Ops/s | |
test_ddpg_speed[True-backward] | 2.7147ms | 2.5545ms | 391.4644 Ops/s | 382.5545 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.4892ms | 1.3380ms | 747.3729 Ops/s | 735.1310 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 2.1870ms | 2.0401ms | 490.1733 Ops/s | 486.1986 Ops/s | |
test_sac_speed[False-None] | 8.5707ms | 8.1243ms | 123.0875 Ops/s | 125.0589 Ops/s | |
test_sac_speed[False-backward] | 11.9029ms | 11.3386ms | 88.1943 Ops/s | 88.7735 Ops/s | |
test_sac_speed[True-None] | 2.1374ms | 1.8132ms | 551.5036 Ops/s | 540.9074 Ops/s | |
test_sac_speed[True-backward] | 4.0323ms | 3.7406ms | 267.3333 Ops/s | 261.2852 Ops/s | |
test_sac_speed[reduce-overhead-None] | 21.6437ms | 12.2161ms | 81.8594 Ops/s | 83.5995 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 1.9507ms | 1.7730ms | 564.0299 Ops/s | 553.3016 Ops/s | |
test_redq_speed[False-None] | 8.1382ms | 7.6603ms | 130.5424 Ops/s | 130.4825 Ops/s | |
test_redq_speed[False-backward] | 12.4383ms | 11.8767ms | 84.1982 Ops/s | 83.4695 Ops/s | |
test_redq_speed[True-None] | 2.5718ms | 2.2913ms | 436.4341 Ops/s | 427.1178 Ops/s | |
test_redq_speed[True-backward] | 4.4798ms | 4.0745ms | 245.4297 Ops/s | 243.6069 Ops/s | |
test_redq_speed[reduce-overhead-None] | 2.7349ms | 2.3772ms | 420.6668 Ops/s | 422.9766 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 4.3400ms | 4.0420ms | 247.4048 Ops/s | 243.3898 Ops/s | |
test_redq_deprec_speed[False-None] | 9.6577ms | 9.2638ms | 107.9474 Ops/s | 111.3625 Ops/s | |
test_redq_deprec_speed[False-backward] | 12.6487ms | 12.2379ms | 81.7133 Ops/s | 82.1957 Ops/s | |
test_redq_deprec_speed[True-None] | 2.7759ms | 2.5938ms | 385.5378 Ops/s | 371.5645 Ops/s | |
test_redq_deprec_speed[True-backward] | 4.7637ms | 4.4363ms | 225.4128 Ops/s | 226.9919 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 3.0391ms | 2.7009ms | 370.2492 Ops/s | 376.6509 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 4.7234ms | 4.2839ms | 233.4295 Ops/s | 222.0647 Ops/s | |
test_td3_speed[False-None] | 8.3547ms | 8.0963ms | 123.5130 Ops/s | 124.2848 Ops/s | |
test_td3_speed[False-backward] | 11.0561ms | 10.4129ms | 96.0349 Ops/s | 95.6305 Ops/s | |
test_td3_speed[True-None] | 1.7327ms | 1.6289ms | 613.9008 Ops/s | 608.7264 Ops/s | |
test_td3_speed[True-backward] | 3.5758ms | 3.3355ms | 299.8050 Ops/s | 311.3793 Ops/s | |
test_td3_speed[reduce-overhead-None] | 73.3739ms | 26.6845ms | 37.4749 Ops/s | 37.3839 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 1.6258ms | 1.4791ms | 676.0930 Ops/s | 714.1742 Ops/s | |
test_cql_speed[False-None] | 17.6116ms | 16.9611ms | 58.9584 Ops/s | 59.6010 Ops/s | |
test_cql_speed[False-backward] | 22.9305ms | 22.4954ms | 44.4535 Ops/s | 45.5974 Ops/s | |
test_cql_speed[True-None] | 3.5751ms | 3.2247ms | 310.1067 Ops/s | 302.7467 Ops/s | |
test_cql_speed[True-backward] | 5.7191ms | 5.4708ms | 182.7874 Ops/s | 179.8188 Ops/s | |
test_cql_speed[reduce-overhead-None] | 21.0894ms | 13.2963ms | 75.2091 Ops/s | 74.1865 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 2.1614ms | 1.9906ms | 502.3519 Ops/s | 531.5386 Ops/s | |
test_a2c_speed[False-None] | 3.4929ms | 3.1988ms | 312.6212 Ops/s | 317.0462 Ops/s | |
test_a2c_speed[False-backward] | 7.3119ms | 6.3702ms | 156.9821 Ops/s | 163.7294 Ops/s | |
test_a2c_speed[True-None] | 1.5664ms | 1.3538ms | 738.6746 Ops/s | 744.0483 Ops/s | |
test_a2c_speed[True-backward] | 3.3250ms | 3.0508ms | 327.7825 Ops/s | 336.9712 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 16.3836ms | 9.1722ms | 109.0246 Ops/s | 117.1250 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 1.9038ms | 1.6042ms | 623.3805 Ops/s | 675.0885 Ops/s | |
test_ppo_speed[False-None] | 3.9721ms | 3.6807ms | 271.6885 Ops/s | 273.8093 Ops/s | |
test_ppo_speed[False-backward] | 7.5975ms | 7.1092ms | 140.6627 Ops/s | 146.1732 Ops/s | |
test_ppo_speed[True-None] | 1.6882ms | 1.4093ms | 709.5722 Ops/s | 691.9887 Ops/s | |
test_ppo_speed[True-backward] | 3.3696ms | 3.2261ms | 309.9759 Ops/s | 318.1851 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 1.1717ms | 0.9693ms | 1.0317 KOps/s | 1.0273 KOps/s | |
test_ppo_speed[reduce-overhead-backward] | 1.7267ms | 1.5599ms | 641.0786 Ops/s | 630.5956 Ops/s | |
test_reinforce_speed[False-None] | 2.6018ms | 2.2696ms | 440.6004 Ops/s | 446.0347 Ops/s | |
test_reinforce_speed[False-backward] | 3.9756ms | 3.3929ms | 294.7367 Ops/s | 289.9044 Ops/s | |
test_reinforce_speed[True-None] | 1.5696ms | 1.2903ms | 775.0094 Ops/s | 757.9694 Ops/s | |
test_reinforce_speed[True-backward] | 3.2300ms | 3.0634ms | 326.4336 Ops/s | 338.8911 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 21.9539ms | 10.5809ms | 94.5097 Ops/s | 94.1066 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 1.7475ms | 1.6265ms | 614.8139 Ops/s | 655.4789 Ops/s | |
test_iql_speed[False-None] | 9.6613ms | 9.2388ms | 108.2389 Ops/s | 108.0869 Ops/s | |
test_iql_speed[False-backward] | 13.8969ms | 13.2550ms | 75.4434 Ops/s | 75.9674 Ops/s | |
test_iql_speed[True-None] | 2.4939ms | 2.2014ms | 454.2576 Ops/s | 440.8169 Ops/s | |
test_iql_speed[True-backward] | 5.3249ms | 4.9093ms | 203.6963 Ops/s | 202.7484 Ops/s | |
test_iql_speed[reduce-overhead-None] | 0.5182s | 13.2163ms | 75.6644 Ops/s | 89.2245 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 2.2104ms | 2.0611ms | 485.1695 Ops/s | 522.5459 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 7.8196ms | 6.2835ms | 159.1468 Ops/s | 158.0969 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.5696ms | 0.3329ms | 3.0042 KOps/s | 3.0289 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.6628ms | 0.3187ms | 3.1379 KOps/s | 3.1384 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.3445ms | 5.9421ms | 168.2906 Ops/s | 166.1735 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.9804ms | 0.2803ms | 3.5676 KOps/s | 3.1510 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.6041ms | 0.2804ms | 3.5657 KOps/s | 3.4266 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 1.5649ms | 1.3297ms | 752.0365 Ops/s | 766.9202 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 1.6469ms | 1.2274ms | 814.7461 Ops/s | 817.7161 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.4650ms | 6.1805ms | 161.8003 Ops/s | 160.6027 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 2.3247ms | 0.4212ms | 2.3743 KOps/s | 2.3295 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.6676ms | 0.4400ms | 2.2726 KOps/s | 2.5751 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 10.1238ms | 6.0721ms | 164.6876 Ops/s | 164.9081 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.9174ms | 0.3644ms | 2.7440 KOps/s | 3.6106 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 1.1762ms | 0.3385ms | 2.9543 KOps/s | 4.1138 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.4168ms | 5.9597ms | 167.7937 Ops/s | 165.6002 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.8844ms | 0.2648ms | 3.7763 KOps/s | 3.4537 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.5909ms | 0.3176ms | 3.1484 KOps/s | 3.8983 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.3663ms | 6.1368ms | 162.9524 Ops/s | 159.4591 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.9650ms | 0.4729ms | 2.1146 KOps/s | 2.1019 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.7605ms | 0.4502ms | 2.2213 KOps/s | 2.2477 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 7.0833ms | 5.5023ms | 181.7423 Ops/s | 177.5842 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 11.0300ms | 2.1469ms | 465.7930 Ops/s | 438.5409 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 2.1413ms | 1.1575ms | 863.9579 Ops/s | 818.4623 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 9.1307ms | 5.6813ms | 176.0172 Ops/s | 175.5703 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 6.3653ms | 2.0418ms | 489.7683 Ops/s | 432.2026 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 8.9541ms | 1.3005ms | 768.9551 Ops/s | 887.3345 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.5457s | 16.5737ms | 60.3367 Ops/s | 30.0885 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 9.4022ms | 2.2350ms | 447.4188 Ops/s | 542.2584 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 7.9889ms | 1.3735ms | 728.0730 Ops/s | 819.3996 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 13.9209ms | 13.1925ms | 75.8007 Ops/s | 73.8555 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 19.1655ms | 17.2223ms | 58.0644 Ops/s | 56.9792 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 18.8522ms | 18.3473ms | 54.5041 Ops/s | 53.6428 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 19.4398ms | 17.5398ms | 57.0133 Ops/s | 57.3872 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 18.7709ms | 17.9678ms | 55.6552 Ops/s | 54.4511 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 20.5085ms | 18.9700ms | 52.7147 Ops/s | 52.9621 Ops/s |
ghstack-source-id: 899441844c058e291017de34c7be2df0f8219a31 Pull Request resolved: #2796
ghstack-source-id: 899441844c058e291017de34c7be2df0f8219a31 Pull Request resolved: pytorch#2796
ghstack-source-id: 899441844c058e291017de34c7be2df0f8219a31 Pull Request resolved: pytorch#2796
ghstack-source-id: 899441844c058e291017de34c7be2df0f8219a31 Pull Request resolved: pytorch#2796
ghstack-source-id: 08ebabd8c0b3ba0776a3b45370a056a3b90b20d2 Pull Request resolved: #2796
ghstack-source-id: 08ebabd8c0b3ba0776a3b45370a056a3b90b20d2 Pull Request resolved: pytorch#2796
ghstack-source-id: bd984300c77e8ed51adf687d1d826e9c149911f0 Pull Request resolved: #2796
@kurtamohler LMK if you need help with this! |
Just wanted to mention that I am actively working on this. It was a little tricky to get ChessEnv working correctly with |
ghstack-source-id: bd984300c77e8ed51adf687d1d826e9c149911f0 Pull Request resolved: pytorch#2796
Super cool thanks! |
ghstack-source-id: bd984300c77e8ed51adf687d1d826e9c149911f0 Pull Request resolved: pytorch#2796
ghstack-source-id: 15144dfb9e3ce724bc8c7a403b436f46ac8c5f8d Pull Request resolved: #2796
I was able to improve the performance a fair bit, but it is now around 17x slower (down from 100x) than my standalone example code. But I suppose we can make further performance improvements later down the road. At the moment, this PR is kind of a mess, so I'll fix it up and probably split out the stuff that is not directly related to MCTS into a separate PR |
Do you know what's causing the slowdown? TensorDict overhead? |
I've been using py-spy to find bottlenecks and improve them. Here's the flamegraph that it produces right now: At the moment, My overall runtime measurements have just been the time it takes to run the entire script, including module imports. Importing pytorch and torchrl is evidently a significant part of the whole runtime ~25%, so performance is actually a little better than what I said before. It's actually about 13x slower than the standalone script after taking that into account. |
def all_actions( | ||
self, tensordict: Optional[TensorDictBase] = None | ||
) -> TensorDictBase: | ||
if not self._overrides_action_generator_funcs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In order to allow ChessEnv to override all_actions
and _rand_action
in the case where mask_actions=False
, I added this hacky _overrides_action_generator_funcs
setting. I only added that to confirm that avoiding the mask increases performance, and did not intend for it to be a long-term solution. I intended to remove _overrides_action_generator_funcs
in favor of always calling the base env's all_actions
/_rand_action
functions and then transform the result in TransformedEnv.all_actions
. But it's proving to be pretty difficult and it's slowing down my progress on MCTS itself.
There may be a good solution to this, but taking a step back to consider the bigger picture, I think it would be more pragmatic for me to stop trying to make mask_actions=False
work for now. The mcts.py
script currently takes ~9 s without the mask versus ~12-15 s with the mask. It would be nice to achieve that speedup eventually, but in order to develop a good MCTS API, it just don't think it's necessary. Furthermore, once we have the ability to combine MCTS with a neural network evaluation function (which I assume is one of the goals), then I think the time it takes to generate the mask may no longer be a significant bottleneck. Plus, there may be ways to generate the action mask more efficiently anyway.
So I'll back out the changes I made with respect to the action mask and start focusing more on the MCTS API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a side note, the board.san(m)
call in _legal_moves_to_index
is pretty slow, and I think that avoiding it could improve performance. At the moment, we rely on the SAN format because the action indices correspond to one of all of the possible ~30K SAN strings. But we could choose to index it in a different way that would be more performant.
I would propose that we directly convert each chess.Move
into an index using its from_square
, to_square
, and promotion
properties, which are all integers. from_square
and to_square
are just numbers between 0 and 63. promotion
is a number that indicates which piece a pawn will promote to, if applicable. In the chess library, promotion
can specify any of the 6 kinds of pieces or none, so it has 7 possible values (although by the rules of chess, you can only promote to 4 different kinds of pieces). We can convert this to a number like so:
index = move.promotion * (64 * 64) + move.from_square * 64 + move.to_square
Then we wouldn't have to look up the SAN string in a big list (which I've modified to be a map in this PR) to obtain the index, and we wouldn't have to look up the index in a big list to covert an index back to a Move object--we would just extract the Move
properties directly from the index.
Additionally, this would give an slight extra benefit that we would decrease the size of the action mask from ~29K down to 5 * 64 * 64 = 20480
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I figured I might as well try out my idea for a more performant action indexing scheme, since it's relatively simple to implement. It seems to have worked fairly well. Runtime of the mcts script is typically ~10-12 s now, so a decent speedup over the previous 12-15 s. I'll split that out into a separate PR though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, that change breaks some of the ChessEnv tests. Changing the index to be something other than SAN-based would mean that in order to convert a SAN string to a move index, as is done in test_san
and test_reward
, we would first need to obtain a chess.Board
object that reflects the state of the board, since the conversion of a SAN string to a chess.Move
object depends on the state of the board. It's certainly possible to do this, but I'll drop this idea for now since it turns out not to be as straightforward as I thought
ghstack-source-id: 15144dfb9e3ce724bc8c7a403b436f46ac8c5f8d Pull Request resolved: pytorch#2796
ghstack-source-id: 4cf2a162e81a2d58bf4cedfa6b22fae100398323 Pull Request resolved: #2796
I think this is ready for review. I'll add MCTS APIs in subsequent PRs and update the example script accordingly |
Stack from ghstack (oldest at bottom):
ChessEnv.all_actions
#2849