Collapse PPO and DD-PPO trainers, faster RNN code, and double buffered sampling
Motivation and Context
This contains most of my local improvements to habitat-baselines. I also collapsed the PPO and DD-PPO trainers as those don't really have a reason to be separate besides history.
Also made some changes to get test time lower. Working on this was a PITA otherwise.
- Collapse PPO and DD-PPO trainers
- Faster RNN code -- it is definitely faster and can make a noticeable impact during early training (~20% faster in some cases), but good luck reading it :-)
NUM_SIMULATORS. The fact that the simulators are in different processes is an implementation detail. A backwards compatibility check has been added tho.
- Support specifying training length both in terms of number of updates and number of frames
- Specify the number of checkpoints as the number of checkpoints instead of a checkpoint interval
- Introduce a TensorDict class for more cleanly interacting with dictionaries of tensors (potentially recursive). This also makes
RolloutStorageabout 100x cleaner.
- Store RGB observations as their proper dtype in the rollout storage (this can save a lot of memory)
- Some refactoring of PPOTrainer.train to be less of a script wrapped in a function
- Double buffered sampling. This can improve performance when simulation time is equal or larger than policy inference time
A note for reviewing: There are 3 blocks to this that are (mostly) independent conceptually. The PPOTrainer refactor + double buffered sampling, TensorDict + RolloutStorage refactor, and the RNN code.
To come in a different PR: - torch.cuda.amp support - FP16 inference and mixed-precision training support (aka apex.amp O2 mode without apex because it is kinda a pain)
How Has This Been Tested
Existing tests and new test for the RNN code.
Types of changes
- Docs change / refactoring / dependency upgrade
- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)
- Breaking change (fix or feature that would cause existing functionality to change)