Skip to content

Commit

Permalink
[PT2/Profiler] Add Context Info to Torch-Compiled Regions (#132765)
Browse files Browse the repository at this point in the history
Summary:
We want to add compile IDs and frames to each Torch-Compiled Region in order to help users cross reference the section they are checking alongside data obtained from tools, such as tlparse.
This diff operates on the assumption that each graph section will enter and exit a CompileContext before it is ran to either compile the graph or look it up in the cache. Based on this assuption, we can save the value of the graph section from the exited CompileContext in eval_frame.c using a Python C API. After this, we can create a new interface in cpp shim to wrap around the record_function in order to pass in the new keyword argument for "context".

Test Plan:
Enhance test_profiler_dynamo_compiled_region to look for kwinputs as well as a name to see that the context is now labeled. Also changed test to run graph with more contexts so that we test a wider range of profiling.

Differential Revision: D60803317

Pull Request resolved: pytorch/pytorch#132765
Approved by: https://github.com/anijain2305
  • Loading branch information
sraikund16 authored and pytorchmergebot committed Aug 27, 2024
1 parent de57a6e commit 0b81f70
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 16 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/record_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ struct TORCH_API RecordFunction {
if (!isActive()) {
return;
}
kwinputs_ = *kwargs;
kwinputs_ = std::unordered_map<std::string, IValue>(*kwargs);
before(std::move(fn), args, current_sequence_nr);
}

Expand Down
41 changes: 32 additions & 9 deletions test/dynamo/test_profiler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Owner(s): ["module: dynamo"]
import logging
from unittest.mock import patch

import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch._dynamo.utils
import torch._logging
from torch._dynamo.utils import dynamo_timed
from torch.testing._internal.common_utils import TemporaryFileName

Expand Down Expand Up @@ -163,20 +165,41 @@ def fn(x, y, z):
)

def test_profiler_dynamo_compiled_region(self):
def fn(x, y, z):
return x @ y + z
torch._logging.set_logs(dynamo=logging.INFO)

opt_fn = torch._dynamo.optimize("eager")(fn)
def fn(x, y):
r = y.sum(dim=1)
print(r.shape)
return x * r

inputs = [torch.rand(4, 4) for _ in range(3)]
fn_c = torch.compile(fn)

for _ in range(2):
opt_fn(*inputs)
with torch.profiler.profile(record_shapes=True) as prof:
fn_c(
torch.randn(10),
torch.randn(10, 10),
)

with torch.profiler.profile() as prof:
opt_fn(*inputs)
fn_c(
torch.randn(10),
torch.randn(10, 15),
)

self.assertTrue(any(e.name == "Torch-Compiled Region" for e in prof.events()))
for e in prof.events():
if e.name == "Torch-Compiled Region":
print(e.kwinputs)
self.assertTrue(
any(
e.name == "Torch-Compiled Region" and e.kwinputs["context"] == "0/0_1"
for e in prof.events()
)
)
self.assertTrue(
any(
e.name == "Torch-Compiled Region" and e.kwinputs["context"] == "1/0"
for e in prof.events()
)
)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion torch/_C/_dynamo/eval_frame.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import types
from typing import NewType
from typing import NewType, Tuple

from torch._dynamo.types import DynamoCallback, DynamoGuardHook

Expand All @@ -13,6 +13,7 @@ def reset_code(code: types.CodeType) -> None: ...
def unsupported(obj1: object, obj2: object) -> object: ...
def skip_code(code: types.CodeType) -> None: ...
def set_guard_error_hook(hook: DynamoGuardHook) -> None: ...
def set_context_frame(context: Tuple[int, int, int]) -> None: ...

class _CacheEntry:
def check_fn(self, *args, **kwargs): ...
Expand Down
10 changes: 10 additions & 0 deletions torch/_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
TypeVar,
)

from torch._C._dynamo.eval_frame import set_context_frame # noqa: F401
from torch.utils import _pytree as pytree
from torch.utils._traceback import CapturedTraceback
from torch.utils.weak import WeakTensorKeyDictionary
Expand Down Expand Up @@ -781,6 +782,15 @@ def compile_context(context: Optional[CompileContext]):
try:
yield context
finally:
if context is not None:
if context.compile_id is not None:
set_context_frame(
(
context.compile_id.frame_id,
context.compile_id.frame_compile_id,
context.attempt,
)
)
_TLS.compile_context = old_context


