Skip to content

Commit 829d86d

Browse files
chosenonechosenone
chosenone
authored and
chosenone
committed
fix(yzj): fix device bug
1 parent 59c7c56 commit 829d86d

File tree

3 files changed

+3
-1
lines changed

3 files changed

+3
-1
lines changed

lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ namespace tree
381381
for (size_t iter = 0; iter < disturbed_probs.size(); iter++)
382382
{
383383
#ifdef __APPLE__
384-
disc_action_with_probs.__emplace_back(std::make_pair(iter, disturbed_probs[iter]));
384+
disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter]));
385385
#else
386386
disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter]));
387387
#endif

lzero/policy/efficientzero.py

+1
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]:
401401
# obtain the oracle latent states from representation function.
402402
beg_index, end_index = self._get_target_obs_index_in_step_k(step_k)
403403
obs_target_batch_tmp = default_collate(obs_target_batch[:, beg_index:end_index].squeeze())
404+
obs_target_batch_tmp = to_device(obs_target_batch_tmp, self._device)
404405
network_output = self._learn_model.initial_inference(obs_target_batch_tmp)
405406

406407
latent_state = to_tensor(latent_state)

lzero/policy/muzero.py

+1
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
369369
# obtain the oracle latent states from representation function.
370370
beg_index, end_index = self._get_target_obs_index_in_step_k(step_k)
371371
obs_target_batch_tmp = default_collate(obs_target_batch[:, beg_index:end_index].squeeze())
372+
obs_target_batch_tmp = to_device(obs_target_batch_tmp, self._device)
372373
network_output = self._learn_model.initial_inference(obs_target_batch_tmp)
373374

374375
latent_state = to_tensor(latent_state)

0 commit comments

Comments
 (0)