-
Notifications
You must be signed in to change notification settings - Fork 281
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug description
My goal is to pre-train a policy with BC and fine tune it with RL, e.g., PPO. The problem is that I cannot find an example for this, and the ways I tried do not work.
Steps to reproduce
Below I provide a minimal example by using the quickstart.
"""This is a simple example demonstrating how to clone the behavior of an expert.
Refer to the jupyter notebooks for more detailed examples of how to use the algorithms.
"""
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy
from imitation.algorithms import bc
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.policies.serialize import load_policy
from imitation.util.util import make_vec_env
rng = np.random.default_rng(0)
env = make_vec_env(
"seals:seals/CartPole-v0",
rng=rng,
post_wrappers=[lambda env, _: RolloutInfoWrapper(env)], # for computing rollouts
)
def train_expert():
# note: use `download_expert` instead to download a pretrained, competent expert
print("Training a expert.")
expert = PPO(
policy=MlpPolicy,
env=env,
seed=0,
batch_size=64,
ent_coef=0.0,
learning_rate=0.0003,
n_epochs=10,
n_steps=64,
)
expert.learn(1_000) # Note: change this to 100_000 to train a decent expert.
return expert
def sample_expert_transitions():
expert = train_expert() # uncomment to train your own expert
print("Sampling expert transitions.")
rollouts = rollout.rollout(
expert,
env,
rollout.make_sample_until(min_timesteps=None, min_episodes=50),
rng=rng,
)
return rollout.flatten_trajectories(rollouts)
transitions = sample_expert_transitions()
bc_trainer = bc.BC(
observation_space=env.observation_space,
action_space=env.action_space,
demonstrations=transitions,
rng=rng,
)
evaluation_env = make_vec_env(
"seals:seals/CartPole-v0",
rng=rng,
# env_make_kwargs={"render_mode": "human"}, # for rendering
)
bc_trainer.train(n_epochs=1)
ppo = PPO(
policy=bc_trainer.policy,
env=env,
seed=0,
batch_size=64,
ent_coef=0.0,
learning_rate=0.0003,
n_epochs=10,
n_steps=64,
)
Running the code like this gives the following error:
.../site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
TypeError: forward() got an unexpected keyword argument 'use_sde'
Using the default policy of PPO like this
from stable_baselines3.common.policies import ActorCriticPolicy
transitions = sample_expert_transitions()
bc_trainer = bc.BC(
policy=ActorCriticPolicy,
observation_space=env.observation_space,
action_space=env.action_space,
demonstrations=transitions,
rng=rng,
)
leads to another error:
.../site-packages/torch/nn/modules/module.py:1340, in Module.to(self, *args, **kwargs)
1337 else:
1338 raise
-> 1340 return self._apply(convert)
AttributeError: 'torch.device' object has no attribute '_apply'
Environment
- Python version: 3.9.18
- Output of
pip freeze --all
:
absl-py==2.1.0
aiohappyeyeballs==2.4.0
aiohttp==3.10.5
aiohttp-cors==0.7.0
aiosignal==1.3.1
ale-py==0.10.1
alembic==1.14.0
annotated-types==0.7.0
anyio==4.4.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
astunparse==1.6.3
async-lru==2.0.4
async-timeout==4.0.3
attrs==24.2.0
babel==2.16.0
beautifulsoup4==4.12.3
bleach==6.1.0
bokeh==3.4.3
cachetools==5.5.0
certifi==2024.8.30
cffi==1.17.1
charset-normalizer==3.3.2
click==8.1.7
cloudpickle==3.0.0
colorama==0.4.6
colorful==0.5.6
colorlog==6.9.0
comm==0.2.2
contourpy==1.3.0
cycler==0.12.1
datasets==3.2.0
debugpy==1.8.5
decorator==4.4.2
defusedxml==0.7.1
Deprecated==1.2.14
dill==0.3.8
distlib==0.3.8
dm-tree==0.1.8
docker-pycreds==0.4.0
docopt-ng==0.9.0
docstring_parser==0.16
etils==1.5.2
eval_type_backport==0.2.0
exceptiongroup==1.2.2
executing==2.1.0
Farama-Notifications==0.0.4
fastapi==0.114.2
fastjsonschema==2.20.0
filelock==3.16.0
flatbuffers==24.3.25
fonttools==4.53.1
fqdn==1.5.1
frozenlist==1.4.1
fsspec==2024.9.0
gast==0.4.0
gitdb==4.0.11
GitPython==3.1.43
glfw==2.8.0
google-api-core==2.19.2
google-auth==2.34.0
google-auth-oauthlib==1.0.0
google-pasta==0.2.0
googleapis-common-protos==1.65.0
GPUtil==1.4.0
greenlet==3.1.1
grpcio==1.62.0
gymnasium==0.29.1
h11==0.14.0
h5py==3.11.0
httpcore==1.0.5
httptools==0.6.1
httpx==0.27.2
huggingface-hub==0.26.5
huggingface-sb3==3.0
icecream==2.1.3
idna==3.10
imageio==2.35.1
imageio-ffmpeg==0.5.1
imitation==1.0.0
importlib_metadata==8.4.0
importlib_resources==6.4.5
ipykernel==6.29.5
ipython==8.18.1
ipywidgets==8.1.5
isoduration==20.11.0
jax==0.4.30
jax-jumpy==1.0.0
jaxlib==0.4.30
jedi==0.19.1
Jinja2==3.1.4
joblib==1.4.2
json5==0.9.25
jsonpickle==4.0.0
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
jupyter==1.1.1
jupyter-console==6.6.3
jupyter-events==0.10.0
jupyter-lsp==2.2.5
jupyter_client==8.6.2
jupyter_core==5.7.2
jupyter_server==2.14.2
jupyter_server_terminals==0.5.3
jupyterlab==4.2.5
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.3
jupyterlab_widgets==3.0.13
keras==3.7.0
kiwisolver==1.4.7
lazy_loader==0.4
libclang==18.1.1
lightning-utilities==0.11.8
lz4==4.3.3
Mako==1.3.8
Markdown==3.7
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.2
matplotlib-inline==0.1.7
mdurl==0.1.2
mistune==3.0.2
ml-dtypes==0.4.1
moviepy==1.0.3
mpmath==1.3.0
msgpack==1.1.0
mujoco==3.2.6
multidict==6.1.0
multiprocess==0.70.16
munch==4.0.0
namex==0.0.8
nbclient==0.10.0
nbconvert==7.16.4
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.2.1
notebook==7.2.2
notebook_shim==0.2.4
numpy==1.26.4
nvidia-cublas-cu12==12.4.5.8
nvidia-cuda-cupti-cu12==12.4.127
nvidia-cuda-nvrtc-cu12==12.4.127
nvidia-cuda-runtime-cu12==12.4.127
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparse-cu12==12.3.1.170
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127
oauthlib==3.2.2
opencensus==0.11.4
opencensus-context==0.1.3
opencv-python==4.10.0.84
opentelemetry-api==1.27.0
opentelemetry-exporter-otlp==1.27.0
opentelemetry-exporter-otlp-proto-common==1.27.0
opentelemetry-exporter-otlp-proto-grpc==1.27.0
opentelemetry-exporter-otlp-proto-http==1.27.0
opentelemetry-proto==1.27.0
opentelemetry-sdk==1.27.0
opentelemetry-semantic-conventions==0.48b0
opt-einsum==3.3.0
optree==0.12.1
optuna==4.1.0
overrides==7.7.0
packaging==24.1
pandas==2.2.2
pandocfilters==1.5.1
parso==0.8.4
pexpect==4.9.0
pillow==10.4.0
pip @ file:///croot/pip_1723484598856/work
platformdirs==4.3.3
plotly==5.24.1
proglog==0.1.10
prometheus_client==0.20.0
prompt_toolkit==3.0.47
proto-plus==1.24.0
protobuf==4.25.4
psutil==6.0.0
ptyprocess==0.7.0
pure_eval==0.2.3
py-cpuinfo==9.0.0
py-spy==0.3.14
pyarrow==17.0.0
pyasn1==0.6.1
pyasn1_modules==0.4.1
pycparser==2.22
pydantic==2.9.1
pydantic_core==2.23.3
pygame==2.6.1
Pygments==2.18.0
PyOpenGL==3.1.7
pyparsing==3.1.4
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-json-logger==2.0.7
pytorch-lightning==2.4.0
pytz==2024.2
PyYAML==6.0.2
pyzmq==26.2.0
ray==2.10.0
ray-cpp==2.10.0
rdkit==2024.3.6
referencing==0.35.1
requests==2.32.3
requests-oauthlib==2.0.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.8.1
rpds-py==0.20.0
rsa==4.9
sacred==0.8.7
scikit-image==0.24.0
scikit-learn==1.5.0
scipy==1.13.1
seals==0.2.1
Send2Trash==1.8.3
sentry-sdk==2.17.0
setproctitle==1.3.3
setuptools==75.1.0
shellingham==1.5.4
shtab==1.7.1
six==1.16.0
smart-open==7.0.4
smmap==5.0.1
sniffio==1.3.1
soundfile==0.12.1
soupsieve==2.6
SQLAlchemy==2.0.36
stable_baselines3==2.4.0
stack-data==0.6.3
starlette==0.38.5
sympy==1.13.1
tenacity==9.0.0
tensorboard==2.18.0
tensorboard-data-server==0.7.2
tensorboardX==2.6.2.2
tensorflow==2.18.0
tensorflow-estimator==2.12.0
tensorflow-io-gcs-filesystem==0.37.1
termcolor==2.4.0
terminado==0.18.1
threadpoolctl==3.5.0
tifffile==2024.8.30
tinycss2==1.3.0
tomli==2.0.1
torch==2.5.1+cpu
torchaudio==2.5.1+cpu
torchmetrics==1.5.1
torchvision==0.20.1+cpu
tornado==6.4.1
tqdm==4.66.5
traitlets==5.14.3
triton==3.1.0
typer==0.12.5
types-python-dateutil==2.9.0.20240906
typing_extensions==4.12.2
tyro==0.9.1
tzdata==2024.1
uri-template==1.3.0
urllib3==2.2.3
uvicorn==0.30.6
uvloop==0.20.0
virtualenv==20.26.4
wandb==0.18.5
wasabi==1.1.3
watchfiles==0.24.0
wcwidth==0.2.13
webcolors==24.8.0
webencodings==0.5.1
websocket-client==1.8.0
websockets==13.0.1
Werkzeug==3.0.4
wheel==0.44.0
widgetsnbextension==4.0.13
wrapt==1.14.1
xxhash==3.5.0
xyzservices==2024.9.0
yarl==1.11.1
zipp==3.20.2
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working