Expand Down
21 changes: 19 additions & 2 deletions torch/csrc/dynamo/cpp_shim.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include <torch/csrc/dynamo/cpp_shim.h>

#include <ATen/record_function.h>
#include <torch/csrc/dynamo/cpp_shim.h>

struct _PytorchRecordFunctionState {
at::RecordFunction guard;
Expand All @@ -14,6 +13,24 @@ _PytorchRecordFunctionState* _pytorch_record_function_enter(const char* name) {
return state;
}

static inline _PytorchRecordFunctionState*
_pytorch_record_function_enter_with_kwinputs(
const char* name,
const std::unordered_map<std::string, c10::IValue>* kwargs) {
_PytorchRecordFunctionState* state = new _PytorchRecordFunctionState();
std::vector<c10::IValue> args;
state->guard.before(name, &args, kwargs);
return state;
}

_PytorchRecordFunctionState* _pytorch_record_function_enter_with_context(
const char* name,
const char* context) {
auto map = std::unordered_map<std::string, c10::IValue>();
map.insert({"context", c10::IValue(context)});
return _pytorch_record_function_enter_with_kwinputs(name, &map);
}

void _pytorch_record_function_exit(_PytorchRecordFunctionState* state) {
if (state == nullptr) {
return;
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/dynamo/cpp_shim.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#pragma once

#ifdef __cplusplus
extern "C" {
#endif
Expand All @@ -8,6 +7,9 @@ struct _PytorchRecordFunctionState;
typedef struct _PytorchRecordFunctionState _PytorchRecordFunctionState;

_PytorchRecordFunctionState* _pytorch_record_function_enter(const char* name);
_PytorchRecordFunctionState* _pytorch_record_function_enter_with_context(
const char* name,
const char* context);
void _pytorch_record_function_exit(_PytorchRecordFunctionState* state);

#ifdef __cplusplus
Expand Down
22 changes: 20 additions & 2 deletions torch/csrc/dynamo/eval_frame.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
#include <opcode.h>
#include <stdbool.h>

#define MAX_COMPILE_CONTEXT_SIZE 100

PyObject* guard_error_hook = NULL;
const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup";

static char compile_context[MAX_COMPILE_CONTEXT_SIZE];
static int active_dynamo_threads = 0;

static Py_tss_t eval_frame_callback_key = Py_tss_NEEDS_INIT;
Expand Down Expand Up @@ -483,7 +485,8 @@ inline static PyObject* eval_custom_code(
PyCodeObject* code,
int throw_flag,
int free_vars_copied) {
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter("Torch-Compiled Region");
const char* trace_id = compile_context;
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter_with_context("Torch-Compiled Region", trace_id);
PyObject* result = eval_custom_code_impl(
tstate,
frame,
Expand Down Expand Up @@ -817,12 +820,27 @@ static PyObject* set_guard_error_hook(PyObject* dummy, PyObject* obj) {
Py_RETURN_NONE;
}

static PyObject* set_context_frame(PyObject* dummy, PyObject* obj) {
int frame_id, frame_compile_id, attempt;
if (!PyArg_ParseTuple(obj, "iii", &frame_id, &frame_compile_id, &attempt)) {
PyErr_SetString(PyExc_TypeError, "Expected three integers");
return NULL;
}
if (attempt == 0) {
sprintf(compile_context, "%d/%d", frame_id, frame_compile_id);
} else {
sprintf(compile_context, "%d/%d_%d", frame_id, frame_compile_id, attempt);
}
Py_RETURN_NONE;
}

static PyMethodDef _methods[] = {
{"set_eval_frame", set_eval_frame_py, METH_O, NULL},
{"reset_code", reset_code, METH_O, NULL},
{"unsupported", unsupported, METH_VARARGS, NULL},
{"skip_code", skip_code, METH_O, NULL},
{"set_guard_error_hook", set_guard_error_hook, METH_O, NULL},
{"set_context_frame", set_context_frame, METH_O, NULL},
{NULL, NULL, 0, NULL}};

static struct PyModuleDef _module = {
Expand Down

0 comments on commit 0b81f70

Please sign in to comment.