From aa3814530ba8c5263267979bf14499734bfef86d Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 19 Feb 2024 10:46:59 +0100 Subject: [PATCH] Add test requirements, upgrade black (#437) * ignoring virtual env, adding test requirements, and ran 'black .' to auto format files' * updated documentation for change of adding tests requirements as extra * commented on fix in changelong * need minigrid for https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/ppo.yml#L290C16-L290C48 * make format * fixed typo * fixed typo * Fix ruff config * Simplify dependencies --------- Co-authored-by: Nikolaus Schuetz --- .github/workflows/ci.yml | 8 ++---- .github/workflows/trained_agents.yml | 4 +-- .gitignore | 9 +++++++ CHANGELOG.md | 5 ++-- Makefile | 2 +- README.md | 1 + docs/guide/install.rst | 1 + hyperparams/ppo.yml | 2 +- hyperparams/python/ppo_config_example.py | 1 + pyproject.toml | 6 +++-- requirements.txt | 12 +-------- rl_zoo3/plots/plot_train.py | 1 + rl_zoo3/plots/score_normalization.py | 1 + rl_zoo3/train.py | 2 +- rl_zoo3/version.txt | 2 +- scripts/create_cluster_jobs.py | 1 + scripts/run_jobs.py | 1 + setup.py | 33 ++++++++++++++---------- tests/test_train.py | 7 +---- 19 files changed, 51 insertions(+), 48 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 212a6ef19..517b813db 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -40,15 +40,11 @@ jobs: # cpu version of pytorch - faster to download pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu - pip install pybullet==3.2.5 - # for v4 MuJoCo envs: - pip install mujoco + pip install -r requirements.txt # Use headless version pip install opencv-python-headless - # install parking-env to test HER - pip install highway-env==1.8.1 - pip install -e . + pip install -e .[plots,tests] - name: Lint with ruff run: | make lint diff --git a/.github/workflows/trained_agents.yml b/.github/workflows/trained_agents.yml index 2d4dd3f01..689681fe8 100644 --- a/.github/workflows/trained_agents.yml +++ b/.github/workflows/trained_agents.yml @@ -41,12 +41,10 @@ jobs: # cpu version of pytorch - faster to download pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu - pip install pybullet==3.2.5 pip install -r requirements.txt # Use headless version pip install opencv-python-headless - pip install highway-env==1.8.1 - pip install -e . + pip install -e .[plots,tests] - name: Check trained agents run: | make check-trained-agents diff --git a/.gitignore b/.gitignore index 4a0fdfd47..46678b0b4 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,12 @@ keys/ .cache *.lprof *.prof + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 0507d9cf7..f814283ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -## Release 2.3.0a1 (WIP) +## Release 2.3.0a2 (WIP) ### Breaking Changes - Updated defaults hyperparameters for TD3/DDPG to be more consistent with SAC @@ -13,7 +13,8 @@ ### Documentation ### Other - +- Added test dependencies to `setup.py` (@power-edge) +- Simplify dependencies of `requirements.txt` (remove duplicates from `setup.py`) ## Release 2.2.1 (2023-11-17) diff --git a/Makefile b/Makefile index 3b0a463b3..8ea41952c 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ type: mypy lint: # stop the build if there are Python syntax errors or undefined names # see https://www.flake8rules.com/ - ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source + ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full # exit-zero treats all errors as warnings. ruff ${LINT_PATHS} --exit-zero diff --git a/README.md b/README.md index ad146bc36..cd798c6e5 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,7 @@ Note: you can do `python -m rl_zoo3.train` from any folder and you have access t ``` apt-get install swig cmake ffmpeg pip install -r requirements.txt +pip install -e .[plots,tests] ``` Please see [Stable Baselines3 documentation](https://stable-baselines3.readthedocs.io/en/master/) for alternatives to install stable baselines3. diff --git a/docs/guide/install.rst b/docs/guide/install.rst index 57d77164b..102e98e37 100644 --- a/docs/guide/install.rst +++ b/docs/guide/install.rst @@ -48,6 +48,7 @@ With extra envs and test dependencies: apt-get install swig cmake ffmpeg pip install -r requirements.txt + pip install -e .[plots,tests] Please see `Stable Baselines3 documentation `_ for alternatives to install stable baselines3. diff --git a/hyperparams/ppo.yml b/hyperparams/ppo.yml index fd664dbed..9339eea8e 100644 --- a/hyperparams/ppo.yml +++ b/hyperparams/ppo.yml @@ -297,7 +297,7 @@ MiniGrid-Empty-Random-5x5-v0: &minigrid-defaults gae_lambda: 0.95 # Factor for trade-off of bias vs variance for Generalized Advantage Estimator gamma: 0.99 n_epochs: 10 # Number of epoch when optimizing the surrogate - ent_coef: 0.0 # Entropy coefficient for the loss caculation + ent_coef: 0.0 # Entropy coefficient for the loss calculation learning_rate: 2.5e-4 # The learning rate, it can be a function clip_range: 0.2 # Clipping parameter, it can be a function diff --git a/hyperparams/python/ppo_config_example.py b/hyperparams/python/ppo_config_example.py index cbafbbaec..459cea8f9 100644 --- a/hyperparams/python/ppo_config_example.py +++ b/hyperparams/python/ppo_config_example.py @@ -1,5 +1,6 @@ """This file just serves as an example on how to configure the zoo using python scripts instead of yaml files.""" + import torch hyperparams = { diff --git a/pyproject.toml b/pyproject.toml index aeffa65fd..b00654161 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,17 +3,19 @@ line-length = 127 # Assume Python 3.8 target-version = "py38" + +[tool.ruff.lint] # See https://beta.ruff.rs/docs/rules/ select = ["E", "F", "B", "UP", "C90", "RUF"] # Ignore explicit stacklevel` ignore = ["B028"] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "./rl_zoo3/import_envs.py"= ["F401"] # "./rl_zoo3/plots/plot_train.py"= ["E501"] -[tool.ruff.mccabe] +[tool.ruff.lint.mccabe] # Unlike Flake8, default to a complexity level of 10. max-complexity = 15 diff --git a/requirements.txt b/requirements.txt index b380b7a45..33ff7e6cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,22 +1,12 @@ gym==0.26.2 stable-baselines3[extra_no_roms,tests,docs]>=2.3.0a1,<3.0 -sb3-contrib>=2.3.0a1,<3.0 box2d-py==2.3.8 -pybullet pybullet_envs_gymnasium>=0.4.0 # minigrid -# scikit-optimize -optuna~=3.0 -pyyaml>=5.1 cloudpickle>=2.2.1 +# optuna plots: plotly # need to upgrade to gymnasium: # panda-gym~=3.0.1 -rliable>=1.0.5 wandb -huggingface_sb3>=3.0,<4.0 -seaborn -tqdm -rich moviepy -ruff diff --git a/rl_zoo3/plots/plot_train.py b/rl_zoo3/plots/plot_train.py index fc661726c..5b6b52752 100644 --- a/rl_zoo3/plots/plot_train.py +++ b/rl_zoo3/plots/plot_train.py @@ -1,6 +1,7 @@ """ Plot training reward/success rate """ + import argparse import os diff --git a/rl_zoo3/plots/score_normalization.py b/rl_zoo3/plots/score_normalization.py index 7e19cf48d..91e220652 100644 --- a/rl_zoo3/plots/score_normalization.py +++ b/rl_zoo3/plots/score_normalization.py @@ -4,6 +4,7 @@ Max score corresponds to acceptable performance, for instance human level performance in the case of Atari games. """ + from typing import NamedTuple import numpy as np diff --git a/rl_zoo3/train.py b/rl_zoo3/train.py index 53be1683b..4e9b0a4ca 100644 --- a/rl_zoo3/train.py +++ b/rl_zoo3/train.py @@ -159,7 +159,7 @@ def train() -> None: args = parser.parse_args() - # Going through custom gym packages to let them register in the global registory + # Going through custom gym packages to let them register in the global registry for env_module in args.gym_packages: importlib.import_module(env_module) diff --git a/rl_zoo3/version.txt b/rl_zoo3/version.txt index 4d04ad95c..34109b68e 100644 --- a/rl_zoo3/version.txt +++ b/rl_zoo3/version.txt @@ -1 +1 @@ -2.3.0a1 +2.3.0a2 diff --git a/scripts/create_cluster_jobs.py b/scripts/create_cluster_jobs.py index 4795bd03f..ed84627c9 100644 --- a/scripts/create_cluster_jobs.py +++ b/scripts/create_cluster_jobs.py @@ -1,6 +1,7 @@ """ Send multiple jobs to the cluster. """ + import os import subprocess import time diff --git a/scripts/run_jobs.py b/scripts/run_jobs.py index 8252f7b27..5d5a87794 100644 --- a/scripts/run_jobs.py +++ b/scripts/run_jobs.py @@ -1,6 +1,7 @@ """ Run multiple experiments on a single machine. """ + import subprocess from typing import List diff --git a/setup.py b/setup.py index 90d177b7b..d42a7b127 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,23 @@ See https://github.com/DLR-RM/rl-baselines3-zoo """ +install_requires = [ + "sb3_contrib>=2.3.0a1,<3.0", + "gymnasium~=0.29.1", + "huggingface_sb3>=3.0,<4.0", + "tqdm", + "rich", + "optuna>=3.0", + "pyyaml>=5.1", + "pytablewriter~=1.2", +] +plots_requires = ["seaborn", "rliable>=1.0.5", "scipy~=1.10"] +test_requires = [ + # for MuJoCo envs v4: + "mujoco~=2.3", + # install parking-env to test HER + "highway-env==1.8.2", +] setup( name="rl_zoo3", @@ -26,20 +43,8 @@ ] }, entry_points={"console_scripts": ["rl_zoo3=rl_zoo3.cli:main"]}, - install_requires=[ - "sb3_contrib>=2.3.0a1,<3.0", - "gymnasium~=0.29.1", - "huggingface_sb3>=3.0,<4.0", - "tqdm", - "rich", - "optuna>=3.0", - "pyyaml>=5.1", - "pytablewriter~=1.2", - # TODO: add test dependencies - ], - extras_require={ - "plots": ["seaborn", "rliable>=1.0.5", "scipy~=1.10"], - }, + install_requires=install_requires, + extras_require={"plots": plots_requires, "tests": test_requires}, description="A Training Framework for Stable Baselines3 Reinforcement Learning Agents", author="Antonin Raffin", url="https://github.com/DLR-RM/rl-baselines3-zoo", diff --git a/tests/test_train.py b/tests/test_train.py index e26b98760..d0780acc1 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -15,12 +15,7 @@ def _assert_eq(left, right): # 'BreakoutNoFrameskip-v4' ENV_IDS = ("CartPole-v1",) -experiments = {} - -for algo in ALGOS: - for env_id in ENV_IDS: - experiments[f"{algo}-{env_id}"] = (algo, env_id) - +experiments = {f"{algo}-{env_id}": (algo, env_id) for algo in ALGOS for env_id in ENV_IDS} # Test for vecnormalize and frame-stack experiments["ppo-BipedalWalkerHardcore-v3"] = ("ppo", "BipedalWalkerHardcore-v3") # Test for SAC