Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/imitation/algorithms/adversarial/common.py
Original file line number Diff line number Diff line change
@@ -2,13 +2,14 @@
import abc
import dataclasses
import logging
from typing import Callable, Iterable, Iterator, Mapping, Optional, Type, overload
from typing import Iterable, Iterator, List, Mapping, Optional, Type, overload

import numpy as np
import torch as th
import torch.utils.tensorboard as thboard
import tqdm
from stable_baselines3.common import base_class, on_policy_algorithm, policies, vec_env
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.sac import policies as sac_policies
from torch.nn import functional as F

@@ -421,7 +422,7 @@ def train_gen(
def train(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more thing - if you change the arguments, update of training_adversarial.py will also be needed

self,
total_timesteps: int,
callback: Optional[Callable[[int], None]] = None,
callback: Optional[List[BaseCallback]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to change the semantics of the argument here, or should we rather deprecate the feature (and introduce a different parameter for additional gen_callback)?

I think the suggestion in the original issue was to add a new gen_callback argument. (Btw, stable-baselines supports both CallbackList and list of callbacks if we wanted to be fancy)

) -> None:
"""Alternates between training the generator and discriminator.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last part of description and finally a call to callback(round) is probably misleading now.

@@ -434,10 +435,15 @@ def train(
Args:
total_timesteps: An upper bound on the number of transitions to sample
from the environment during training.
callback: A function called at the end of every round which takes in a
single argument, the round number. Round numbers are in
`range(total_timesteps // self.gen_train_timesteps)`.
callback: List of stable_baslines3 callback to be passed to the policy
learning function.
"""
if callback is not None:
if self.gen_callback is None:
self.gen_callback = callback
else:
self.gen_callback = callback + [self.gen_callback]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can someone abuse the API by calling train() multiple times? If so, the value of self.gen_callback would contain nested list, which is not correct. Generally, the value of gen_callback is currently Optional[BaseCallback] and we shouldn't change the type to a list at runtime.

Perhaps it would be better to add an optional callback argument to train_gen(), merge callbacks there, and avoid the stateful change here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can the learn_kwargs argument from train_gen() be removed, as discussed in the original issue #607 ?


n_rounds = total_timesteps // self.gen_train_timesteps
assert n_rounds >= 1, (
"No updates (need at least "
@@ -450,8 +456,6 @@ def train(
with networks.training(self.reward_train):
# switch to training mode (affects dropout, normalization)
self.train_disc()
if callback:
callback(r)
self.logger.dump(self._global_step)

@overload