diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index 6b4ec04ecc..4b02417ecd 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -11,6 +11,7 @@ import enum import functools import gc +import uuid import importlib import inspect import itertools @@ -64,7 +65,7 @@ from torch._dispatch.python import enable_python_dispatcher from torch._guards import TracingContext from torch._subclasses.meta_utils import is_sparse_compressed -from torch._utils_internal import log_compilation_event +from torch._utils_internal import log_compilation_event, log_chromium_event_internal from torch.fx._utils import _format_graph_code, lazy_format_graph_code from torch.nn.modules.lazy import LazyModuleMixin from torch.utils._triton import has_triton, has_triton_package @@ -212,6 +213,15 @@ def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None: frame_phase_timing[key][phase_name] += time_spent +def get_cache_stats() -> Dict[str, Any]: + """Get a bunch of metadata about cache hits and misses to use in chromium events""" + cache_stats = { + "fxgraph_cache_hit":counters["inductor"]["fxgraph_cache_hit"], + "fxgraph_cache_miss": counters["inductor"]["fxgraph_cache_miss"], + "fxgraph_cache_bypass": counters["inductor"]["fxgraph_cache_bypass"], + } + return cache_stats + # dynamo_timed is a context manager # By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics # where the key is the functions name. @@ -251,16 +261,20 @@ def dynamo_timed( fail_type: Optional[str] = None fail_reason: Optional[str] = None time_spent = float("-inf") + if phase_name == "entire_frame_compile": + reset_chromium_events() try: with torch.profiler.record_function(f"{key} (dynamo_timed)"): t0 = time.time() - ChromiumEventLogger.log_event_start(key, time.time_ns()) + start = time.time_ns() + ChromiumEventLogger.log_event_start(key, start, None) if phase_name: - ChromiumEventLogger.log_event_start(phase_name, time.time_ns()) + ChromiumEventLogger.log_event_start(phase_name, start) yield + if phase_name: - ChromiumEventLogger.log_event_end(phase_name, time.time_ns()) - ChromiumEventLogger.log_event_end(key, time.time_ns()) + ChromiumEventLogger.log_event_end(phase_name, time.time_ns(), {"cache_stats": get_cache_stats()}, start) + ChromiumEventLogger.log_event_end(key, time.time_ns(), {"cache_stats": get_cache_stats()}, start) time_spent = time.time() - t0 compilation_time_metrics[key].append(time_spent) except Exception as e: @@ -807,6 +821,18 @@ def get_compilation_metrics() -> List[Union[CompilationMetrics, BwdCompilationMe return list(_compilation_metrics) +chromium_event_stack = ["__start__"] +# Generate a unique id for this process, which we can use in scuba to filter down +# to a single python run. +# TODO: figure out what this actually should be reset at +compile_unique_id = str(uuid.uuid4()) + +def reset_chromium_events() -> None: + global chromium_event_stack + chromium_event_stack = ["__start__"] + + + class ChromiumEventLogger: """Logs chromium events to structured logs. tlparse will concatenate these into a perfetto UI link. @@ -826,18 +852,22 @@ def log_event_start( :param time_ns Timestamp in nanoseconds :param metadata: Any extra metadata associated with this event """ - ChromiumEventLogger._log_timed_event( + global chromium_event_stack + event = ChromiumEventLogger._log_timed_event( event_name, time_ns, "B", metadata, ) + log_chromium_event_internal(event, chromium_event_stack, compile_unique_id) + chromium_event_stack.append(event_name) @staticmethod def log_event_end( event_name: str, time_ns: int, metadata: Optional[Dict[str, Any]] = None, + start_time_ns: Optional[int] = None, ) -> None: """ Logs the end of a single event. This function should only be @@ -846,28 +876,53 @@ def log_event_end( :param time_ns: Timestamp in nanoseconds :param metadata: Any extra metadata associated with this event """ - ChromiumEventLogger._log_timed_event( + global chromium_event_stack + # These stack health checks currently never happen, + # but they're written this way to future proof any weird event + # overlaps in the future. + if (event_name not in chromium_event_stack): + # Something went wrong, we never called start on this event, + # or it was skipped due to overlapping events below + log.warn("Start event not in stack, ignoring") + return + + event = ChromiumEventLogger._log_timed_event( event_name, time_ns, "E", metadata, ) + while event_name != chromium_event_stack[-1]: + # If the event isn't the most recent one to end, pop + # off the stack until it is. + # Since event_name in chromium_event_stack, this pop is always safe + log.warn("Detected overlapping events, fixing stack") + chromium_event_stack.pop() + + log_chromium_event_internal(event, chromium_event_stack, compile_unique_id, start_time_ns) + # Finally pop the actual event off the stack + chromium_event_stack.pop() + + @staticmethod def _log_timed_event( event_name: str, time_ns: int, phase: str, metadata: Optional[Dict[str, Any]] = None, - ) -> None: + ) -> Dict[str, Any]: """ Logs a timed event in chromium format. See log_event_start, log_event_end, etc. """ event = { "name": event_name, - "ts": time_ns / 1000, # Chromium events are in ms + "ts": time_ns / 1000, # Chromium events are in micro seconds "args": metadata, "ph": phase, + # These categories are needed in all chromium traces + "cat": "dynamo_timed", + "tid": 0, "pid": 0, # pid should be specified on all logs, we don't personally care about the actual process id } torch._logging.trace_structured( @@ -876,6 +931,7 @@ def _log_timed_event( suppress_context=False, expect_trace_id=False, # Not every chromium event will have a trace_id ) + return event @staticmethod def log_instant_event( @@ -895,7 +951,10 @@ def log_instant_event( "ts": time_ns / 1000, "args": metadata, "ph": "i", - "pid": 0, # pid should be specified on all logs, we don't personally care about the actual process id + # These categories are needed in all chromium traces + "cat": "dynamo_timed", + "tid": 0, + "pid": 0, "s": "p", # We use "process" level instant events so they all appear on the same row in the trace. } torch._logging.trace_structured( @@ -904,6 +963,9 @@ def log_instant_event( suppress_context=False, expect_trace_id=True, ) + # Log an instant event with the same start and end time + log_chromium_event_internal(event, chromium_event_stack, compile_unique_id) + @dataclasses.dataclass