Skip to content

Commit 38241ff

Browse files
authored
env(zjow): fix evogym replay video problem (#527)
* fix video save * fix video save * fix gym repo * Add carrier config * minor change
1 parent f5f219b commit 38241ff

File tree

8 files changed

+220
-179
lines changed

8 files changed

+220
-179
lines changed

dizoo/evogym/config/bridgewalker_ddpg_config.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,17 @@
66
env_id='BridgeWalker-v0',
77
robot='speed_bot',
88
robot_dir='../envs',
9-
collector_env_num=1,
10-
evaluator_env_num=1,
11-
n_evaluator_episode=1,
12-
stop_value=1,
13-
manager=dict(shared_memory=False, ),
9+
collector_env_num=8,
10+
evaluator_env_num=8,
11+
n_evaluator_episode=8,
12+
stop_value=10,
13+
manager=dict(shared_memory=True, ),
1414
# The path to save the game replay
15-
replay_path='./evogym_walker_ddpg_seed0/video',
15+
# replay_path='./evogym_walker_ddpg_seed0/video',
1616
),
1717
policy=dict(
1818
cuda=True,
19-
load_path="./evogym_walker_ddpg_seed0/ckpt/ckpt_best.pth.tar",
19+
# load_path="./evogym_walker_ddpg_seed0/ckpt/ckpt_best.pth.tar",
2020
random_collect_size=1000,
2121
model=dict(
2222
obs_shape=59,
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from easydict import EasyDict
2+
3+
carry_ppo_config = dict(
4+
exp_name='evogym_carrier_ppo_seed1',
5+
env=dict(
6+
env_id='Carrier-v0',
7+
robot='carry_bot',
8+
robot_dir='./dizoo/evogym/envs',
9+
collector_env_num=8,
10+
evaluator_env_num=8,
11+
n_evaluator_episode=8,
12+
stop_value=10,
13+
manager=dict(shared_memory=True, ),
14+
# The path to save the game replay
15+
# replay_path='./evogym_carry_ppo_seed0/video',
16+
),
17+
policy=dict(
18+
cuda=True,
19+
recompute_adv=True,
20+
# load_path="./evogym_carry_ppo_seed0/ckpt/ckpt_best.pth.tar",
21+
model=dict(
22+
obs_shape=70,
23+
action_shape=12,
24+
action_space='continuous',
25+
),
26+
action_space='continuous',
27+
learn=dict(
28+
epoch_per_collect=10,
29+
batch_size=256,
30+
learning_rate=3e-3,
31+
value_weight=0.5,
32+
entropy_weight=0.01,
33+
clip_ratio=0.2,
34+
adv_norm=True,
35+
value_norm=True,
36+
),
37+
collect=dict(
38+
n_sample=2048,
39+
gae_lambda=0.97,
40+
),
41+
eval=dict(evaluator=dict(eval_freq=5000, )),
42+
)
43+
)
44+
carry_ppo_config = EasyDict(carry_ppo_config)
45+
main_config = carry_ppo_config
46+
47+
carry_ppo_create_config = dict(
48+
env=dict(
49+
type='evogym',
50+
import_names=['dizoo.evogym.envs.evogym_env'],
51+
),
52+
env_manager=dict(type='subprocess'),
53+
policy=dict(
54+
type='ppo',
55+
import_names=['ding.policy.ppo'],
56+
),
57+
replay_buffer=dict(type='naive', ),
58+
)
59+
carry_ppo_create_config = EasyDict(carry_ppo_create_config)
60+
create_config = carry_ppo_create_config
61+
62+
if __name__ == "__main__":
63+
# or you can enter `ding -m serial -c evogym_carry_ppo_config.py -s 0 --env-step 1e7`
64+
from ding.entry import serial_pipeline_onpolicy
65+
serial_pipeline_onpolicy((main_config, create_config), seed=0)

dizoo/evogym/config/walker_ddpg_config.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@
55
env=dict(
66
env_id='Walker-v0',
77
robot='speed_bot',
8-
robot_dir='../envs',
9-
collector_env_num=1,
10-
evaluator_env_num=1,
11-
n_evaluator_episode=1,
12-
stop_value=-0.5,
13-
manager=dict(shared_memory=False, ),
8+
robot_dir='./dizoo/evogym/envs',
9+
collector_env_num=8,
10+
evaluator_env_num=8,
11+
n_evaluator_episode=8,
12+
stop_value=10,
13+
manager=dict(shared_memory=True, ),
1414
# The path to save the game replay
15-
replay_path='./evogym_walker_ddpg_seed0/video',
15+
# replay_path='./evogym_walker_ddpg_seed0/video',
1616
),
1717
policy=dict(
1818
cuda=True,
19-
load_path="./evogym_walker_ddpg_seed0/ckpt/ckpt_best.pth.tar",
19+
# load_path="./evogym_walker_ddpg_seed0/ckpt/ckpt_best.pth.tar",
2020
random_collect_size=1000,
2121
model=dict(
2222
obs_shape=58,

dizoo/evogym/config/walker_ddpg_eval_config.py

Lines changed: 0 additions & 70 deletions
This file was deleted.
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from easydict import EasyDict
2+
3+
walker_ppo_config = dict(
4+
exp_name='evogym_walker_ppo_seed0',
5+
env=dict(
6+
env_id='Walker-v0',
7+
robot='speed_bot',
8+
robot_dir='./dizoo/evogym/envs',
9+
collector_env_num=1,
10+
evaluator_env_num=1,
11+
n_evaluator_episode=1,
12+
stop_value=10,
13+
manager=dict(shared_memory=True, ),
14+
# The path to save the game replay
15+
# replay_path='./evogym_walker_ppo_seed0/video',
16+
),
17+
policy=dict(
18+
cuda=True,
19+
recompute_adv=True,
20+
# load_path="./evogym_walker_ppo_seed0/ckpt/ckpt_best.pth.tar",
21+
model=dict(
22+
obs_shape=58,
23+
action_shape=10,
24+
action_space='continuous',
25+
),
26+
action_space='continuous',
27+
learn=dict(
28+
epoch_per_collect=10,
29+
batch_size=256,
30+
learning_rate=3e-4,
31+
value_weight=0.5,
32+
entropy_weight=0.0,
33+
clip_ratio=0.2,
34+
adv_norm=True,
35+
value_norm=True,
36+
),
37+
collect=dict(
38+
n_sample=2048,
39+
gae_lambda=0.97,
40+
),
41+
eval=dict(evaluator=dict(eval_freq=5000, )),
42+
)
43+
)
44+
walker_ppo_config = EasyDict(walker_ppo_config)
45+
main_config = walker_ppo_config
46+
47+
walker_ppo_create_config = dict(
48+
env=dict(
49+
type='evogym',
50+
import_names=['dizoo.evogym.envs.evogym_env'],
51+
),
52+
env_manager=dict(type='subprocess'),
53+
policy=dict(
54+
type='ppo',
55+
import_names=['ding.policy.ppo'],
56+
),
57+
replay_buffer=dict(type='naive', ),
58+
)
59+
walker_ppo_create_config = EasyDict(walker_ppo_create_config)
60+
create_config = walker_ppo_create_config
61+
62+
if __name__ == "__main__":
63+
# or you can enter `ding -m serial -c evogym_walker_ppo_config.py -s 0 --env-step 1e7`
64+
from ding.entry import serial_pipeline_onpolicy
65+
serial_pipeline_onpolicy((main_config, create_config), seed=0)

dizoo/evogym/entry/walker_ppo_eval.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import os
2+
import gym
3+
import torch
4+
from tensorboardX import SummaryWriter
5+
from easydict import EasyDict
6+
from functools import partial
7+
8+
from ding.config import compile_config
9+
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
10+
from ding.envs import BaseEnvManager
11+
from ding.envs import get_vec_env_setting, create_env_manager
12+
from ding.policy import PPOPolicy
13+
from ding.utils import set_pkg_seed
14+
15+
from dizoo.evogym.config.walker_ppo_config import main_config, create_config
16+
17+
18+
def main(cfg, create_cfg, seed=0):
19+
cfg = compile_config(
20+
cfg,
21+
BaseEnvManager,
22+
PPOPolicy,
23+
BaseLearner,
24+
SampleSerialCollector,
25+
InteractionSerialEvaluator,
26+
AdvancedReplayBuffer,
27+
create_cfg=create_cfg,
28+
save_cfg=True
29+
)
30+
31+
create_cfg.policy.type = create_cfg.policy.type + '_command'
32+
env_fn = None
33+
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
34+
# Create main components: env, policy
35+
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
36+
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
37+
38+
evaluator_env.enable_save_replay(cfg.env.replay_path)
39+
40+
# Set random seed for all package and instance
41+
evaluator_env.seed(seed, dynamic_seed=False)
42+
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
43+
44+
# Set up RL Policy
45+
policy = PPOPolicy(cfg.policy)
46+
policy.eval_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cpu'))
47+
48+
# evaluate
49+
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
50+
evaluator = InteractionSerialEvaluator(
51+
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
52+
)
53+
evaluator.eval()
54+
55+
56+
if __name__ == "__main__":
57+
main(main_config, create_config, seed=0)

dizoo/evogym/envs/evogym_env.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
from typing import Any, Union, List, Optional
2+
import os
3+
import time
24
import copy
35
import numpy as np
4-
from easydict import EasyDict
56
import gym
6-
import evogym.envs
7-
from evogym import WorldObject, sample_robot
8-
from .viewer import DingEvoViewer
9-
from evogym.sim import EvoSim
10-
import os
7+
from easydict import EasyDict
8+
119
from ding.envs import BaseEnv, BaseEnvTimestep, FinalEvalRewardEnv
1210
from ding.envs.common.common_function import affine_transform
1311
from ding.torch_utils import to_ndarray, to_list
1412
from ding.utils import ENV_REGISTRY
1513

14+
import evogym.envs
15+
from evogym import WorldObject, sample_robot
16+
from evogym.sim import EvoSim
1617

1718
@ENV_REGISTRY.register('evogym')
1819
class EvoGymEnv(BaseEnv):
@@ -59,11 +60,17 @@ def reset(self) -> np.ndarray:
5960
self._env.seed(self._seed)
6061
if self._replay_path is not None:
6162
gym.logger.set_level(gym.logger.DEBUG)
62-
# use our own 'viewer' to make 'render' compatible with gym
63-
self._env.default_viewer = DingEvoViewer(EvoSim(self._env.world))
64-
self._env.__class__.render = self._env.default_viewer.render
65-
self._env.metadata['render.modes'] = 'rgb_array' # make render mode compatible with gym
66-
self._env = gym.wrappers.RecordVideo(self._env, './videos/' + str('time()') + '/') # time()
63+
# make render mode compatible with gym
64+
if gym.version.VERSION > '0.22.0':
65+
self._env.metadata.update({'render_modes': ["rgb_array"]})
66+
else:
67+
self._env.metadata.update({'render.modes': ["rgb_array"]})
68+
self._env = gym.wrappers.RecordVideo(
69+
self._env,
70+
video_folder=self._replay_path,
71+
episode_trigger=lambda episode_id: True,
72+
name_prefix='rl-video-{}-{}'.format(id(self),time.time())
73+
)
6774
obs = self._env.reset()
6875
obs = to_ndarray(obs).astype('float32')
6976
return obs

0 commit comments

Comments
 (0)