Skip to content

Commit 1c79684

Browse files
jameshealdaraffin
andauthored
Optimize the log of the entropy coeff instead of the entropy coeff (#56)
* optimize the log of the entropy coeff instead of the entropy coeff * Update log ent coef for SAC and derivates * Reformat yaml * Use uv for faster downloads * Remove TODO * Remove redundant call --------- Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent 19c85a1 commit 1c79684

File tree

6 files changed

+45
-40
lines changed

6 files changed

+45
-40
lines changed

.github/workflows/ci.yml

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ name: CI
55

66
on:
77
push:
8-
branches: [ master ]
8+
branches: [master]
99
pull_request:
10-
branches: [ master ]
10+
branches: [master]
1111

1212
jobs:
1313
build:
@@ -23,34 +23,37 @@ jobs:
2323
python-version: ["3.8", "3.9", "3.10", "3.11"]
2424

2525
steps:
26-
- uses: actions/checkout@v3
27-
- name: Set up Python ${{ matrix.python-version }}
28-
uses: actions/setup-python@v4
29-
with:
30-
python-version: ${{ matrix.python-version }}
31-
- name: Install dependencies
32-
run: |
33-
python -m pip install --upgrade pip
34-
# cpu version of pytorch
35-
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
26+
- uses: actions/checkout@v3
27+
- name: Set up Python ${{ matrix.python-version }}
28+
uses: actions/setup-python@v4
29+
with:
30+
python-version: ${{ matrix.python-version }}
31+
- name: Install dependencies
32+
run: |
33+
python -m pip install --upgrade pip
34+
# Use uv for faster downloads
35+
pip install uv
36+
# cpu version of pytorch
37+
# See https://github.com/astral-sh/uv/issues/1497
38+
uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu
3639
37-
pip install .[tests]
38-
# Use headless version
39-
pip install opencv-python-headless
40-
- name: Lint with ruff
41-
run: |
42-
make lint
43-
# - name: Build the doc
44-
# run: |
45-
# make doc
46-
- name: Check codestyle
47-
run: |
48-
make check-codestyle
49-
- name: Type check
50-
run: |
51-
make type
52-
# skip mypy, jax doesn't have its latest version for python 3.8
53-
if: "!(matrix.python-version == '3.8')"
54-
- name: Test with pytest
55-
run: |
56-
make pytest
40+
uv pip install --system .[tests]
41+
# Use headless version
42+
uv pip install --system opencv-python-headless
43+
- name: Lint with ruff
44+
run: |
45+
make lint
46+
# - name: Build the doc
47+
# run: |
48+
# make doc
49+
- name: Check codestyle
50+
run: |
51+
make check-codestyle
52+
- name: Type check
53+
run: |
54+
make type
55+
# skip mypy, jax doesn't have its latest version for python 3.8
56+
if: "!(matrix.python-version == '3.8')"
57+
- name: Test with pytest
58+
run: |
59+
make pytest

sbx/crossq/crossq.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,8 +355,10 @@ def actor_loss(
355355
@jax.jit
356356
def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float):
357357
def temperature_loss(temp_params: flax.core.FrozenDict) -> jax.Array:
358+
# Note: we optimize the log of the entropy coeff which is slightly different from the paper
359+
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
358360
ent_coef_value = ent_coef_state.apply_fn({"params": temp_params})
359-
ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() # type: ignore[union-attr]
361+
ent_coef_loss = jnp.log(ent_coef_value) * (entropy - target_entropy).mean() # type: ignore[union-attr]
360362
return ent_coef_loss
361363

362364
ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params)

sbx/dqn/policies.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array:
107107
),
108108
)
109109

110-
# TODO: jit qf.apply_fn too?
111110
self.qf.apply = jax.jit(self.qf.apply) # type: ignore[method-assign]
112111

113112
return key

sbx/sac/sac.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,6 @@ def _setup_model(self) -> None:
141141
ent_coef_init = float(self.ent_coef_init.split("_")[1])
142142
assert ent_coef_init > 0.0, "The initial value of ent_coef must be greater than 0"
143143

144-
# Note: we optimize the log of the entropy coeff which is slightly different from the paper
145-
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
146144
self.ent_coef = EntropyCoef(ent_coef_init)
147145
else:
148146
# This will throw an error if a malformed string (different from 'auto') is passed
@@ -325,8 +323,10 @@ def soft_update(tau: float, qf_state: RLTrainState) -> RLTrainState:
325323
@jax.jit
326324
def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float):
327325
def temperature_loss(temp_params: flax.core.FrozenDict) -> jax.Array:
326+
# Note: we optimize the log of the entropy coeff which is slightly different from the paper
327+
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
328328
ent_coef_value = ent_coef_state.apply_fn({"params": temp_params})
329-
ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() # type: ignore[union-attr]
329+
ent_coef_loss = jnp.log(ent_coef_value) * (entropy - target_entropy).mean() # type: ignore[union-attr]
330330
return ent_coef_loss
331331

332332
ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params)

sbx/tqc/tqc.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,9 +383,10 @@ def soft_update(tau: float, qf1_state: RLTrainState, qf2_state: RLTrainState) ->
383383
@jax.jit
384384
def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float):
385385
def temperature_loss(temp_params: flax.core.FrozenDict) -> jax.Array:
386+
# Note: we optimize the log of the entropy coeff which is slightly different from the paper
387+
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
386388
ent_coef_value = ent_coef_state.apply_fn({"params": temp_params})
387-
# ent_coef_loss = (jnp.log(ent_coef_value) * (entropy - target_entropy)).mean()
388-
ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() # type: ignore[union-attr]
389+
ent_coef_loss = jnp.log(ent_coef_value) * (entropy - target_entropy).mean() # type: ignore[union-attr]
389390
return ent_coef_loss
390391

391392
ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params)

sbx/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.17.0
1+
0.18.0

0 commit comments

Comments
 (0)