Skip to content

Commit 2f57561

Browse files
Xmaster6yvmoens
andauthored
[Feature] Add support for trackio (#3196)
Co-authored-by: vmoens <[email protected]>
1 parent 9c84bc5 commit 2f57561

File tree

7 files changed

+254
-1
lines changed

7 files changed

+254
-1
lines changed

.github/unittest/linux/scripts/run_all.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,12 @@ if [[ "$PYTHON_VERSION" != "3.14" ]]; then
143143
uv_pip_install ray
144144
fi
145145

146+
# Install trackio for Python < 3.14 (trackio wheels may not be available for Python 3.14 yet)
147+
if [[ "$PYTHON_VERSION" != "3.14" ]]; then
148+
echo "installing trackio"
149+
uv_pip_install trackio
150+
fi
151+
146152
# Install mujoco for Python < 3.14 (mujoco doesn't have Python 3.14 wheels yet)
147153
if [[ "$PYTHON_VERSION" != "3.14" ]]; then
148154
echo "installing mujoco"

.github/unittest/linux_distributed/scripts/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ dependencies:
2828
- dm_control
2929
- mujoco<3.3.6
3030
- mlflow
31+
- trackio
3132
- av
3233
- coverage
3334
- ray

.github/unittest/linux_sota/scripts/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ dependencies:
2525
- dm_control
2626
- mujoco<3.3.6
2727
- mlflow
28+
- trackio
2829
- av
2930
- coverage
3031
- vmas

docs/source/reference/trainers_loggers.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Logger classes for experiment tracking and visualization.
1515
csv.CSVLogger
1616
mlflow.MLFlowLogger
1717
tensorboard.TensorboardLogger
18+
trackio.TrackioLogger
1819
wandb.WandbLogger
1920
get_logger
2021
generate_exp_name

test/test_loggers.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torchrl.record.loggers.csv import CSVLogger
2222
from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger
2323
from torchrl.record.loggers.tensorboard import _has_tb, TensorboardLogger
24+
from torchrl.record.loggers.trackio import _has_trackio, TrackioLogger
2425
from torchrl.record.loggers.wandb import _has_wandb, WandbLogger
2526
from torchrl.record.recorder import PixelRenderTransform, VideoRecorder
2627

@@ -455,6 +456,78 @@ def make_env():
455456
env.close()
456457

457458

459+
@pytest.fixture()
460+
def trackio_logger():
461+
exp_name = "ramala"
462+
logger = TrackioLogger(project="test", exp_name=exp_name)
463+
yield logger
464+
logger.experiment.finish()
465+
del logger
466+
467+
468+
@pytest.mark.skipif(not _has_trackio, reason="trackio not installed")
469+
class TestTrackioLogger:
470+
@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
471+
def test_log_scalar(self, steps, trackio_logger):
472+
torch.manual_seed(0)
473+
474+
values = torch.rand(3)
475+
for i in range(3):
476+
scalar_name = "foo"
477+
scalar_value = values[i].item()
478+
trackio_logger.log_scalar(
479+
value=scalar_value,
480+
name=scalar_name,
481+
step=steps[i] if steps else None,
482+
)
483+
484+
@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
485+
def test_log_str(self, steps, trackio_logger):
486+
for i in range(3):
487+
trackio_logger.log_str(
488+
name="foo",
489+
value="bar",
490+
step=steps[i] if steps else None,
491+
)
492+
493+
def test_log_video(self, trackio_logger):
494+
torch.manual_seed(0)
495+
496+
# creating a sample video (T, C, H, W), where T - number of frames,
497+
# C - number of image channels (e.g. 3 for RGB), H, W - image dimensions.
498+
# the first 64 frames are black and the next 64 are white
499+
video = torch.cat(
500+
(torch.zeros(128, 3, 32, 32), torch.full((128, 3, 32, 32), 255))
501+
)
502+
video = video[None, :]
503+
trackio_logger.log_video(
504+
name="foo",
505+
video=video,
506+
fps=4,
507+
format="mp4",
508+
)
509+
trackio_logger.log_video(
510+
name="foo_16fps",
511+
video=video,
512+
fps=16,
513+
format="mp4",
514+
)
515+
516+
def test_log_hparams(self, trackio_logger, config):
517+
trackio_logger.log_hparams(config)
518+
for key, value in config.items():
519+
assert trackio_logger.experiment.config[key] == value
520+
521+
@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
522+
def test_log_histogram(self, steps, trackio_logger):
523+
torch.manual_seed(0)
524+
for i in range(3):
525+
data = torch.randn(100)
526+
trackio_logger.log_histogram(
527+
"hist", data, step=steps[i] if steps else None, bins=10
528+
)
529+
530+
458531
if __name__ == "__main__":
459532
args, unknown = argparse.ArgumentParser().parse_known_args()
460533
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/record/loggers/trackio.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
import importlib.util
8+
9+
from collections.abc import Sequence
10+
11+
import numpy as np
12+
13+
from torch import Tensor
14+
15+
from .common import Logger
16+
17+
_has_trackio = importlib.util.find_spec("trackio") is not None
18+
_has_omegaconf = importlib.util.find_spec("omegaconf") is not None
19+
20+
21+
class TrackioLogger(Logger):
22+
"""Wrapper for the trackio logger.
23+
24+
Args:
25+
exp_name (str): The name of the experiment.
26+
project (str): The name of the project.
27+
28+
Keyword Args:
29+
fps (int, optional): Number of frames per second when recording videos. Defaults to ``30``.
30+
**kwargs: Extra keyword arguments for ``trackio.init``.
31+
32+
"""
33+
34+
@classmethod
35+
def __new__(cls, *args, **kwargs):
36+
return super().__new__(cls)
37+
38+
def __init__(
39+
self,
40+
exp_name: str,
41+
project: str,
42+
*,
43+
video_fps: int = 32,
44+
**kwargs,
45+
) -> None:
46+
if not _has_trackio:
47+
raise ImportError("trackio could not be imported")
48+
49+
self.video_fps = video_fps
50+
self._trackio_kwargs = {
51+
"name": exp_name,
52+
"project": project,
53+
"resume": "allow",
54+
**kwargs,
55+
}
56+
57+
super().__init__(exp_name=exp_name, log_dir=project)
58+
59+
def _create_experiment(self):
60+
"""Creates a trackio experiment.
61+
62+
Args:
63+
exp_name (str): The name of the experiment.
64+
65+
Returns:
66+
A trackio.Experiment object.
67+
"""
68+
if not _has_trackio:
69+
raise ImportError("Trackio is not installed")
70+
import trackio
71+
72+
return trackio.init(**self._trackio_kwargs)
73+
74+
def log_scalar(self, name: str, value: float, step: int | None = None) -> None:
75+
"""Logs a scalar value to trackio.
76+
77+
Args:
78+
name (str): The name of the scalar.
79+
value (float): The value of the scalar.
80+
step (int, optional): The step at which the scalar is logged.
81+
Defaults to None.
82+
"""
83+
self.experiment.log({name: value}, step=step)
84+
85+
def log_video(self, name: str, video: Tensor, **kwargs) -> None:
86+
"""Log videos inputs to trackio.
87+
88+
Args:
89+
name (str): The name of the video.
90+
video (Tensor): The video to be logged.
91+
**kwargs: Other keyword arguments. By construction, log_video
92+
supports 'step' (integer indicating the step index), 'format'
93+
(default is 'mp4') and 'fps' (defaults to ``self.video_fps``). Other kwargs are
94+
passed as-is to the :obj:`experiment.log` method.
95+
"""
96+
import trackio
97+
98+
fps = kwargs.pop("fps", self.video_fps)
99+
format = kwargs.pop("format", "mp4")
100+
self.experiment.log(
101+
{
102+
name: trackio.Video(
103+
video.numpy().astype(np.uint8), fps=fps, format=format
104+
)
105+
},
106+
**kwargs,
107+
)
108+
109+
def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821
110+
"""Logs the hyperparameters of the experiment.
111+
112+
Args:
113+
cfg (DictConfig or dict): The configuration of the experiment.
114+
115+
"""
116+
if type(cfg) is not dict and _has_omegaconf:
117+
if not _has_omegaconf:
118+
raise ImportError(
119+
"OmegaConf could not be imported. "
120+
"Cannot log hydra configs without OmegaConf."
121+
)
122+
from omegaconf import OmegaConf
123+
124+
cfg = OmegaConf.to_container(cfg, resolve=True)
125+
self.experiment.config.update(cfg)
126+
127+
def __repr__(self) -> str:
128+
return f"TrackioLogger(experiment={self.experiment.__repr__()})"
129+
130+
def log_histogram(self, name: str, data: Sequence, **kwargs):
131+
"""Add histogram to log.
132+
133+
Args:
134+
name (str): Data identifier
135+
data (torch.Tensor, numpy.ndarray): Values to build histogram
136+
137+
Keyword Args:
138+
step (int): Global step value to record
139+
bins (int): Number of bins to use for the histogram
140+
141+
"""
142+
import trackio
143+
144+
num_bins = kwargs.pop("bins", None)
145+
step = kwargs.pop("step", None)
146+
self.experiment.log(
147+
{name: trackio.Histogram(data, num_bins=num_bins)}, step=step
148+
)
149+
150+
def log_str(self, name: str, value: str, step: int | None = None) -> None:
151+
"""Logs a string value to trackio using a table format for better visualization.
152+
153+
Args:
154+
name (str): The name of the string data.
155+
value (str): The string value to log.
156+
step (int, optional): The step at which the string is logged.
157+
Defaults to None.
158+
"""
159+
import trackio
160+
161+
# Create a table with a single row
162+
table = trackio.Table(columns=["text"], data=[[value]])
163+
self.experiment.log({name: table}, step=step)

torchrl/record/loggers/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def get_logger(
3535
If empty, ``None`` is returned.
3636
logger_name (str): Name to be used as a log_dir
3737
experiment_name (str): Name of the experiment
38-
kwargs (dict[str]): might contain either `wandb_kwargs` or `mlflow_kwargs`
38+
kwargs (dict[str]): might contain either `wandb_kwargs`, `mlflow_kwargs` or `trackio_kwargs`
3939
"""
4040
if logger_type == "tensorboard":
4141
from torchrl.record.loggers.tensorboard import TensorboardLogger
@@ -63,6 +63,14 @@ def get_logger(
6363
exp_name=experiment_name,
6464
**mlflow_kwargs,
6565
)
66+
elif logger_type == "trackio":
67+
from torchrl.record.loggers.trackio import TrackioLogger
68+
69+
trackio_kwargs = kwargs.get("trackio_kwargs", {})
70+
project = trackio_kwargs.pop("project", "torchrl")
71+
logger = TrackioLogger(
72+
project=project, exp_name=experiment_name, **trackio_kwargs
73+
)
6674
elif logger_type in ("", None):
6775
return None
6876
else:

0 commit comments

Comments
 (0)