Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(rjy): add mamujoco env and related configs #153

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
algo(rjy): add pipeline of sez ma (train+eval)
nighood committed Dec 8, 2023
commit c54c0a549e81c22c1fe1e8905fb2142e75f73263
95 changes: 59 additions & 36 deletions lzero/policy/sampled_efficientzero.py
Original file line number Diff line number Diff line change
@@ -1034,7 +1034,16 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read
``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``.
"""
self._eval_model.eval()
active_eval_env_num = data.shape[0]
if isinstance(data, dict):
# If data is a dictionary, find the first non-dictionary element and get its shape[0]
for k, v in data.items():
if not isinstance(v, dict):
active_eval_env_num = v.shape[0]*v.shape[1]
agent_num = v.shape[1] # multi-agent
elif isinstance(data, torch.Tensor):
# If data is a torch.tensor, directly return its shape[0]
active_eval_env_num = data.shape[0]
agent_num = 1 # single-agent
with torch.no_grad():
# data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)}
network_output = self._eval_model.initial_inference(data)
@@ -1088,51 +1097,65 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read
roots_sampled_actions = roots.get_sampled_actions(
) # shape: ``{list: batch_size} ->{list: action_space_size}``

if self._multi_agent:
active_eval_env_num = active_eval_env_num // agent_num
data_id = [i for i in range(active_eval_env_num)]
output = {i: None for i in data_id}

if ready_env_id is None:
ready_env_id = np.arange(active_eval_env_num)

for i, env_id in enumerate(ready_env_id):
distributions, value = roots_visit_count_distributions[i], roots_values[i]
try:
root_sampled_actions = np.array([action.value for action in roots_sampled_actions[i]])
except Exception:
# logging.warning('ctree_sampled_efficientzero roots.get_sampled_actions() return list')
root_sampled_actions = np.array([action for action in roots_sampled_actions[i]])
# NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents
# the index within the legal action set, rather than the index in the entire action set.
# Setting deterministic=True implies choosing the action with the highest value (argmax) rather than sampling during the evaluation phase.
action, visit_count_distribution_entropy = select_action(
distributions, temperature=1, deterministic=True
)
# ==============================================================
# sampled related core code
# ==============================================================
output[env_id] = {
'action': [],
'visit_count_distributions': [],
'root_sampled_actions': [],
'visit_count_distribution_entropy': [],
'searched_value': [],
'predicted_value': [],
'predicted_policy_logits': [],
}
for j in range(agent_num):
index = i * agent_num + j
distributions, value = roots_visit_count_distributions[index], roots_values[index]
try:
root_sampled_actions = np.array([action.value for action in roots_sampled_actions[index]])
except Exception:
# logging.warning('ctree_sampled_efficientzero roots.get_sampled_actions() return list')
root_sampled_actions = np.array([action for action in roots_sampled_actions[index]])
# NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents
# the index within the legal action set, rather than the index in the entire action set.
# Setting deterministic=True implies choosing the action with the highest value (argmax) rather than sampling during the evaluation phase.
action, visit_count_distribution_entropy = select_action(
distributions, temperature=1, deterministic=True
)
# ==============================================================
# sampled related core code
# ==============================================================

try:
action = roots_sampled_actions[i][action].value
# logging.warning('ptree_sampled_efficientzero roots.get_sampled_actions() return array')
except Exception:
# logging.warning('ctree_sampled_efficientzero roots.get_sampled_actions() return list')
action = np.array(roots_sampled_actions[i][action])
try:
action = roots_sampled_actions[index][action].value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

优化这些调试代码

# logging.warning('ptree_sampled_efficientzero roots.get_sampled_actions() return array')
except Exception:
# logging.warning('ctree_sampled_efficientzero roots.get_sampled_actions() return list')
action = np.array(roots_sampled_actions[index][action])

if not self._cfg.model.continuous_action_space:
if len(action.shape) == 0:
action = int(action)
elif len(action.shape) == 1:
action = int(action[0])
if not self._cfg.model.continuous_action_space:
if len(action.shape) == 0:
action = int(action)
elif len(action.shape) == 1:
action = int(action[0])

output[env_id] = {
'action': action,
'visit_count_distributions': distributions,
'root_sampled_actions': root_sampled_actions,
'visit_count_distribution_entropy': visit_count_distribution_entropy,
'searched_value': value,
'predicted_value': pred_values[i],
'predicted_policy_logits': policy_logits[i],
}
output[env_id]['action'].append(action)
output[env_id]['visit_count_distributions'].append(distributions)
output[env_id]['root_sampled_actions'].append(root_sampled_actions)
output[env_id]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy)
output[env_id]['searched_value'].append(value)
output[env_id]['predicted_value'].append(pred_values[index])
output[env_id]['predicted_policy_logits'].append(policy_logits[index])

for k,v in output[env_id].items():
output[env_id][k] = np.array(v)

return output

12 changes: 7 additions & 5 deletions lzero/worker/muzero_evaluator.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,8 @@
from ding.utils import get_world_size, get_rank, broadcast_object_list
from ding.worker.collector.base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor
from easydict import EasyDict
from ding.torch_utils import to_ndarray, to_device
from ding.utils.data import default_collate

from lzero.mcts.buffer.game_segment import GameSegment
from lzero.mcts.utils import prepare_observation
@@ -271,18 +273,18 @@ def eval(
ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode]))
remain_episode -= min(len(new_available_env_id), remain_episode)

stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id}
stack_obs = {env_id: game_segments[env_id].get_obs()[0] for env_id in ready_env_id}
stack_obs = list(stack_obs.values())
stack_obs = default_collate(stack_obs)
if not isinstance(stack_obs, dict):
stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type)
stack_obs = to_device(stack_obs, self.policy_config.device)

action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id}
to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id}
action_mask = [action_mask_dict[env_id] for env_id in ready_env_id]
to_play = [to_play_dict[env_id] for env_id in ready_env_id]

stack_obs = to_ndarray(stack_obs)
stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type)
stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float()

# ==============================================================
# policy forward
# ==============================================================
Original file line number Diff line number Diff line change
@@ -101,7 +101,8 @@
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
n_episode=n_episode,
eval_freq=int(2e3),
# eval_freq=int(2e3),
eval_freq=int(2),
replay_buffer_size=int(1e6),
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,