Skip to content

Commit

Permalink
Add test requirements, upgrade black (#437)
Browse files Browse the repository at this point in the history
* ignoring virtual env, adding test requirements, and ran 'black .' to auto format files'

* updated documentation for change of adding tests requirements as extra

* commented on fix in changelong

* need minigrid for https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/ppo.yml#L290C16-L290C48

* make format

* fixed typo

* fixed typo

* Fix ruff config

* Simplify dependencies

---------

Co-authored-by: Nikolaus Schuetz <[email protected]>
  • Loading branch information
araffin and nikolauspschuetz authored Feb 19, 2024
1 parent 8cecab4 commit aa38145
Show file tree
Hide file tree
Showing 19 changed files with 51 additions and 48 deletions.
8 changes: 2 additions & 6 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,11 @@ jobs:
# cpu version of pytorch - faster to download
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
pip install pybullet==3.2.5
# for v4 MuJoCo envs:
pip install mujoco
pip install -r requirements.txt
# Use headless version
pip install opencv-python-headless
# install parking-env to test HER
pip install highway-env==1.8.1
pip install -e .
pip install -e .[plots,tests]
- name: Lint with ruff
run: |
make lint
Expand Down
4 changes: 1 addition & 3 deletions .github/workflows/trained_agents.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,10 @@ jobs:
# cpu version of pytorch - faster to download
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
pip install pybullet==3.2.5
pip install -r requirements.txt
# Use headless version
pip install opencv-python-headless
pip install highway-env==1.8.1
pip install -e .
pip install -e .[plots,tests]
- name: Check trained agents
run: |
make check-trained-agents
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,12 @@ keys/
.cache
*.lprof
*.prof

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## Release 2.3.0a1 (WIP)
## Release 2.3.0a2 (WIP)

### Breaking Changes
- Updated defaults hyperparameters for TD3/DDPG to be more consistent with SAC
Expand All @@ -13,7 +13,8 @@
### Documentation

### Other

- Added test dependencies to `setup.py` (@power-edge)
- Simplify dependencies of `requirements.txt` (remove duplicates from `setup.py`)


## Release 2.2.1 (2023-11-17)
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type: mypy
lint:
# stop the build if there are Python syntax errors or undefined names
# see https://www.flake8rules.com/
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
# exit-zero treats all errors as warnings.
ruff ${LINT_PATHS} --exit-zero

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Note: you can do `python -m rl_zoo3.train` from any folder and you have access t
```
apt-get install swig cmake ffmpeg
pip install -r requirements.txt
pip install -e .[plots,tests]
```

Please see [Stable Baselines3 documentation](https://stable-baselines3.readthedocs.io/en/master/) for alternatives to install stable baselines3.
Expand Down
1 change: 1 addition & 0 deletions docs/guide/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ With extra envs and test dependencies:
apt-get install swig cmake ffmpeg
pip install -r requirements.txt
pip install -e .[plots,tests]
Please see `Stable Baselines3 documentation <https://stable-baselines3.readthedocs.io/en/master/>`_ for alternatives to install stable baselines3.
Expand Down
2 changes: 1 addition & 1 deletion hyperparams/ppo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ MiniGrid-Empty-Random-5x5-v0: &minigrid-defaults
gae_lambda: 0.95 # Factor for trade-off of bias vs variance for Generalized Advantage Estimator
gamma: 0.99
n_epochs: 10 # Number of epoch when optimizing the surrogate
ent_coef: 0.0 # Entropy coefficient for the loss caculation
ent_coef: 0.0 # Entropy coefficient for the loss calculation
learning_rate: 2.5e-4 # The learning rate, it can be a function
clip_range: 0.2 # Clipping parameter, it can be a function

Expand Down
1 change: 1 addition & 0 deletions hyperparams/python/ppo_config_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""This file just serves as an example on how to configure the zoo
using python scripts instead of yaml files."""

import torch

hyperparams = {
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
line-length = 127
# Assume Python 3.8
target-version = "py38"

[tool.ruff.lint]
# See https://beta.ruff.rs/docs/rules/
select = ["E", "F", "B", "UP", "C90", "RUF"]
# Ignore explicit stacklevel`
ignore = ["B028"]

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"./rl_zoo3/import_envs.py"= ["F401"]
# "./rl_zoo3/plots/plot_train.py"= ["E501"]


[tool.ruff.mccabe]
[tool.ruff.lint.mccabe]
# Unlike Flake8, default to a complexity level of 10.
max-complexity = 15

Expand Down
12 changes: 1 addition & 11 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,22 +1,12 @@
gym==0.26.2
stable-baselines3[extra_no_roms,tests,docs]>=2.3.0a1,<3.0
sb3-contrib>=2.3.0a1,<3.0
box2d-py==2.3.8
pybullet
pybullet_envs_gymnasium>=0.4.0
# minigrid
# scikit-optimize
optuna~=3.0
pyyaml>=5.1
cloudpickle>=2.2.1
# optuna plots:
plotly
# need to upgrade to gymnasium:
# panda-gym~=3.0.1
rliable>=1.0.5
wandb
huggingface_sb3>=3.0,<4.0
seaborn
tqdm
rich
moviepy
ruff
1 change: 1 addition & 0 deletions rl_zoo3/plots/plot_train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Plot training reward/success rate
"""

import argparse
import os

Expand Down
1 change: 1 addition & 0 deletions rl_zoo3/plots/score_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Max score corresponds to acceptable performance, for instance
human level performance in the case of Atari games.
"""

from typing import NamedTuple

import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion rl_zoo3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def train() -> None:

args = parser.parse_args()

# Going through custom gym packages to let them register in the global registory
# Going through custom gym packages to let them register in the global registry
for env_module in args.gym_packages:
importlib.import_module(env_module)

Expand Down
2 changes: 1 addition & 1 deletion rl_zoo3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.3.0a1
2.3.0a2
1 change: 1 addition & 0 deletions scripts/create_cluster_jobs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Send multiple jobs to the cluster.
"""

import os
import subprocess
import time
Expand Down
1 change: 1 addition & 0 deletions scripts/run_jobs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Run multiple experiments on a single machine.
"""

import subprocess
from typing import List

Expand Down
33 changes: 19 additions & 14 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,23 @@
See https://github.com/DLR-RM/rl-baselines3-zoo
"""
install_requires = [
"sb3_contrib>=2.3.0a1,<3.0",
"gymnasium~=0.29.1",
"huggingface_sb3>=3.0,<4.0",
"tqdm",
"rich",
"optuna>=3.0",
"pyyaml>=5.1",
"pytablewriter~=1.2",
]
plots_requires = ["seaborn", "rliable>=1.0.5", "scipy~=1.10"]
test_requires = [
# for MuJoCo envs v4:
"mujoco~=2.3",
# install parking-env to test HER
"highway-env==1.8.2",
]

setup(
name="rl_zoo3",
Expand All @@ -26,20 +43,8 @@
]
},
entry_points={"console_scripts": ["rl_zoo3=rl_zoo3.cli:main"]},
install_requires=[
"sb3_contrib>=2.3.0a1,<3.0",
"gymnasium~=0.29.1",
"huggingface_sb3>=3.0,<4.0",
"tqdm",
"rich",
"optuna>=3.0",
"pyyaml>=5.1",
"pytablewriter~=1.2",
# TODO: add test dependencies
],
extras_require={
"plots": ["seaborn", "rliable>=1.0.5", "scipy~=1.10"],
},
install_requires=install_requires,
extras_require={"plots": plots_requires, "tests": test_requires},
description="A Training Framework for Stable Baselines3 Reinforcement Learning Agents",
author="Antonin Raffin",
url="https://github.com/DLR-RM/rl-baselines3-zoo",
Expand Down
7 changes: 1 addition & 6 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@ def _assert_eq(left, right):
# 'BreakoutNoFrameskip-v4'
ENV_IDS = ("CartPole-v1",)

experiments = {}

for algo in ALGOS:
for env_id in ENV_IDS:
experiments[f"{algo}-{env_id}"] = (algo, env_id)

experiments = {f"{algo}-{env_id}": (algo, env_id) for algo in ALGOS for env_id in ENV_IDS}
# Test for vecnormalize and frame-stack
experiments["ppo-BipedalWalkerHardcore-v3"] = ("ppo", "BipedalWalkerHardcore-v3")
# Test for SAC
Expand Down

0 comments on commit aa38145

Please sign in to comment.