Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

[Feature] Expose additional arguments for Wandb #201

Merged
merged 7 commits into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion benchmarl/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,12 @@ evaluation_static: False

# List of loggers to use, options are: wandb, csv, tensorboard, mflow
loggers: [csv,wandb]
# Wandb project name
# Wandb project name (kept for backward compatibility)
project_name: "benchmarl"
# Wandb extra kwargs passed to the WandbLogger (~superset of wandb.init kwargs)
# WandbLogger includes: offline, save_dir, project, video_fps
# wandb.init includes: entity, tags, notes, etc.
wandb_extra_kwargs: {}
# Create a json folder as part of the output in the format of marl-eval
create_json: True

Expand Down
4 changes: 3 additions & 1 deletion benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class ExperimentConfig:

loggers: List[str] = MISSING
project_name: str = MISSING
wandb_extra_kwargs: Dict[str, Any] = MISSING
create_json: bool = MISSING

save_folder: Optional[str] = MISSING
Expand Down Expand Up @@ -611,7 +612,6 @@ def _setup_name(self):

def _setup_logger(self):
self.logger = Logger(
project_name=self.config.project_name,
experiment_name=self.name,
folder_name=str(self.folder_name),
experiment_config=self.config,
Expand All @@ -621,6 +621,8 @@ def _setup_logger(self):
task_name=self.task_name,
group_map=self.group_map,
seed=self.seed,
project_name=self.config.project_name,
wandb_extra_kwargs=self.config.wandb_extra_kwargs,
)
self.logger.log_hparams(
critic_model_name=self.critic_model_name,
Expand Down
11 changes: 9 additions & 2 deletions benchmarl/experiment/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections.abc import MutableMapping, Sequence
from pathlib import Path

from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -39,6 +39,7 @@ def __init__(
group_map: Dict[str, List[str]],
seed: int,
project_name: str,
wandb_extra_kwargs: Dict[str, Any],
):
self.experiment_config = experiment_config
self.algorithm_name = algorithm_name
Expand All @@ -62,15 +63,21 @@ def __init__(

self.loggers: List[torchrl.record.loggers.Logger] = []
for logger_name in experiment_config.loggers:
wandb_project = wandb_extra_kwargs.get("project", project_name)
if wandb_project != project_name:
raise ValueError(
f"wandb_extra_kwargs.project ({wandb_project}) is different from the project_name ({project_name})"
)
self.loggers.append(
get_logger(
logger_type=logger_name,
logger_name=folder_name,
experiment_name=experiment_name,
wandb_kwargs={
"group": task_name,
"project": project_name,
"id": experiment_name,
"project": project_name,
**wandb_extra_kwargs,
},
)
)
Expand Down
Loading