Skip to content

Commit cbdb78f

Browse files
bmass02copybara-github
authored andcommitted
Add TraceMes to PyGrain for measuring input wait time on host.
PiperOrigin-RevId: 778664409
1 parent cb3ebf4 commit cbdb78f

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

grain/_src/python/dataset/stats.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,17 @@
4040

4141
from grain._src.core import monitoring
4242

43+
# Conditionally import profiler from JAX.
44+
# TODO: refactor this to conditionally import profiler from all
45+
# supported frameworks (e.g. TF/JAX/PyTorch)
46+
try:
47+
from jax import profiler # pylint: disable=g-import-not-at-top # pytype: disable=import-error
48+
49+
_TRACE_ANNOTATION = profiler.TraceAnnotation
50+
except ImportError:
51+
logging.warning("Failed to import TraceAnnotation.")
52+
_TRACE_ANNOTATION = None
53+
4354

4455
# Registry of weak references to output dataset iterators for collecting
4556
# execution stats.
@@ -310,8 +321,17 @@ def __next__(self):
310321

311322
@functools.wraps(next_fn)
312323
def wrapper(iterator):
313-
start_time = time.perf_counter_ns()
314-
result = next_fn(iterator)
324+
if _TRACE_ANNOTATION is not None and _TRACE_ANNOTATION.is_enabled():
325+
with _TRACE_ANNOTATION(
326+
f"{iterator.__class__.__name__}.{next_fn.__name__}",
327+
_ipl_stage_name=str(iterator),
328+
_ipl_stage_id=id(iterator),
329+
):
330+
start_time = time.perf_counter_ns()
331+
result = next_fn(iterator)
332+
else:
333+
start_time = time.perf_counter_ns()
334+
result = next_fn(iterator)
315335

316336
if iterator._stats._is_output: # pylint:disable=protected-access
317337
next_duration_ns = time.perf_counter_ns() - start_time

0 commit comments

Comments
 (0)