Skip to content

How to Retrieve mean/gen/train/loss during GAIL Training and Create Checkpoints on New Minimum #859

@kechirojp

Description

@kechirojp

Question

Hello,

I am trying to implement a callback during GAIL training that retrieves the mean/gen/train/loss value and creates a checkpoint whenever this loss reaches a new minimum.

However, I am unsure how to access the mean/gen/train/loss value.

Currently, I have the logger configured as follows:

from imitation.util import logger as imit_logger
custom_logger = imit_logger.configure(
        folder=optuna_log_dir,
        format_strs=["tensorboard","csv", "stdout"],
    )

gail_trainer = GAIL(
    demonstrations=flattened_transitions,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=512,
    n_disc_updates_per_round=8,
    venv=venv,
    gen_algo=learner,
    reward_net=reward_net,
    allow_variable_horizon=True,  
    log_dir=log_GAIL_dir,
    init_tensorboard=True,
    init_tensorboard_graph=True,
    custom_logger= custom_logger
)

Could you please advise on how to retrieve mean/gen/train/loss and how to save a checkpoint whenever it reaches a new minimum?

Environment

  • Operating system and version: Google Colaboratory
  • Python version: Python 3.10.12
  • Output of pip freeze --all:absl-py==1.4.0
    accelerate==0.34.2
    aiohappyeyeballs==2.4.3
    aiohttp==3.10.8
    aiosignal==1.3.1
    alabaster==0.7.16
    albucore==0.0.16
    albumentations==1.4.15
    ale-py==0.8.1
    alembic==1.13.3
    altair==4.2.2
    annotated-types==0.7.0
    anyio==3.7.1
    argon2-cffi==23.1.0
    argon2-cffi-bindings==21.2.0
    array_record==0.5.1
    arviz==0.19.0
    astropy==6.1.4
    astropy-iers-data==0.2024.9.30.0.32.59
    astunparse==1.6.3
    async-timeout==4.0.3
    atpublic==4.1.0
    attrs==24.2.0
    audioread==3.0.1
    autograd==1.7.0
    AutoROM==0.6.1
    AutoROM.accept-rom-license==0.6.1
    babel==2.16.0
    backcall==0.2.0
    beautifulsoup4==4.12.3
    bigframes==1.21.0
    bigquery-magics==0.4.0
    bleach==6.1.0
    blinker==1.4
    blis==0.7.11
    blosc2==2.0.0
    bokeh==3.4.3
    Bottleneck==1.4.0
    bqplot==0.12.43
    branca==0.8.0
    build==1.2.2.post1
    CacheControl==0.14.0
    cachetools==5.5.0
    catalogue==2.0.10
    certifi==2024.8.30
    cffi==1.17.1
    chardet==5.2.0
    charset-normalizer==3.3.2
    chex==0.1.87
    clarabel==0.9.0
    click==8.1.7
    cloudpathlib==0.19.0
    cloudpickle==2.2.1
    cmake==3.30.4
    cmdstanpy==1.2.4
    colorama==0.4.6
    colorcet==3.1.0
    colorlog==6.8.2
    colorlover==0.3.0
    colour==0.1.5
    community==1.0.0b1
    confection==0.1.5
    cons==0.4.6
    contextlib2==21.6.0
    contourpy==1.3.0
    cryptography==43.0.1
    cuda-python==12.2.1
    cudf-cu12 @ https://pypi.nvidia.com/cudf-cu12/cudf_cu12-24.6.1-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=6fdd6fd412117503ad23dca75aabb5b536348d3af459b59920152be1a0af6f15
    cufflinks==0.17.3
    cupy-cuda12x==12.2.0
    cvxopt==1.3.2
    cvxpy==1.5.3
    cycler==0.12.1
    cymem==2.0.8
    Cython==3.0.11
    dask==2024.8.0
    datascience==0.17.6
    datasets==3.0.1
    db-dtypes==1.3.0
    dbus-python==1.2.18
    debugpy==1.6.6
    decorator==4.4.2
    defusedxml==0.7.1
    Deprecated==1.2.14
    dill==0.3.8
    distributed==2024.8.0
    distro==1.7.0
    dlib==19.24.2
    dm-tree==0.1.8
    docopt-ng==0.9.0
    docstring_parser==0.16
    docutils==0.18.1
    dopamine_rl==4.0.9
    duckdb==1.1.1
    earthengine-api==1.0.0
    easydict==1.13
    ecos==2.0.14
    editdistance==0.8.1
    eerepr==0.0.4
    einops==0.8.0
    en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl#sha256=86cc141f63942d4b2c5fcee06630fd6f904788d2f0ab005cce45aadb8fb73889
    entrypoints==0.4
    et-xmlfile==1.1.0
    etils==1.9.4
    etuples==0.3.9
    eval_type_backport==0.2.0
    exceptiongroup==1.2.2
    Farama-Notifications==0.0.4
    fastai==2.7.17
    fastcore==1.7.10
    fastdownload==0.0.7
    fastjsonschema==2.20.0
    fastprogress==1.0.3
    fastrlock==0.8.2
    filelock==3.16.1
    firebase-admin==6.5.0
    Flask==2.2.5
    flatbuffers==24.3.25
    flax==0.8.5
    folium==0.17.0
    fonttools==4.54.1
    frozendict==2.4.4
    frozenlist==1.4.1
    fsspec==2024.6.1
    future==1.0.0
    gast==0.6.0
    gcsfs==2024.6.1
    GDAL==3.6.4
    gdown==5.2.0
    geemap==0.34.5
    gensim==4.3.3
    geocoder==1.38.1
    geographiclib==2.0
    geopandas==1.0.1
    geopy==2.4.1
    gin-config==0.5.0
    gitdb==4.0.11
    GitPython==3.1.43
    glob2==0.7
    google==2.0.3
    google-ai-generativelanguage==0.6.6
    google-api-core==2.19.2
    google-api-python-client==2.137.0
    google-auth==2.27.0
    google-auth-httplib2==0.2.0
    google-auth-oauthlib==1.2.1
    google-cloud-aiplatform==1.69.0
    google-cloud-bigquery==3.25.0
    google-cloud-bigquery-connection==1.15.5
    google-cloud-bigquery-storage==2.26.0
    google-cloud-bigtable==2.26.0
    google-cloud-core==2.4.1
    google-cloud-datastore==2.19.0
    google-cloud-firestore==2.16.1
    google-cloud-functions==1.16.5
    google-cloud-iam==2.15.2
    google-cloud-language==2.13.4
    google-cloud-pubsub==2.25.2
    google-cloud-resource-manager==1.12.5
    google-cloud-storage==2.8.0
    google-cloud-translate==3.15.5
    google-colab @ file:///colabtools/dist/google_colab-1.0.0.tar.gz#sha256=f8df4d5e53a79aac2c8af9405f533f064988497448276e4f71c56dc9a1491702
    google-crc32c==1.6.0
    google-generativeai==0.7.2
    google-pasta==0.2.0
    google-resumable-media==2.7.2
    googleapis-common-protos==1.65.0
    googledrivedownloader==0.4
    graphviz==0.20.3
    greenlet==3.1.1
    grpc-google-iam-v1==0.13.1
    grpcio==1.64.1
    grpcio-status==1.48.2
    gspread==6.0.2
    gspread-dataframe==3.3.1
    gym==0.25.2
    gym-notices==0.0.8
    gymnasium==0.29.1
    h5netcdf==1.3.0
    h5py==3.11.0
    holidays==0.57
    holoviews==1.19.1
    html5lib==1.1
    httpimport==1.4.0
    httplib2==0.22.0
    huggingface-hub==0.24.7
    huggingface-sb3==3.0
    humanize==4.10.0
    hyperopt==0.2.7
    ibis-framework==9.2.0
    idna==3.10
    imageio==2.35.1
    imageio-ffmpeg==0.5.1
    imagesize==1.4.1
    imbalanced-learn==0.12.3
    imgaug==0.4.0
    imitation @ file:///content/imitation
    immutabledict==4.2.0
    importlib_metadata==8.4.0
    importlib_resources==6.4.5
    imutils==0.5.4
    inflect==7.4.0
    iniconfig==2.0.0
    intel-cmplr-lib-ur==2024.2.1
    intel-openmp==2024.2.1
    ipyevents==2.0.2
    ipyfilechooser==0.6.0
    ipykernel==5.5.6
    ipyleaflet==0.19.2
    ipyparallel==8.8.0
    ipython==7.34.0
    ipython-genutils==0.2.0
    ipython-sql==0.5.0
    ipytree==0.2.2
    ipywidgets==7.7.1
    itsdangerous==2.2.0
    jax==0.4.33
    jax-cuda12-pjrt==0.4.33
    jax-cuda12-plugin==0.4.33
    jaxlib==0.4.33
    jeepney==0.7.1
    jellyfish==1.1.0
    jieba==0.42.1
    Jinja2==3.1.4
    joblib==1.4.2
    jsonpickle==3.3.0
    jsonschema==4.23.0
    jsonschema-specifications==2023.12.1
    jupyter-client==6.1.12
    jupyter-console==6.1.0
    jupyter-leaflet==0.19.2
    jupyter-server==1.24.0
    jupyter_core==5.7.2
    jupyterlab_pygments==0.3.0
    jupyterlab_widgets==3.0.13
    kaggle==1.6.17
    kagglehub==0.3.1
    keras==3.4.1
    keyring==23.5.0
    kiwisolver==1.4.7
    langcodes==3.4.1
    language_data==1.2.0
    launchpadlib==1.10.16
    lazr.restfulclient==0.14.4
    lazr.uri==1.0.6
    lazy_loader==0.4
    libclang==18.1.1
    librosa==0.10.2.post1
    lightgbm==4.5.0
    linkify-it-py==2.0.3
    llvmlite==0.43.0
    locket==1.0.0
    logical-unification==0.4.6
    lxml==4.9.4
    Mako==1.3.5
    marisa-trie==1.2.0
    Markdown==3.7
    markdown-it-py==3.0.0
    MarkupSafe==2.1.5
    matplotlib==3.7.1
    matplotlib-inline==0.1.7
    matplotlib-venn==1.1.1
    mdit-py-plugins==0.4.2
    mdurl==0.1.2
    miniKanren==1.0.3
    missingno==0.5.2
    mistune==0.8.4
    mizani==0.11.4
    mkl==2024.2.2
    ml-dtypes==0.4.1
    mlxtend==0.23.1
    more-itertools==10.5.0
    moviepy==1.0.3
    mpmath==1.3.0
    msgpack==1.0.8
    multidict==6.1.0
    multipledispatch==1.0.0
    multiprocess==0.70.16
    multitasking==0.0.11
    munch==4.0.0
    murmurhash==1.0.10
    music21==9.1.0
    namex==0.0.8
    natsort==8.4.0
    nbclassic==1.1.0
    nbclient==0.10.0
    nbconvert==6.5.4
    nbformat==5.10.4
    nest-asyncio==1.6.0
    networkx==3.3
    nibabel==5.2.1
    nltk==3.8.1
    notebook==6.5.5
    notebook_shim==0.2.4
    numba==0.60.0
    numexpr==2.10.1
    numpy==1.26.4
    nvidia-cublas-cu12==12.6.3.3
    nvidia-cuda-cupti-cu12==12.6.80
    nvidia-cuda-nvcc-cu12==12.6.77
    nvidia-cuda-runtime-cu12==12.6.77
    nvidia-cudnn-cu12==9.4.0.58
    nvidia-cufft-cu12==11.3.0.4
    nvidia-cusolver-cu12==11.7.1.2
    nvidia-cusparse-cu12==12.5.4.2
    nvidia-nccl-cu12==2.23.4
    nvidia-nvjitlink-cu12==12.6.77
    nvtx==0.2.10
    oauth2client==4.1.3
    oauthlib==3.2.2
    opencv-contrib-python==4.10.0.84
    opencv-python==4.10.0.84
    opencv-python-headless==4.10.0.84
    openpyxl==3.1.5
    opentelemetry-api==1.27.0
    opentelemetry-sdk==1.27.0
    opentelemetry-semantic-conventions==0.48b0
    opt_einsum==3.4.0
    optax==0.2.3
    optree==0.13.0
    optuna==4.0.0
    orbax-checkpoint==0.6.4
    osqp==0.6.7.post0
    packaging==24.1
    pandas==2.2.2
    pandas-datareader==0.10.0
    pandas-gbq==0.23.2
    pandas-stubs==2.2.2.240909
    pandocfilters==1.5.1
    panel==1.4.5
    param==2.1.1
    parso==0.8.4
    parsy==2.1
    partd==1.4.2
    pathlib==1.0.1
    patsy==0.5.6
    peewee==3.17.6
    pexpect==4.9.0
    pickleshare==0.7.5
    pillow==10.4.0
    pip==24.1.2
    pip-tools==7.4.1
    platformdirs==4.3.6
    plotly==5.24.1
    plotnine==0.13.6
    pluggy==1.5.0
    polars==1.7.1
    pooch==1.8.2
    portpicker==1.5.2
    prefetch_generator==1.0.3
    preshed==3.0.9
    prettytable==3.11.0
    proglog==0.1.10
    progressbar2==4.5.0
    prometheus_client==0.21.0
    promise==2.3
    prompt_toolkit==3.0.48
    prophet==1.1.6
    proto-plus==1.24.0
    protobuf==3.20.3
    psutil==5.9.5
    psycopg2==2.9.9
    ptyprocess==0.7.0
    py-cpuinfo==9.0.0
    py4j==0.10.9.7
    pyarrow==16.1.0
    pyarrow-hotfix==0.6
    pyasn1==0.6.1
    pyasn1_modules==0.4.1
    pycocotools==2.0.8
    pycparser==2.22
    pydantic==2.9.2
    pydantic_core==2.23.4
    pydata-google-auth==1.8.2
    pydot==3.0.2
    pydot-ng==2.0.0
    pydotplus==2.0.2
    PyDrive==1.3.1
    PyDrive2==1.20.0
    pyerfa==2.0.1.4
    pygame==2.6.1
    Pygments==2.18.0
    PyGObject==3.42.1
    PyJWT==2.9.0
    pymc==5.16.2
    pymystem3==0.2.0
    pynvjitlink-cu12==0.3.0
    pyogrio==0.10.0
    PyOpenGL==3.1.7
    pyOpenSSL==24.2.1
    pyparsing==3.1.4
    pyperclip==1.9.0
    pyproj==3.7.0
    pyproject_hooks==1.2.0
    pyshp==2.3.1
    PySocks==1.7.1
    pytensor==2.25.5
    pytest==7.4.4
    python-apt==0.0.0
    python-box==7.2.0
    python-dateutil==2.8.2
    python-louvain==0.16
    python-slugify==8.0.4
    python-utils==3.9.0
    pytz==2024.2
    pyviz_comms==3.0.3
    PyYAML==6.0.2
    pyzmq==24.0.1
    qdldl==0.1.7.post4
    ratelim==0.1.6
    referencing==0.35.1
    regex==2024.9.11
    requests==2.32.3
    requests-oauthlib==1.3.1
    requirements-parser==0.9.0
    rich==13.9.1
    rmm-cu12==24.6.0
    rpds-py==0.20.0
    rpy2==3.4.2
    rsa==4.9
    sacred==0.8.6
    safetensors==0.4.5
    scikit-image==0.24.0
    scikit-learn==1.5.2
    scipy==1.13.1
    scooby==0.10.0
    scs==3.2.7
    seaborn==0.13.1
    seals==0.2.1
    SecretStorage==3.3.1
    Send2Trash==1.8.3
    sentencepiece==0.2.0
    setuptools==71.0.4
    shapely==2.0.6
    shellingham==1.5.4
    Shimmy==1.3.0
    simple-parsing==0.1.6
    six==1.16.0
    sklearn-pandas==2.2.0
    smart-open==7.0.4
    smmap==5.0.1
    sniffio==1.3.1
    snowballstemmer==2.2.0
    sortedcontainers==2.4.0
    soundfile==0.12.1
    soupsieve==2.6
    soxr==0.5.0.post1
    spacy==3.7.5
    spacy-legacy==3.0.12
    spacy-loggers==1.0.5
    Sphinx==5.0.2
    sphinxcontrib-applehelp==2.0.0
    sphinxcontrib-devhelp==2.0.0
    sphinxcontrib-htmlhelp==2.1.0
    sphinxcontrib-jsmath==1.0.1
    sphinxcontrib-qthelp==2.0.0
    sphinxcontrib-serializinghtml==2.0.0
    SQLAlchemy==2.0.35
    sqlglot==25.1.0
    sqlparse==0.5.1
    srsly==2.4.8
    stable-baselines3==2.2.1
    stanio==0.5.1
    statsmodels==0.14.4
    StrEnum==0.4.15
    sympy==1.13.3
    tables==3.8.0
    tabulate==0.9.0
    tbb==2021.13.1
    tblib==3.0.0
    tenacity==9.0.0
    tensorboard==2.17.0
    tensorboard-data-server==0.7.2
    tensorflow==2.17.0
    tensorflow-datasets==4.9.6
    tensorflow-hub==0.16.1
    tensorflow-io-gcs-filesystem==0.37.1
    tensorflow-metadata==1.16.0
    tensorflow-probability==0.24.0
    tensorstore==0.1.66
    termcolor==2.4.0
    terminado==0.18.1
    text-unidecode==1.3
    textblob==0.17.1
    tf-slim==1.1.0
    tf_keras==2.17.0
    thinc==8.2.5
    threadpoolctl==3.5.0
    tifffile==2024.9.20
    tinycss2==1.3.0
    tokenizers==0.19.1
    toml==0.10.2
    tomli==2.0.2
    toolz==0.12.1
    torch @ https://download.pytorch.org/whl/cu121_full/torch-2.4.1%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=f3ed9a2b7f8671b2b32a2f036d1b81055eb3ad9b18ba43b705aa34bae4289e1a
    torchaudio @ https://download.pytorch.org/whl/cu121_full/torchaudio-2.4.1%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=da8c87c80a1c1376a48dc33eef30b03bbdf1df25a05bd2b1c620b8811c7b19be
    torchsummary==1.5.1
    torchvision @ https://download.pytorch.org/whl/cu121_full/torchvision-0.19.1%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=b8cc4bf381b75522995b601e07a1b433b5fd925dc3e34a7fa6cd22f449d65379
    tornado==6.3.3
    tqdm==4.66.5
    traitlets==5.7.1
    traittypes==0.2.1
    transformers==4.44.2
    tweepy==4.14.0
    typeguard==4.3.0
    typer==0.12.5
    types-pytz==2024.2.0.20241003
    types-setuptools==75.1.0.20240917
    typing_extensions==4.12.2
    tzdata==2024.2
    tzlocal==5.2
    uc-micro-py==1.0.3
    uritemplate==4.1.1
    urllib3==2.2.3
    vega-datasets==0.9.0
    wadllib==1.3.6
    wasabi==1.1.3
    wcwidth==0.2.13
    weasel==0.4.1
    webcolors==24.8.0
    webencodings==0.5.1
    websocket-client==1.8.0
    Werkzeug==3.0.4
    wheel==0.44.0
    widgetsnbextension==3.6.9
    wordcloud==1.9.3
    wrapt==1.16.0
    xarray==2024.9.0
    xarray-einstats==0.8.0
    xgboost==2.1.1
    xlrd==2.0.1
    xxhash==3.5.0
    xyzservices==2024.9.0
    yarl==1.13.1
    yellowbrick==1.5
    yfinance==0.2.44
    zict==3.0.0
    zipp==3.20.2

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions