Skip to content

Commit

Permalink
Add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Dec 13, 2023
1 parent 7d3f328 commit 36febf0
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion sbx/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def learn(
progress_bar=progress_bar,
)

def train(self, batch_size, gradient_steps):
def train(self, batch_size: int, gradient_steps: int):
# Sample all at once for efficiency (so we can jit the for loop)
data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env)

Expand Down Expand Up @@ -408,7 +408,9 @@ def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]:

(actor_state, qf_state, ent_coef_state, actor_loss_value, ent_coef_loss_value, key) = jax.lax.cond(
(policy_delay_offset + i) % policy_delay_interval == 0,
# If True:
cls.update_actor_and_temperature,
# If False:
lambda *_: (actor_state, qf_state, ent_coef_state, info["actor_loss"], info["ent_coef_loss"], key),
actor_state,
qf_state,
Expand Down

0 comments on commit 36febf0

Please sign in to comment.