File tree Expand file tree Collapse file tree 4 files changed +13
-10
lines changed Expand file tree Collapse file tree 4 files changed +13
-10
lines changed Original file line number Diff line number Diff line change @@ -179,20 +179,21 @@ def learn(
179
179
progress_bar = progress_bar ,
180
180
)
181
181
182
- def train (self , batch_size : int , gradient_steps : int ):
182
+ def train (self , gradient_steps : int , batch_size : int ) -> None :
183
+ assert self .replay_buffer is not None
183
184
# Sample all at once for efficiency (so we can jit the for loop)
184
185
data = self .replay_buffer .sample (batch_size * gradient_steps , env = self ._vec_normalize_env )
185
186
186
187
if isinstance (data .observations , dict ):
187
- keys = list (self .observation_space .keys ())
188
+ keys = list (self .observation_space .keys ()) # type: ignore[attr-defined]
188
189
obs = np .concatenate ([data .observations [key ].numpy () for key in keys ], axis = 1 )
189
190
next_obs = np .concatenate ([data .next_observations [key ].numpy () for key in keys ], axis = 1 )
190
191
else :
191
192
obs = data .observations .numpy ()
192
193
next_obs = data .next_observations .numpy ()
193
194
194
195
# Convert to numpy
195
- data = ReplayBufferSamplesNp (
196
+ data = ReplayBufferSamplesNp ( # type: ignore[assignment]
196
197
obs ,
197
198
data .actions .numpy (),
198
199
next_obs ,
Original file line number Diff line number Diff line change @@ -120,20 +120,21 @@ def learn(
120
120
progress_bar = progress_bar ,
121
121
)
122
122
123
- def train (self , batch_size , gradient_steps ):
123
+ def train (self , gradient_steps : int , batch_size : int ) -> None :
124
+ assert self .replay_buffer is not None
124
125
# Sample all at once for efficiency (so we can jit the for loop)
125
126
data = self .replay_buffer .sample (batch_size * gradient_steps , env = self ._vec_normalize_env )
126
127
127
128
if isinstance (data .observations , dict ):
128
- keys = list (self .observation_space .keys ())
129
+ keys = list (self .observation_space .keys ()) # type: ignore[attr-defined]
129
130
obs = np .concatenate ([data .observations [key ].numpy () for key in keys ], axis = 1 )
130
131
next_obs = np .concatenate ([data .next_observations [key ].numpy () for key in keys ], axis = 1 )
131
132
else :
132
133
obs = data .observations .numpy ()
133
134
next_obs = data .next_observations .numpy ()
134
135
135
136
# Convert to numpy
136
- data = ReplayBufferSamplesNp (
137
+ data = ReplayBufferSamplesNp ( # type: ignore[assignment]
137
138
obs ,
138
139
data .actions .numpy (),
139
140
next_obs ,
Original file line number Diff line number Diff line change @@ -180,20 +180,21 @@ def learn(
180
180
progress_bar = progress_bar ,
181
181
)
182
182
183
- def train (self , batch_size , gradient_steps ):
183
+ def train (self , gradient_steps : int , batch_size : int ) -> None :
184
+ assert self .replay_buffer is not None
184
185
# Sample all at once for efficiency (so we can jit the for loop)
185
186
data = self .replay_buffer .sample (batch_size * gradient_steps , env = self ._vec_normalize_env )
186
187
187
188
if isinstance (data .observations , dict ):
188
- keys = list (self .observation_space .keys ())
189
+ keys = list (self .observation_space .keys ()) # type: ignore[attr-defined]
189
190
obs = np .concatenate ([data .observations [key ].numpy () for key in keys ], axis = 1 )
190
191
next_obs = np .concatenate ([data .next_observations [key ].numpy () for key in keys ], axis = 1 )
191
192
else :
192
193
obs = data .observations .numpy ()
193
194
next_obs = data .next_observations .numpy ()
194
195
195
196
# Convert to numpy
196
- data = ReplayBufferSamplesNp (
197
+ data = ReplayBufferSamplesNp ( # type: ignore[assignment]
197
198
obs ,
198
199
data .actions .numpy (),
199
200
next_obs ,
Original file line number Diff line number Diff line change 1
- 0.9.1
1
+ 0.10.0
You can’t perform that action at this time.
0 commit comments