Skip to content

Commit

Permalink
Hotfix for train signature
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Dec 13, 2023
1 parent 36febf0 commit dbe8760
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 10 deletions.
7 changes: 4 additions & 3 deletions sbx/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,20 +179,21 @@ def learn(
progress_bar=progress_bar,
)

def train(self, batch_size: int, gradient_steps: int):
def train(self, gradient_steps: int, batch_size: int) -> None:
assert self.replay_buffer is not None
# Sample all at once for efficiency (so we can jit the for loop)
data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env)

if isinstance(data.observations, dict):
keys = list(self.observation_space.keys())
keys = list(self.observation_space.keys()) # type: ignore[attr-defined]
obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1)
next_obs = np.concatenate([data.next_observations[key].numpy() for key in keys], axis=1)
else:
obs = data.observations.numpy()
next_obs = data.next_observations.numpy()

# Convert to numpy
data = ReplayBufferSamplesNp(
data = ReplayBufferSamplesNp( # type: ignore[assignment]
obs,
data.actions.numpy(),
next_obs,
Expand Down
7 changes: 4 additions & 3 deletions sbx/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,20 +120,21 @@ def learn(
progress_bar=progress_bar,
)

def train(self, batch_size, gradient_steps):
def train(self, gradient_steps: int, batch_size: int) -> None:
assert self.replay_buffer is not None
# Sample all at once for efficiency (so we can jit the for loop)
data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env)

if isinstance(data.observations, dict):
keys = list(self.observation_space.keys())
keys = list(self.observation_space.keys()) # type: ignore[attr-defined]
obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1)
next_obs = np.concatenate([data.next_observations[key].numpy() for key in keys], axis=1)
else:
obs = data.observations.numpy()
next_obs = data.next_observations.numpy()

# Convert to numpy
data = ReplayBufferSamplesNp(
data = ReplayBufferSamplesNp( # type: ignore[assignment]
obs,
data.actions.numpy(),
next_obs,
Expand Down
7 changes: 4 additions & 3 deletions sbx/tqc/tqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,20 +180,21 @@ def learn(
progress_bar=progress_bar,
)

def train(self, batch_size, gradient_steps):
def train(self, gradient_steps: int, batch_size: int) -> None:
assert self.replay_buffer is not None
# Sample all at once for efficiency (so we can jit the for loop)
data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env)

if isinstance(data.observations, dict):
keys = list(self.observation_space.keys())
keys = list(self.observation_space.keys()) # type: ignore[attr-defined]
obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1)
next_obs = np.concatenate([data.next_observations[key].numpy() for key in keys], axis=1)
else:
obs = data.observations.numpy()
next_obs = data.next_observations.numpy()

# Convert to numpy
data = ReplayBufferSamplesNp(
data = ReplayBufferSamplesNp( # type: ignore[assignment]
obs,
data.actions.numpy(),
next_obs,
Expand Down
2 changes: 1 addition & 1 deletion sbx/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.9.1
0.10.0

0 comments on commit dbe8760

Please sign in to comment.