Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions training/tests/integration/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,69 @@
LOGGER = logging.getLogger(__name__)


@pytest.mark.multigpu
@pytest.mark.slow
def test_benchmark_dataloader(
benchmark_config: tuple[DictConfig, str], # cfg, benchmarkTestCase
) -> None:
"""Runs a benchmark for dataloader performance, testing MultiDataset batch sampling speed."""
import time

from anemoi.graphs.create import GraphCreator
from anemoi.training.data.datamodule import AnemoiDatasetsDataModule

cfg, test_case = benchmark_config
cfg.graph.nodes.data.node_builder.dataset = cfg.system.input.dataset
LOGGER.info("Benchmarking dataloader for configuration: %s", test_case)

# Reset memory logging and free all possible memory between runs
# this ensures we report the peak memory used during each run,
# and not the peak memory used by the run with the highest memory usage
reset_peak_memory_stats()
empty_cache()
gc.collect()

# Initialize the forecaster to get graph data
graph = GraphCreator(config=cfg.graph).create(overwrite=True)

# Initialize datamodule with graph data
datamodule = AnemoiDatasetsDataModule(config=cfg, graph_data={"data": graph})

# Get training dataloader
train_dataloader = datamodule.train_dataloader()

# Benchmark batch sampling speed
num_batches_to_test = 100
LOGGER.info("Testing %d batches from MultiDataset", num_batches_to_test)

start_time = time.perf_counter()
batch_count = 0

for batch_idx, batch in enumerate(train_dataloader):
if batch_idx >= num_batches_to_test:
break
batch_count += 1

# Log first batch structure
if batch_idx == 0:
LOGGER.info("First batch structure:")
for dataset_name, data in batch.items():
LOGGER.info(" Dataset '%s': shape %s, dtype %s", dataset_name, data.shape, data.dtype)

end_time = time.perf_counter()
elapsed_time = end_time - start_time

# Calculate performance metrics
batches_per_second = batch_count / elapsed_time
time_per_batch_ms = (elapsed_time / batch_count) * 1000

LOGGER.info("Dataloader Performance Results:")
LOGGER.info(" Total batches: %d", batch_count)
LOGGER.info(" Total time: %.2f seconds", elapsed_time)
LOGGER.info(" Throughput: %.2f it/s", batches_per_second)
LOGGER.info(" Time per batch: %.2f ms", time_per_batch_ms)


@pytest.mark.multigpu
@pytest.mark.slow
def test_benchmark_training_cycle(
Expand Down