-
Notifications
You must be signed in to change notification settings - Fork 368
[Feature,Example] Add MCTS algorithm and example #2796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/kurtamohler/5/base
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2796
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 4 New Failures, 2 Cancelled Jobs, 10 Unrelated FailuresAs of commit e7dc7d3 with merge base a31dca3 ( NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This seems to work, but at the moment, it is about 100x slower than the one I implemented outside of TorchRL here. I will see what I can do to speed it up. Once I improve performance, then I'll think about how to add a good API for it |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.6670s | 0.5657s | 1.7678 Ops/s | 1.8609 Ops/s | |
test_transformed | 1.1937s | 1.1151s | 0.8968 Ops/s | 0.9544 Ops/s | |
test_serial | 1.5891s | 1.5822s | 0.6320 Ops/s | 0.6516 Ops/s | |
test_parallel | 1.4041s | 1.3072s | 0.7650 Ops/s | 0.7655 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.1281ms | 30.6692μs | 32.6060 KOps/s | 31.8336 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 48.5910μs | 18.2776μs | 54.7118 KOps/s | 56.3213 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 58.3490μs | 17.9955μs | 55.5695 KOps/s | 58.8661 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 52.4480μs | 10.5994μs | 94.3453 KOps/s | 97.4872 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 77.4040μs | 33.4255μs | 29.9173 KOps/s | 31.2275 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 47.3890μs | 20.4616μs | 48.8721 KOps/s | 50.2605 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 0.6047ms | 19.8741μs | 50.3168 KOps/s | 52.5319 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 38.8520μs | 12.4776μs | 80.1436 KOps/s | 83.1997 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 78.1160μs | 35.1778μs | 28.4270 KOps/s | 29.5983 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 63.1480μs | 22.4285μs | 44.5862 KOps/s | 45.9753 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 56.7050μs | 19.8779μs | 50.3071 KOps/s | 52.3006 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 44.2430μs | 12.3023μs | 81.2858 KOps/s | 83.5345 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 79.4780μs | 36.8559μs | 27.1327 KOps/s | 27.8850 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 65.3020μs | 24.0665μs | 41.5516 KOps/s | 42.5832 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 67.5660μs | 21.4726μs | 46.5709 KOps/s | 48.9753 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 44.1620μs | 14.1554μs | 70.6443 KOps/s | 73.3485 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 89.2260μs | 35.1502μs | 28.4493 KOps/s | 29.3951 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 53.0590μs | 22.2377μs | 44.9688 KOps/s | 46.7607 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 55.3930μs | 22.4340μs | 44.5752 KOps/s | 46.5204 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 46.6070μs | 13.8630μs | 72.1345 KOps/s | 74.9763 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 79.1870μs | 36.8129μs | 27.1644 KOps/s | 27.4907 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 2.4669ms | 24.4106μs | 40.9657 KOps/s | 42.3422 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 63.7790μs | 24.3232μs | 41.1130 KOps/s | 42.4322 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 0.5936ms | 15.7628μs | 63.4404 KOps/s | 65.9291 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 0.1398ms | 38.7365μs | 25.8154 KOps/s | 26.5864 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 62.1860μs | 25.8683μs | 38.6573 KOps/s | 39.6015 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 76.6530μs | 24.1655μs | 41.3813 KOps/s | 42.9654 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 61.3640μs | 15.6321μs | 63.9710 KOps/s | 65.6277 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 85.9110μs | 40.1345μs | 24.9162 KOps/s | 25.8407 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 92.3440μs | 27.6099μs | 36.2189 KOps/s | 37.4261 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 60.5130μs | 25.7848μs | 38.7825 KOps/s | 37.7366 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 94.3540μs | 17.2691μs | 57.9068 KOps/s | 59.4643 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 10.6862ms | 10.0041ms | 99.9587 Ops/s | 102.1639 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 27.6173ms | 24.5006ms | 40.8154 Ops/s | 37.8485 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.3514ms | 0.2258ms | 4.4283 KOps/s | 5.5216 KOps/s | |
test_values[td1_return_estimate-False-False] | 28.3776ms | 24.8935ms | 40.1711 Ops/s | 41.7830 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 27.3800ms | 24.6668ms | 40.5403 Ops/s | 37.4782 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 35.8925ms | 35.0271ms | 28.5493 Ops/s | 28.6664 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 26.7276ms | 24.6828ms | 40.5140 Ops/s | 37.3563 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 9.5037ms | 8.4806ms | 117.9166 Ops/s | 118.9131 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 2.4665ms | 1.9874ms | 503.1822 Ops/s | 524.9830 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.6568ms | 0.3774ms | 2.6496 KOps/s | 2.7049 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 49.4232ms | 45.4212ms | 22.0162 Ops/s | 20.8984 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 4.6964ms | 3.5654ms | 280.4697 Ops/s | 289.4240 Ops/s | |
test_dqn_speed[False-None] | 6.2207ms | 1.4725ms | 679.0958 Ops/s | 712.1553 Ops/s | |
test_dqn_speed[False-backward] | 2.0657ms | 1.9452ms | 514.0792 Ops/s | 527.0985 Ops/s | |
test_dqn_speed[True-None] | 0.9742ms | 0.5752ms | 1.7385 KOps/s | 1.7572 KOps/s | |
test_dqn_speed[True-backward] | 1.1054ms | 0.9898ms | 1.0103 KOps/s | 785.4052 Ops/s | |
test_dqn_speed[reduce-overhead-None] | 0.9310ms | 0.5628ms | 1.7769 KOps/s | 1.7587 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 1.0243ms | 0.9810ms | 1.0194 KOps/s | 1.0131 KOps/s | |
test_ddpg_speed[False-None] | 3.7434ms | 2.9791ms | 335.6763 Ops/s | 344.3897 Ops/s | |
test_ddpg_speed[False-backward] | 4.2266ms | 4.1262ms | 242.3543 Ops/s | 246.8492 Ops/s | |
test_ddpg_speed[True-None] | 1.9377ms | 1.4427ms | 693.1523 Ops/s | 684.2267 Ops/s | |
test_ddpg_speed[True-backward] | 2.4205ms | 2.3307ms | 429.0576 Ops/s | 417.1818 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.9018ms | 1.4416ms | 693.6696 Ops/s | 678.2458 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 2.4246ms | 2.3261ms | 429.9062 Ops/s | 422.7577 Ops/s | |
test_sac_speed[False-None] | 8.6580ms | 8.2573ms | 121.1052 Ops/s | 123.1160 Ops/s | |
test_sac_speed[False-backward] | 11.4724ms | 10.9531ms | 91.2983 Ops/s | 90.7021 Ops/s | |
test_sac_speed[True-None] | 3.3506ms | 2.5689ms | 389.2706 Ops/s | 382.8235 Ops/s | |
test_sac_speed[True-backward] | 5.2565ms | 4.2676ms | 234.3244 Ops/s | 231.5071 Ops/s | |
test_sac_speed[reduce-overhead-None] | 3.1729ms | 2.5712ms | 388.9194 Ops/s | 382.0138 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 4.3214ms | 4.2427ms | 235.6994 Ops/s | 229.2568 Ops/s | |
test_redq_speed[False-None] | 13.7217ms | 13.0737ms | 76.4894 Ops/s | 74.9000 Ops/s | |
test_redq_speed[False-backward] | 24.0607ms | 22.5596ms | 44.3270 Ops/s | 42.6657 Ops/s | |
test_redq_speed[True-None] | 7.3458ms | 6.6228ms | 150.9942 Ops/s | 142.2213 Ops/s | |
test_redq_speed[True-backward] | 16.1895ms | 14.3244ms | 69.8108 Ops/s | 67.5794 Ops/s | |
test_redq_speed[reduce-overhead-None] | 8.0604ms | 6.7333ms | 148.5151 Ops/s | 140.1701 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 15.3748ms | 14.2692ms | 70.0812 Ops/s | 69.3793 Ops/s | |
test_redq_deprec_speed[False-None] | 13.6675ms | 12.9704ms | 77.0984 Ops/s | 75.8742 Ops/s | |
test_redq_deprec_speed[False-backward] | 19.6393ms | 18.6507ms | 53.6174 Ops/s | 52.3380 Ops/s | |
test_redq_deprec_speed[True-None] | 5.8744ms | 5.1439ms | 194.4067 Ops/s | 192.1556 Ops/s | |
test_redq_deprec_speed[True-backward] | 10.5701ms | 9.9102ms | 100.9058 Ops/s | 99.7643 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 5.5362ms | 5.1505ms | 194.1567 Ops/s | 192.4768 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 10.7469ms | 9.9701ms | 100.2998 Ops/s | 99.4953 Ops/s | |
test_td3_speed[False-None] | 8.4482ms | 8.2259ms | 121.5678 Ops/s | 123.8643 Ops/s | |
test_td3_speed[False-backward] | 12.0238ms | 10.7250ms | 93.2405 Ops/s | 94.9045 Ops/s | |
test_td3_speed[True-None] | 3.7342ms | 2.4258ms | 412.2346 Ops/s | 436.2082 Ops/s | |
test_td3_speed[True-backward] | 5.8898ms | 4.1622ms | 240.2558 Ops/s | 247.2589 Ops/s | |
test_td3_speed[reduce-overhead-None] | 3.4836ms | 2.4274ms | 411.9705 Ops/s | 436.0162 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 4.7328ms | 3.9302ms | 254.4390 Ops/s | 251.6302 Ops/s | |
test_cql_speed[False-None] | 39.3408ms | 36.4926ms | 27.4028 Ops/s | 27.2499 Ops/s | |
test_cql_speed[False-backward] | 63.8771ms | 48.7278ms | 20.5221 Ops/s | 20.7983 Ops/s | |
test_cql_speed[True-None] | 23.4197ms | 22.3086ms | 44.8257 Ops/s | 43.1851 Ops/s | |
test_cql_speed[True-backward] | 29.9610ms | 29.0382ms | 34.4374 Ops/s | 33.8017 Ops/s | |
test_cql_speed[reduce-overhead-None] | 24.0368ms | 22.4104ms | 44.6221 Ops/s | 43.9143 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 30.4437ms | 29.1107ms | 34.3516 Ops/s | 33.8345 Ops/s | |
test_a2c_speed[False-None] | 7.9690ms | 7.1985ms | 138.9183 Ops/s | 136.1004 Ops/s | |
test_a2c_speed[False-backward] | 16.4940ms | 14.3265ms | 69.8009 Ops/s | 69.3241 Ops/s | |
test_a2c_speed[True-None] | 5.5012ms | 4.6716ms | 214.0586 Ops/s | 213.7634 Ops/s | |
test_a2c_speed[True-backward] | 11.6603ms | 11.1312ms | 89.8376 Ops/s | 88.2366 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 5.0690ms | 4.6514ms | 214.9881 Ops/s | 213.3272 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 12.3482ms | 11.1782ms | 89.4601 Ops/s | 88.3800 Ops/s | |
test_ppo_speed[False-None] | 8.9908ms | 7.5480ms | 132.4857 Ops/s | 132.9993 Ops/s | |
test_ppo_speed[False-backward] | 15.9991ms | 14.8850ms | 67.1818 Ops/s | 66.8864 Ops/s | |
test_ppo_speed[True-None] | 7.0669ms | 5.6722ms | 176.2993 Ops/s | 196.3270 Ops/s | |
test_ppo_speed[True-backward] | 13.1540ms | 11.9436ms | 83.7266 Ops/s | 91.1146 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 5.9280ms | 5.0496ms | 198.0340 Ops/s | 196.7899 Ops/s | |
test_ppo_speed[reduce-overhead-backward] | 12.9746ms | 11.0568ms | 90.4422 Ops/s | 90.1317 Ops/s | |
test_reinforce_speed[False-None] | 7.8501ms | 6.6475ms | 150.4328 Ops/s | 151.8951 Ops/s | |
test_reinforce_speed[False-backward] | 10.3156ms | 10.0050ms | 99.9499 Ops/s | 101.2409 Ops/s | |
test_reinforce_speed[True-None] | 4.8988ms | 4.0596ms | 246.3294 Ops/s | 242.3324 Ops/s | |
test_reinforce_speed[True-backward] | 11.3386ms | 10.0807ms | 99.1991 Ops/s | 98.2430 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 5.6285ms | 4.1157ms | 242.9698 Ops/s | 243.6155 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 10.3976ms | 10.0302ms | 99.6992 Ops/s | 99.6690 Ops/s | |
test_iql_speed[False-None] | 38.5091ms | 32.8647ms | 30.4278 Ops/s | 30.6699 Ops/s | |
test_iql_speed[False-backward] | 66.0998ms | 46.3514ms | 21.5743 Ops/s | 21.8439 Ops/s | |
test_iql_speed[True-None] | 16.8175ms | 15.7538ms | 63.4769 Ops/s | 62.1946 Ops/s | |
test_iql_speed[True-backward] | 27.6974ms | 27.0039ms | 37.0317 Ops/s | 36.8073 Ops/s | |
test_iql_speed[reduce-overhead-None] | 16.7057ms | 15.7787ms | 63.3766 Ops/s | 62.3173 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 28.5708ms | 27.0915ms | 36.9119 Ops/s | 36.4693 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 7.5242ms | 4.8603ms | 205.7501 Ops/s | 205.1682 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.8807ms | 0.5223ms | 1.9144 KOps/s | 1.9520 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.8289ms | 0.4993ms | 2.0027 KOps/s | 2.0203 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 5.1899ms | 4.6744ms | 213.9300 Ops/s | 218.1507 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 2.7189ms | 0.5112ms | 1.9561 KOps/s | 1.9739 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.9171ms | 0.4892ms | 2.0442 KOps/s | 2.0881 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 2.3909ms | 1.6847ms | 593.5674 Ops/s | 607.5345 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 2.4403ms | 1.5996ms | 625.1756 Ops/s | 638.7079 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 7.2508ms | 4.7626ms | 209.9684 Ops/s | 211.5272 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 2.3785ms | 0.6683ms | 1.4964 KOps/s | 1.5403 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 1.0514ms | 0.6354ms | 1.5737 KOps/s | 1.5987 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 5.5881ms | 4.6224ms | 216.3362 Ops/s | 215.5646 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 2.9884ms | 0.5251ms | 1.9043 KOps/s | 1.9002 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.7819ms | 0.4966ms | 2.0136 KOps/s | 2.0338 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 4.8889ms | 4.5685ms | 218.8907 Ops/s | 218.8448 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 2.4492ms | 0.5083ms | 1.9673 KOps/s | 1.9853 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.9461ms | 0.4937ms | 2.0256 KOps/s | 1.9903 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 4.8796ms | 4.6731ms | 213.9890 Ops/s | 210.7630 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 2.2775ms | 0.6575ms | 1.5210 KOps/s | 1.5512 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 1.0146ms | 0.6375ms | 1.5687 KOps/s | 1.5910 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.7474s | 19.1502ms | 52.2188 Ops/s | 252.6653 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 7.5421ms | 2.4836ms | 402.6348 Ops/s | 439.7529 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 1.9992ms | 1.2654ms | 790.2648 Ops/s | 781.9049 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 5.7821ms | 4.4504ms | 224.6987 Ops/s | 24.5577 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 7.0013ms | 2.3783ms | 420.4668 Ops/s | 434.1877 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 5.5109ms | 1.4076ms | 710.4371 Ops/s | 773.8216 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 5.9358ms | 4.5892ms | 217.9010 Ops/s | 218.1338 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 8.9733ms | 2.6094ms | 383.2256 Ops/s | 408.1929 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 4.2783ms | 1.4517ms | 688.8326 Ops/s | 640.0356 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 60.2168ms | 51.2038ms | 19.5298 Ops/s | 19.5130 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 15.8485ms | 14.6016ms | 68.4858 Ops/s | 69.6231 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 60.1486ms | 51.1486ms | 19.5509 Ops/s | 19.7582 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 15.6831ms | 14.6469ms | 68.2738 Ops/s | 68.8215 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 59.7138ms | 50.2430ms | 19.9033 Ops/s | 19.3723 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 17.5019ms | 15.9584ms | 62.6628 Ops/s | 61.3305 Ops/s |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.9247s | 0.8352s | 1.1974 Ops/s | 1.2342 Ops/s | |
test_transformed | 1.5751s | 1.4831s | 0.6743 Ops/s | 0.7204 Ops/s | |
test_serial | 2.3348s | 2.3276s | 0.4296 Ops/s | 0.4395 Ops/s | |
test_parallel | 1.8737s | 1.8501s | 0.5405 Ops/s | 0.5247 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.2181ms | 39.4480μs | 25.3499 KOps/s | 24.8669 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 0.2084ms | 23.3873μs | 42.7583 KOps/s | 41.7959 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 0.2152ms | 22.7146μs | 44.0245 KOps/s | 44.3122 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 0.1913ms | 12.9656μs | 77.1273 KOps/s | 76.8508 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 0.2512ms | 42.8422μs | 23.3415 KOps/s | 23.1463 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 59.3010μs | 25.7878μs | 38.7780 KOps/s | 37.9085 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 57.0410μs | 25.0397μs | 39.9365 KOps/s | 39.6334 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 0.2081ms | 15.3561μs | 65.1208 KOps/s | 64.2779 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 0.1249ms | 43.6469μs | 22.9111 KOps/s | 22.1216 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 51.7410μs | 28.2656μs | 35.3787 KOps/s | 34.8829 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 51.7710μs | 24.8089μs | 40.3081 KOps/s | 40.1668 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 41.5800μs | 15.2632μs | 65.5173 KOps/s | 64.0967 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 81.8010μs | 47.3826μs | 21.1048 KOps/s | 21.1678 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 59.7110μs | 30.5440μs | 32.7396 KOps/s | 32.8600 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 54.7010μs | 26.7415μs | 37.3951 KOps/s | 37.1751 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 59.2710μs | 17.4605μs | 57.2721 KOps/s | 56.2251 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 82.1310μs | 44.8012μs | 22.3208 KOps/s | 22.1909 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 0.2168ms | 27.6648μs | 36.1470 KOps/s | 35.2172 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 2.6289ms | 29.0503μs | 34.4230 KOps/s | 34.7866 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 0.1943ms | 17.1984μs | 58.1449 KOps/s | 58.4806 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 0.2138ms | 47.8154μs | 20.9138 KOps/s | 20.7955 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 68.9910μs | 30.5250μs | 32.7600 KOps/s | 32.0963 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 0.1353ms | 31.0988μs | 32.1556 KOps/s | 31.9362 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 46.3300μs | 19.2687μs | 51.8975 KOps/s | 50.9353 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 79.8110μs | 49.6547μs | 20.1391 KOps/s | 19.7188 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 0.1200ms | 32.7121μs | 30.5697 KOps/s | 29.8974 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 0.1674ms | 30.5696μs | 32.7122 KOps/s | 32.4523 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 42.2810μs | 19.4340μs | 51.4561 KOps/s | 50.9271 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 76.7710μs | 51.5037μs | 19.4161 KOps/s | 19.2582 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 69.6310μs | 35.3274μs | 28.3067 KOps/s | 27.8861 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 66.3210μs | 32.1517μs | 31.1026 KOps/s | 29.9902 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 0.1321ms | 21.6747μs | 46.1368 KOps/s | 45.8517 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 25.5069ms | 25.0141ms | 39.9774 Ops/s | 38.9588 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 0.1123s | 3.1487ms | 317.5872 Ops/s | 351.7154 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.1059ms | 79.0372μs | 12.6523 KOps/s | 12.5983 KOps/s | |
test_values[td1_return_estimate-False-False] | 58.2833ms | 55.6867ms | 17.9576 Ops/s | 18.2474 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 1.3975ms | 1.0869ms | 920.0721 Ops/s | 918.2950 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 92.5602ms | 87.4584ms | 11.4340 Ops/s | 11.5141 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 1.3603ms | 1.0876ms | 919.4267 Ops/s | 923.3755 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 24.8300ms | 24.4789ms | 40.8515 Ops/s | 38.5277 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 1.0378ms | 0.7412ms | 1.3493 KOps/s | 1.3283 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.8303ms | 0.6613ms | 1.5122 KOps/s | 1.5002 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 1.6134ms | 1.4798ms | 675.7894 Ops/s | 672.6477 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 0.8405ms | 0.6766ms | 1.4780 KOps/s | 1.4736 KOps/s | |
test_dqn_speed[False-None] | 1.6871ms | 1.5261ms | 655.2808 Ops/s | 670.0316 Ops/s | |
test_dqn_speed[False-backward] | 2.3038ms | 2.1482ms | 465.5000 Ops/s | 445.7629 Ops/s | |
test_dqn_speed[True-None] | 0.7694ms | 0.5443ms | 1.8372 KOps/s | 1.7646 KOps/s | |
test_dqn_speed[True-backward] | 1.3331ms | 1.1110ms | 900.1268 Ops/s | 878.7431 Ops/s | |
test_dqn_speed[reduce-overhead-None] | 0.7192ms | 0.5631ms | 1.7757 KOps/s | 1.7556 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 1.1137ms | 1.0647ms | 939.2745 Ops/s | 925.1307 Ops/s | |
test_ddpg_speed[False-None] | 3.2907ms | 2.8505ms | 350.8194 Ops/s | 354.8874 Ops/s | |
test_ddpg_speed[False-backward] | 4.8511ms | 4.2849ms | 233.3757 Ops/s | 238.1097 Ops/s | |
test_ddpg_speed[True-None] | 1.5396ms | 1.3308ms | 751.4217 Ops/s | 743.2127 Ops/s | |
test_ddpg_speed[True-backward] | 2.7147ms | 2.5545ms | 391.4644 Ops/s | 382.5545 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.4892ms | 1.3380ms | 747.3729 Ops/s | 735.1310 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 2.1870ms | 2.0401ms | 490.1733 Ops/s | 486.1986 Ops/s | |
test_sac_speed[False-None] | 8.5707ms | 8.1243ms | 123.0875 Ops/s | 125.0589 Ops/s | |
test_sac_speed[False-backward] | 11.9029ms | 11.3386ms | 88.1943 Ops/s | 88.7735 Ops/s | |
test_sac_speed[True-None] | 2.1374ms | 1.8132ms | 551.5036 Ops/s | 540.9074 Ops/s | |
test_sac_speed[True-backward] | 4.0323ms | 3.7406ms | 267.3333 Ops/s | 261.2852 Ops/s | |
test_sac_speed[reduce-overhead-None] | 21.6437ms | 12.2161ms | 81.8594 Ops/s | 83.5995 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 1.9507ms | 1.7730ms | 564.0299 Ops/s | 553.3016 Ops/s | |
test_redq_speed[False-None] | 8.1382ms | 7.6603ms | 130.5424 Ops/s | 130.4825 Ops/s | |
test_redq_speed[False-backward] | 12.4383ms | 11.8767ms | 84.1982 Ops/s | 83.4695 Ops/s | |
test_redq_speed[True-None] | 2.5718ms | 2.2913ms | 436.4341 Ops/s | 427.1178 Ops/s | |
test_redq_speed[True-backward] | 4.4798ms | 4.0745ms | 245.4297 Ops/s | 243.6069 Ops/s | |
test_redq_speed[reduce-overhead-None] | 2.7349ms | 2.3772ms | 420.6668 Ops/s | 422.9766 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 4.3400ms | 4.0420ms | 247.4048 Ops/s | 243.3898 Ops/s | |
test_redq_deprec_speed[False-None] | 9.6577ms | 9.2638ms | 107.9474 Ops/s | 111.3625 Ops/s | |
test_redq_deprec_speed[False-backward] | 12.6487ms | 12.2379ms | 81.7133 Ops/s | 82.1957 Ops/s | |
test_redq_deprec_speed[True-None] | 2.7759ms | 2.5938ms | 385.5378 Ops/s | 371.5645 Ops/s | |
test_redq_deprec_speed[True-backward] | 4.7637ms | 4.4363ms | 225.4128 Ops/s | 226.9919 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 3.0391ms | 2.7009ms | 370.2492 Ops/s | 376.6509 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 4.7234ms | 4.2839ms | 233.4295 Ops/s | 222.0647 Ops/s | |
test_td3_speed[False-None] | 8.3547ms | 8.0963ms | 123.5130 Ops/s | 124.2848 Ops/s | |
test_td3_speed[False-backward] | 11.0561ms | 10.4129ms | 96.0349 Ops/s | 95.6305 Ops/s | |
test_td3_speed[True-None] | 1.7327ms | 1.6289ms | 613.9008 Ops/s | 608.7264 Ops/s | |
test_td3_speed[True-backward] | 3.5758ms | 3.3355ms | 299.8050 Ops/s | 311.3793 Ops/s | |
test_td3_speed[reduce-overhead-None] | 73.3739ms | 26.6845ms | 37.4749 Ops/s | 37.3839 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 1.6258ms | 1.4791ms | 676.0930 Ops/s | 714.1742 Ops/s | |
test_cql_speed[False-None] | 17.6116ms | 16.9611ms | 58.9584 Ops/s | 59.6010 Ops/s | |
test_cql_speed[False-backward] | 22.9305ms | 22.4954ms | 44.4535 Ops/s | 45.5974 Ops/s | |
test_cql_speed[True-None] | 3.5751ms | 3.2247ms | 310.1067 Ops/s | 302.7467 Ops/s | |
test_cql_speed[True-backward] | 5.7191ms | 5.4708ms | 182.7874 Ops/s | 179.8188 Ops/s | |
test_cql_speed[reduce-overhead-None] | 21.0894ms | 13.2963ms | 75.2091 Ops/s | 74.1865 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 2.1614ms | 1.9906ms | 502.3519 Ops/s | 531.5386 Ops/s | |
test_a2c_speed[False-None] | 3.4929ms | 3.1988ms | 312.6212 Ops/s | 317.0462 Ops/s | |
test_a2c_speed[False-backward] | 7.3119ms | 6.3702ms | 156.9821 Ops/s | 163.7294 Ops/s | |
test_a2c_speed[True-None] | 1.5664ms | 1.3538ms | 738.6746 Ops/s | 744.0483 Ops/s | |
test_a2c_speed[True-backward] | 3.3250ms | 3.0508ms | 327.7825 Ops/s | 336.9712 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 16.3836ms | 9.1722ms | 109.0246 Ops/s | 117.1250 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 1.9038ms | 1.6042ms | 623.3805 Ops/s | 675.0885 Ops/s | |
test_ppo_speed[False-None] | 3.9721ms | 3.6807ms | 271.6885 Ops/s | 273.8093 Ops/s | |
test_ppo_speed[False-backward] | 7.5975ms | 7.1092ms | 140.6627 Ops/s | 146.1732 Ops/s | |
test_ppo_speed[True-None] | 1.6882ms | 1.4093ms | 709.5722 Ops/s | 691.9887 Ops/s | |
test_ppo_speed[True-backward] | 3.3696ms | 3.2261ms | 309.9759 Ops/s | 318.1851 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 1.1717ms | 0.9693ms | 1.0317 KOps/s | 1.0273 KOps/s | |
test_ppo_speed[reduce-overhead-backward] | 1.7267ms | 1.5599ms | 641.0786 Ops/s | 630.5956 Ops/s | |
test_reinforce_speed[False-None] | 2.6018ms | 2.2696ms | 440.6004 Ops/s | 446.0347 Ops/s | |
test_reinforce_speed[False-backward] | 3.9756ms | 3.3929ms | 294.7367 Ops/s | 289.9044 Ops/s | |
test_reinforce_speed[True-None] | 1.5696ms | 1.2903ms | 775.0094 Ops/s | 757.9694 Ops/s | |
test_reinforce_speed[True-backward] | 3.2300ms | 3.0634ms | 326.4336 Ops/s | 338.8911 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 21.9539ms | 10.5809ms | 94.5097 Ops/s | 94.1066 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 1.7475ms | 1.6265ms | 614.8139 Ops/s | 655.4789 Ops/s | |
test_iql_speed[False-None] | 9.6613ms | 9.2388ms | 108.2389 Ops/s | 108.0869 Ops/s | |
test_iql_speed[False-backward] | 13.8969ms | 13.2550ms | 75.4434 Ops/s | 75.9674 Ops/s | |
test_iql_speed[True-None] | 2.4939ms | 2.2014ms | 454.2576 Ops/s | 440.8169 Ops/s | |
test_iql_speed[True-backward] | 5.3249ms | 4.9093ms | 203.6963 Ops/s | 202.7484 Ops/s | |
test_iql_speed[reduce-overhead-None] | 0.5182s | 13.2163ms | 75.6644 Ops/s | 89.2245 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 2.2104ms | 2.0611ms | 485.1695 Ops/s | 522.5459 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 7.8196ms | 6.2835ms | 159.1468 Ops/s | 158.0969 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.5696ms | 0.3329ms | 3.0042 KOps/s | 3.0289 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.6628ms | 0.3187ms | 3.1379 KOps/s | 3.1384 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.3445ms | 5.9421ms | 168.2906 Ops/s | 166.1735 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.9804ms | 0.2803ms | 3.5676 KOps/s | 3.1510 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.6041ms | 0.2804ms | 3.5657 KOps/s | 3.4266 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 1.5649ms | 1.3297ms | 752.0365 Ops/s | 766.9202 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 1.6469ms | 1.2274ms | 814.7461 Ops/s | 817.7161 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.4650ms | 6.1805ms | 161.8003 Ops/s | 160.6027 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 2.3247ms | 0.4212ms | 2.3743 KOps/s | 2.3295 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.6676ms | 0.4400ms | 2.2726 KOps/s | 2.5751 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 10.1238ms | 6.0721ms | 164.6876 Ops/s | 164.9081 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.9174ms | 0.3644ms | 2.7440 KOps/s | 3.6106 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 1.1762ms | 0.3385ms | 2.9543 KOps/s | 4.1138 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.4168ms | 5.9597ms | 167.7937 Ops/s | 165.6002 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.8844ms | 0.2648ms | 3.7763 KOps/s | 3.4537 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.5909ms | 0.3176ms | 3.1484 KOps/s | 3.8983 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.3663ms | 6.1368ms | 162.9524 Ops/s | 159.4591 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.9650ms | 0.4729ms | 2.1146 KOps/s | 2.1019 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.7605ms | 0.4502ms | 2.2213 KOps/s | 2.2477 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 7.0833ms | 5.5023ms | 181.7423 Ops/s | 177.5842 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 11.0300ms | 2.1469ms | 465.7930 Ops/s | 438.5409 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 2.1413ms | 1.1575ms | 863.9579 Ops/s | 818.4623 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 9.1307ms | 5.6813ms | 176.0172 Ops/s | 175.5703 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 6.3653ms | 2.0418ms | 489.7683 Ops/s | 432.2026 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 8.9541ms | 1.3005ms | 768.9551 Ops/s | 887.3345 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.5457s | 16.5737ms | 60.3367 Ops/s | 30.0885 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 9.4022ms | 2.2350ms | 447.4188 Ops/s | 542.2584 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 7.9889ms | 1.3735ms | 728.0730 Ops/s | 819.3996 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 13.9209ms | 13.1925ms | 75.8007 Ops/s | 73.8555 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 19.1655ms | 17.2223ms | 58.0644 Ops/s | 56.9792 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 18.8522ms | 18.3473ms | 54.5041 Ops/s | 53.6428 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 19.4398ms | 17.5398ms | 57.0133 Ops/s | 57.3872 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 18.7709ms | 17.9678ms | 55.6552 Ops/s | 54.4511 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 20.5085ms | 18.9700ms | 52.7147 Ops/s | 52.9621 Ops/s |
ghstack-source-id: 8994418 Pull Request resolved: pytorch#2796
ghstack-source-id: 8994418 Pull Request resolved: pytorch#2796
ghstack-source-id: 8994418 Pull Request resolved: pytorch#2796
ghstack-source-id: 08ebabd Pull Request resolved: pytorch#2796
@kurtamohler LMK if you need help with this! |
Just wanted to mention that I am actively working on this. It was a little tricky to get ChessEnv working correctly with |
ghstack-source-id: bd98430 Pull Request resolved: pytorch#2796
Super cool thanks! |
ghstack-source-id: bd98430 Pull Request resolved: pytorch#2796
I was able to improve the performance a fair bit, but it is now around 17x slower (down from 100x) than my standalone example code. But I suppose we can make further performance improvements later down the road. At the moment, this PR is kind of a mess, so I'll fix it up and probably split out the stuff that is not directly related to MCTS into a separate PR |
Do you know what's causing the slowdown? TensorDict overhead? |
I've been using py-spy to find bottlenecks and improve them. Here's the flamegraph that it produces right now: At the moment, My overall runtime measurements have just been the time it takes to run the entire script, including module imports. Importing pytorch and torchrl is evidently a significant part of the whole runtime ~25%, so performance is actually a little better than what I said before. It's actually about 13x slower than the standalone script after taking that into account. |
ghstack-source-id: 15144df Pull Request resolved: pytorch#2796
I think this is ready for review. I'll add MCTS APIs in subsequent PRs and update the example script accordingly |
ghstack-source-id: 4cf2a16 Pull Request resolved: pytorch#2796
ghstack-source-id: 4cf2a16 Pull Request resolved: pytorch#2796
ghstack-source-id: 4cf2a16 Pull Request resolved: pytorch#2796
ghstack-source-id: 9ee7dc3 Pull Request resolved: pytorch#2796
ghstack-source-id: cfaa730 Pull Request resolved: pytorch#2796
|
||
# If it's black's turn, flip the reward, since black wants to optimize for | ||
# the lowest reward, not highest. | ||
# TODO: Need a more generic way to do this, since not all use cases of MCTS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll need to come up with a good way to specify how to process the exploitation value for each different player.
During an MCTS traversal, when we visit each node, we have to rank all the child nodes and decide which one to traverse to, using one of the standard exploration/exploitation formulas. I only know about the UCB1 formula at the moment, so I'll focus on that one. To calculate the expoitation value of each child node, one of the values that the UCB1 formula operates on is the "reward average" attached to each child node--the sum of all the rewards of the rollouts that have been performed under the child node divided by the number of times that child node has been visited during traversals.
Let's say that at the end of a rollout, the reward value that the chess environment gives for a white win is 1, black win is -1, and draw is 0. In order to do the exploitation part correctly, we should assume that each player wants to maximize their chances of winning. So if it's white's turn at a particular node, we want exploitation actions to maximize the reward average. But if it's black's turn, we want expoitation actions to minimize the reward average, so we have to flip the sign of the reward average when we calculate the UCB1 value on black's turn.
Or let's say we instead want to use a two-element reward. A white win is [1, -1]
, black win is [-1, 1]
, and draw is [0, 0]
. Now the reward average at each node of the MCTS tree has two elements. When it's white's turn at a particular node, we want to look at the first element of the reward average of each child node. When it's black's turn, we want to look at the second element.
Ideally, our MCTS API should be able to handle both of the above reward schemes, and any other sensible kind of reward scheme that users would want to use, for environments with any number of agents.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking of just adding an argument to the MCTS
function called select_player_reward_fn
, or something along those lines, which is a callable with the signature (td, reward_average)
. In the two examples I gave above, the user could specify something like this:
select_player_reward_fn=lambda td, reward_avg: reward_avg if td["turn"] else -reward_avg
for 1-element reward between 1 and -1
select_player_reward_fn=lambda td, reward_avg: reward_avg[td["turn"]]
for n-element reward
In order to have a sensible default, we could assume that the reward is normally n-element, one value for each player, since that is the normal setup for multi-agent environments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was curious about how other libraries handle this, so I looked into a couple of them.
The mcts
Python library has a sign flip for two-player games. The total reward is multiplied by the player value (https://github.com/pbsinclair42/MCTS/blob/f4b6226ce06840e51b66bdca4bedbe2c4c143012/mcts.py#L112), which is either 1 or -1 in the example scripts of this repo. So this library seems to be specifically for two player games.
The MCTS implementation within the open_spiel
library has an n-element reward, one for each player, and chooses the reward for whichever player whose turn it is, like the default behavior I proposed in the last comment: https://github.com/google-deepmind/open_spiel/blob/8296179b697644cf957c7c9313f594c062cbd17c/open_spiel/python/algorithms/mcts.py#L368-L369
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think maybe our implementation of MCTS should require the reward to have one element per player, and also require that the environment output spec has a key indicating which player's turn it is. Or at least that could be the default behavior and we could allow the user to override it with a callable argument.
Let me know what you think @vmoens.
Stack from ghstack (oldest at bottom):