-
Notifications
You must be signed in to change notification settings - Fork 2k
Use proper dtype for RolloutBuffer storage
#2163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use proper dtype for RolloutBuffer storage
#2163
Conversation
…rmalize_obs for rollout buffers
|
@araffin I've removed |
dtype for RolloutBuffer
dtype for RolloutBufferdtype for RolloutBuffer storage
araffin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks =)
Description
BaseBuffer.__init__BufferDTypesdataclassactionstotorch.float32forRolloutBufferandDictRolloutBufferto avoid breaking calculations built withtorch.float32in mindTypes of changes
Motivation and Context
closes #2162
After inspecting the code in
common/buffers.py, I observe the following:RolloutBufferbloodline always usedtype=np.float32for all arraysReplayBufferbloodline usesdtypefrom observation & action spaces (gymnasium.spaces.Spaceobjects) forself.observations/self.next_observations/self.actionsThis lack of uniformness introduces a few problems:
np.uint8observationsChecklist
make format(required)make check-codestyleandmake lint(required)make pytestandmake typeboth pass. (required)make doc(required)Note: You can run most of the checks using
make commit-checks.Note: we are using a maximum length of 127 characters per line