Commit 30f9048
committed
serialize: doctest doesn't like save_to_disk(PosixPath)
The error message:
Warning, treated as error:
**********************************************************************
File "algorithms/dagger.rst", line 45, in default
Failed example:
import tempfile
import numpy as np
import gymnasium as gym
from stable_baselines3.common.evaluation import evaluate_policy
from imitation.algorithms import bc
from imitation.algorithms.dagger import SimpleDAggerTrainer
from imitation.policies.serialize import load_policy
from imitation.util.util import make_vec_env
rng = np.random.default_rng(0)
env = make_vec_env(
"seals:seals/CartPole-v0",
rng=rng,
)
expert = load_policy(
"ppo-huggingface",
organization="HumanCompatibleAI",
env_name="seals-CartPole-v0",
venv=env,
)
bc_trainer = bc.BC(
observation_space=env.observation_space,
action_space=env.action_space,
rng=rng,
)
with tempfile.TemporaryDirectory(prefix="dagger_example_") as tmpdir:
print(tmpdir)
dagger_trainer = SimpleDAggerTrainer(
venv=env,
scratch_dir=tmpdir,
expert_policy=expert,
bc_trainer=bc_trainer,
rng=rng,
)
dagger_trainer.train(8_000)
reward, _ = evaluate_policy(dagger_trainer.policy, env, 10)
print("Reward:", reward)
Exception raised:
Traceback (most recent call last):
File "/usr/lib/python3.8/doctest.py", line 1336, in __run
exec(compile(example.source, filename, "single",
File "<doctest default[0]>", line 38, in <module>
dagger_trainer.train(8_000)
File "/venv/lib/python3.8/site-packages/imitation/algorithms/dagger.py", line 669, in train
trajectories = rollout.generate_trajectories(
File "/venv/lib/python3.8/site-packages/imitation/data/rollout.py", line 447, in generate_trajectories
obs, rews, dones, infos = venv.step(acts)
File "/venv/lib/python3.8/site-packages/stable_baselines3/common/vec_env/base_vec_env.py", line 206, in step
return self.step_wait()
File "/venv/lib/python3.8/site-packages/imitation/algorithms/dagger.py", line 285, in step_wait
_save_dagger_demo(traj, traj_index, self.save_dir, self.rng)
File "/venv/lib/python3.8/site-packages/imitation/algorithms/dagger.py", line 147, in _save_dagger_demo
serialize.save(npz_path, [trajectory])
File "/venv/lib/python3.8/site-packages/imitation/data/serialize.py", line 23, in save
huggingface_utils.trajectories_to_dataset(trajectories).save_to_disk(p)
File "/venv/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 1470, in save_to_disk
fs, _ = url_to_fs(dataset_path, **(storage_options or {}))
File "/venv/lib/python3.8/site-packages/fsspec/core.py", line 383, in url_to_fs
chain = _un_chain(url, kwargs)
File "/venv/lib/python3.8/site-packages/fsspec/core.py", line 323, in _un_chain
if "::" in path
TypeError: argument of type 'PosixPath' is not iterable1 parent b765aff commit 30f9048
1 file changed
+1
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
20 | 20 | | |
21 | 21 | | |
22 | 22 | | |
23 | | - | |
| 23 | + | |
24 | 24 | | |
25 | 25 | | |
26 | 26 | | |
| |||
0 commit comments