Skip to content

GAIL is not trainig with image-based data #864

@aha85b

Description

@aha85b

Bug description

I am new to imitation learning package.
The issue I am facing now is that, I want to train an agent using GAIL. However, I keep getting some errors regarding data shape, for example, I am using a custom environment extended Gymnasium, the error is about the class tuple I believe the env.reset() is triggering it. when I change the reset function to return only observation it fixed the error then I get another issue which is, I could it workaround.

I just noticed something while I was debugging this issue, gail.Gail does not like gymnasium and sb3.PPO() does not like gym. So the Gail class through an error due to in compatibility with gym and the other way around for PPO.

round: 0%| | 0/4 [00:00<?, ?it/s]


RuntimeError Traceback (most recent call last)
Cell In[27], line 1
----> 1 gail_trainer.train(10000)

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/imitation/algorithms/adversarial/common.py:454, in AdversarialTrainer.train(self, total_timesteps, callback)
448 assert n_rounds >= 1, (
449 "No updates (need at least "
450 f"{self.gen_train_timesteps} timesteps, have only "
451 f"total_timesteps={total_timesteps})!"
452 )
453 for r in tqdm.tqdm(range(0, n_rounds), desc="round"):
--> 454 self.train_gen(self.gen_train_timesteps)
455 for _ in range(self.n_disc_updates_per_round):
456 with networks.training(self.reward_train):
457 # switch to training mode (affects dropout, normalization)

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/imitation/algorithms/adversarial/common.py:414, in AdversarialTrainer.train_gen(self, total_timesteps, learn_kwargs)
411 learn_kwargs = {}
413 with self.logger.accumulate_means("gen"):
--> 414 self.gen_algo.learn(
415 total_timesteps=total_timesteps,
416 reset_num_timesteps=False,
417 callback=self.gen_callback,
418 **learn_kwargs,
419 )
420 self._global_step += 1
422 gen_trajs, ep_lens = self.venv_buffering.pop_trajectories()

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/stable_baselines3/ppo/ppo.py:315, in PPO.learn(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, progress_bar)
306 def learn(
307 self: SelfPPO,
308 total_timesteps: int,
(...)
313 progress_bar: bool = False,
314 ) -> SelfPPO:
--> 315 return super().learn(
316 total_timesteps=total_timesteps,
317 callback=callback,
318 log_interval=log_interval,
319 tb_log_name=tb_log_name,
320 reset_num_timesteps=reset_num_timesteps,
321 progress_bar=progress_bar,
322 )

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/stable_baselines3/common/on_policy_algorithm.py:300, in OnPolicyAlgorithm.learn(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, progress_bar)
297 assert self.env is not None
299 while self.num_timesteps < total_timesteps:
--> 300 continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
302 if not continue_training:
303 break

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/stable_baselines3/common/on_policy_algorithm.py:179, in OnPolicyAlgorithm.collect_rollouts(self, env, callback, rollout_buffer, n_rollout_steps)
176 with th.no_grad():
177 # Convert to pytorch tensor or to TensorDict
178 obs_tensor = obs_as_tensor(self._last_obs, self.device)
--> 179 actions, values, log_probs = self.policy(obs_tensor)
180 actions = actions.cpu().numpy()
182 # Rescale and perform action

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/stable_baselines3/common/policies.py:645, in ActorCriticPolicy.forward(self, obs, deterministic)
637 """
638 Forward pass in all the networks (actor and critic)
639
(...)
642 :return: action, value and log probability of the action
643 """
644 # Preprocess the observation if needed
--> 645 features = self.extract_features(obs)
646 if self.share_features_extractor:
647 latent_pi, latent_vf = self.mlp_extractor(features)

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/stable_baselines3/common/policies.py:672, in ActorCriticPolicy.extract_features(self, obs, features_extractor)
663 """
664 Preprocess the observation if needed and extract features.
665
(...)
669 features for the actor and the features for the critic.
670 """
671 if self.share_features_extractor:
--> 672 return super().extract_features(obs, self.features_extractor if features_extractor is None else features_extractor)
673 else:
674 if features_extractor is not None:

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/stable_baselines3/common/policies.py:131, in BaseModel.extract_features(self, obs, features_extractor)
123 """
124 Preprocess the observation if needed and extract features.
125
(...)
128 :return: The extracted features
129 """
130 preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
--> 131 return features_extractor(preprocessed_obs)

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/stable_baselines3/common/torch_layers.py:106, in NatureCNN.forward(self, observations)
105 def forward(self, observations: th.Tensor) -> th.Tensor:
--> 106 return self.linear(self.cnn(observations))

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/torch/nn/modules/container.py:217, in Sequential.forward(self, input)
215 def forward(self, input):
216 for module in self:
--> 217 input = module(input)
218 return input

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/torch/nn/modules/linear.py:116, in Linear.forward(self, input)
115 def forward(self, input: Tensor) -> Tensor:
--> 116 return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x49 and 3136x512)

Steps to reproduce

This is the Code that get the issue
SEED = 42
learner = PPO(
env=env,
policy=CnnPolicy,
batch_size=32,
ent_coef=0.0,
learning_rate=0.0004,
gamma=0.95,
n_epochs=5,
seed=SEED,
# device='cpu'
)

reward_net = CnnRewardNet(
observation_space=env.observation_space,
action_space=env.action_space,
use_state=True,
use_action=True,
use_next_state=False,
use_done=False,
hwc_format=False

)

gail_trainer = GAIL(
demonstrations=transitions,
demo_batch_size=32,
gen_replay_buffer_capacity=512,
n_disc_updates_per_round=8,
venv=env,
gen_algo=learner,
reward_net=reward_net,
)

gail_trainer.train(10000)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions