Skip to content

Commit dbe8760

Browse files
committed
Hotfix for train signature
1 parent 36febf0 commit dbe8760

File tree

4 files changed

+13
-10
lines changed

4 files changed

+13
-10
lines changed

sbx/sac/sac.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,20 +179,21 @@ def learn(
179179
progress_bar=progress_bar,
180180
)
181181

182-
def train(self, batch_size: int, gradient_steps: int):
182+
def train(self, gradient_steps: int, batch_size: int) -> None:
183+
assert self.replay_buffer is not None
183184
# Sample all at once for efficiency (so we can jit the for loop)
184185
data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env)
185186

186187
if isinstance(data.observations, dict):
187-
keys = list(self.observation_space.keys())
188+
keys = list(self.observation_space.keys()) # type: ignore[attr-defined]
188189
obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1)
189190
next_obs = np.concatenate([data.next_observations[key].numpy() for key in keys], axis=1)
190191
else:
191192
obs = data.observations.numpy()
192193
next_obs = data.next_observations.numpy()
193194

194195
# Convert to numpy
195-
data = ReplayBufferSamplesNp(
196+
data = ReplayBufferSamplesNp( # type: ignore[assignment]
196197
obs,
197198
data.actions.numpy(),
198199
next_obs,

sbx/td3/td3.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,20 +120,21 @@ def learn(
120120
progress_bar=progress_bar,
121121
)
122122

123-
def train(self, batch_size, gradient_steps):
123+
def train(self, gradient_steps: int, batch_size: int) -> None:
124+
assert self.replay_buffer is not None
124125
# Sample all at once for efficiency (so we can jit the for loop)
125126
data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env)
126127

127128
if isinstance(data.observations, dict):
128-
keys = list(self.observation_space.keys())
129+
keys = list(self.observation_space.keys()) # type: ignore[attr-defined]
129130
obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1)
130131
next_obs = np.concatenate([data.next_observations[key].numpy() for key in keys], axis=1)
131132
else:
132133
obs = data.observations.numpy()
133134
next_obs = data.next_observations.numpy()
134135

135136
# Convert to numpy
136-
data = ReplayBufferSamplesNp(
137+
data = ReplayBufferSamplesNp( # type: ignore[assignment]
137138
obs,
138139
data.actions.numpy(),
139140
next_obs,

sbx/tqc/tqc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,20 +180,21 @@ def learn(
180180
progress_bar=progress_bar,
181181
)
182182

183-
def train(self, batch_size, gradient_steps):
183+
def train(self, gradient_steps: int, batch_size: int) -> None:
184+
assert self.replay_buffer is not None
184185
# Sample all at once for efficiency (so we can jit the for loop)
185186
data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env)
186187

187188
if isinstance(data.observations, dict):
188-
keys = list(self.observation_space.keys())
189+
keys = list(self.observation_space.keys()) # type: ignore[attr-defined]
189190
obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1)
190191
next_obs = np.concatenate([data.next_observations[key].numpy() for key in keys], axis=1)
191192
else:
192193
obs = data.observations.numpy()
193194
next_obs = data.next_observations.numpy()
194195

195196
# Convert to numpy
196-
data = ReplayBufferSamplesNp(
197+
data = ReplayBufferSamplesNp( # type: ignore[assignment]
197198
obs,
198199
data.actions.numpy(),
199200
next_obs,

sbx/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.9.1
1+
0.10.0

0 commit comments

Comments
 (0)