|
40 | 40 |
|
41 | 41 | from grain._src.core import monitoring |
42 | 42 |
|
| 43 | +# Conditionally import profiler from JAX. |
| 44 | +# TODO: refactor this to conditionally import profiler from all |
| 45 | +# supported frameworks (e.g. TF/JAX/PyTorch) |
| 46 | +try: |
| 47 | + from jax import profiler # pylint: disable=g-import-not-at-top # pytype: disable=import-error |
| 48 | + |
| 49 | + _TRACE_ANNOTATION = profiler.TraceAnnotation |
| 50 | +except ImportError: |
| 51 | + logging.warning("Failed to import TraceAnnotation.") |
| 52 | + _TRACE_ANNOTATION = None |
| 53 | + |
43 | 54 |
|
44 | 55 | # Registry of weak references to output dataset iterators for collecting |
45 | 56 | # execution stats. |
@@ -310,8 +321,17 @@ def __next__(self): |
310 | 321 |
|
311 | 322 | @functools.wraps(next_fn) |
312 | 323 | def wrapper(iterator): |
313 | | - start_time = time.perf_counter_ns() |
314 | | - result = next_fn(iterator) |
| 324 | + if _TRACE_ANNOTATION is not None and _TRACE_ANNOTATION.is_enabled(): |
| 325 | + with _TRACE_ANNOTATION( |
| 326 | + f"{iterator.__class__.__name__}.{next_fn.__name__}", |
| 327 | + _ipl_stage_name=str(iterator), |
| 328 | + _ipl_stage_id=id(iterator), |
| 329 | + ): |
| 330 | + start_time = time.perf_counter_ns() |
| 331 | + result = next_fn(iterator) |
| 332 | + else: |
| 333 | + start_time = time.perf_counter_ns() |
| 334 | + result = next_fn(iterator) |
315 | 335 |
|
316 | 336 | if iterator._stats._is_output: # pylint:disable=protected-access |
317 | 337 | next_duration_ns = time.perf_counter_ns() - start_time |
|
0 commit comments