Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 60 additions & 42 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,13 @@ def __init__(
multi_gpu_device_check()

self.name = name
self._input_buffers: List[torch.Tensor] = []
self._output_buffers: List[torch.Tensor] = []
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
self._input_buffers: Dict[str, List[torch.Tensor]] = {}
self._output_buffers: Dict[str, List[torch.Tensor]] = {}
self._caller_stream: Optional[torch.cuda.Stream] = None
self._engine_stream: Optional[torch.cuda.Stream] = None

# TODO: Make the below a Dictionary {shape: cudagraph}
self.shape_key: Optional[str] = None
self.shape_key_to_cudagraph: Dict[str, torch.cuda.CUDAGraph] = {}

# See https://github.com/pytorch/pytorch/blob/acfe237a71af609e837a34bb38048aa8acb8eb4d/torch/cuda/graphs.py#L92-L98
# Unused currently - to be used by Dynamic Shape support implementation
Expand Down Expand Up @@ -293,9 +292,6 @@ def setup_engine(self) -> None:
if self.requires_output_allocator:
self.create_output_allocator()

if torch_tensorrt.runtime.get_cudagraphs_mode():
self.cudagraph = torch.cuda.CUDAGraph()

def _check_initialized(self) -> None:
if not self.initialized:
raise RuntimeError("PythonTorchTensorRTModule is not initialized.")
Expand Down Expand Up @@ -342,10 +338,13 @@ def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule:
result.__setstate__(self.__getstate__())
return result

def _reset_captured_graph(self) -> None:
if self.cudagraph:
self.cudagraph.reset()
self.cudagraph = None
def _reset_captured_graph(self, inputs_shape_key: str | None = None) -> None:
if (
inputs_shape_key is not None
and inputs_shape_key in self.shape_key_to_cudagraph
):
self.shape_key_to_cudagraph[inputs_shape_key].reset()
self.shape_key_to_cudagraph.pop(inputs_shape_key)

def __del__(self) -> None:
self._reset_captured_graph()
Expand All @@ -355,6 +354,7 @@ def setup_input_tensors(
contiguous_inputs: List[torch.Tensor],
cudagraphs_enabled: bool,
need_cudagraphs_record: bool,
inputs_shape_key: str | None = None,
) -> None:
for i, input_name in enumerate(self.input_names):
if not contiguous_inputs[i].is_cuda:
Expand All @@ -374,14 +374,22 @@ def setup_input_tensors(
contiguous_inputs[i].dtype == self.input_dtypes[i]
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."

is_shape_tensor_input = self.engine.is_shape_inference_io(input_name)
if need_cudagraphs_record:
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
# Clone is required to avoid re-using user-provided GPU memory
self._input_buffers[i] = contiguous_inputs[i].clone()
if is_shape_tensor_input:
self._input_buffers[inputs_shape_key][i] = (
contiguous_inputs[i].cpu().clone()
)
else:
self._input_buffers[inputs_shape_key][i] = contiguous_inputs[
i
].clone()

# For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers
# as per TensorRT requirements
if self.engine.is_shape_inference_io(input_name):
if is_shape_tensor_input:
# Shape tensor inputs are casted to int64 explicitly
# Currently Torch CPU pointers are not working; numpy pointers are used instead
# to refer to underlying memory
Expand All @@ -392,9 +400,9 @@ def setup_input_tensors(
input_name, tuple(contiguous_inputs[i].shape)
)
if cudagraphs_enabled:
self._input_buffers[i].copy_(contiguous_inputs[i])
self._input_buffers[inputs_shape_key][i].copy_(contiguous_inputs[i])
self.context.set_tensor_address(
input_name, self._input_buffers[i].data_ptr()
input_name, self._input_buffers[inputs_shape_key][i].data_ptr()
)
else:
self.context.set_tensor_address(
Expand Down Expand Up @@ -430,7 +438,7 @@ def create_output_allocator(self) -> None:
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:

def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
shape_changed = self.validate_input_shapes(inputs)
shape_changed, inputs_shape_key = self.validate_input_shapes(inputs)
(
need_cudagraphs_record,
can_use_pre_allocated_outputs,
Expand All @@ -440,11 +448,11 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
)

if need_cudagraphs_reset:
self._reset_captured_graph()
self._reset_captured_graph(inputs_shape_key)

if need_cudagraphs_record:
self._input_buffers = [None] * len(self.input_names)
self._output_buffers = [None] * len(self.output_names)
self._input_buffers[inputs_shape_key] = [None] * len(self.input_names)
self._output_buffers[inputs_shape_key] = [None] * len(self.output_names)

with (
torch.autograd.profiler.record_function(
Expand All @@ -458,7 +466,10 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}."

self.setup_input_tensors(
contiguous_inputs, self.cudagraphs_enabled, need_cudagraphs_record
contiguous_inputs,
self.cudagraphs_enabled,
need_cudagraphs_record,
inputs_shape_key,
)

if shape_changed:
Expand Down Expand Up @@ -492,11 +503,12 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:

for o, output_name in enumerate(self.output_names):
if need_cudagraphs_record:
self._output_buffers[o] = outputs[o].clone()
self._output_buffers[inputs_shape_key][o] = outputs[o].clone()

if self.cudagraphs_enabled:
self.context.set_tensor_address(
output_name, self._output_buffers[o].data_ptr()
output_name,
self._output_buffers[inputs_shape_key][o].data_ptr(),
)
else:
self.context.set_tensor_address(
Expand All @@ -522,24 +534,31 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
with torch.cuda.stream(self._engine_stream):
if self.cudagraphs_enabled:
if need_cudagraphs_record:
self.cudagraph = torch.cuda.CUDAGraph()
self.shape_key_to_cudagraph[inputs_shape_key] = (
torch.cuda.CUDAGraph()
)

if self.profiling_enabled:
self.cudagraph.enable_debug_mode()
self.shape_key_to_cudagraph[
inputs_shape_key
].enable_debug_mode()

with torch.cuda.graph(
self.cudagraph, stream=self._engine_stream
self.shape_key_to_cudagraph[inputs_shape_key],
stream=self._engine_stream,
):
self.context.execute_async_v3(
self._engine_stream.cuda_stream
)

if self.profiling_enabled:
self.cudagraph.debug_dump(
self.shape_key_to_cudagraph[
inputs_shape_key
].debug_dump(
f"{DEBUG_LOGGING_DIR}/{self.name}_cudagraph.dot"
)

self.cudagraph.replay() # type: ignore
self.shape_key_to_cudagraph[inputs_shape_key].replay() # type: ignore

else:
self.context.execute_async_v3(self._engine_stream.cuda_stream)
Expand All @@ -551,7 +570,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:

if self.cudagraphs_enabled:
for idx, o in enumerate(outputs):
o.copy_(self._output_buffers[idx])
o.copy_(self._output_buffers[inputs_shape_key][idx])

if len(outputs) == 1:
return outputs[0]
Expand Down Expand Up @@ -742,27 +761,26 @@ def get_layer_info(self) -> str:
)
return engine_json

def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
def validate_input_shapes(
self, inputs: Sequence[torch.Tensor | Any]
) -> Tuple[bool, str]:
"""
Validates the input shapes of the forward function has changed
"""
# Representation of input shapes to a given model
# Shapes are concatenated as so:
# x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
tensor_inputs = []
for t in inputs:
if not isinstance(t, torch.Tensor):
return True
tensor_inputs.append(t)
new_shape_key = "".join(
str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs
str(
tuple(t.shape if hasattr(t, "shape") else torch.tensor(t).shape)
).replace(" ", "")
for t in inputs
)

# If the new shape key differs from the existing one,
# invalidate the old shape key and remove the CUDAGraph
if new_shape_key != self.shape_key:
logger.debug(f"Input shape changed {self.shape_key} -> {new_shape_key}")
self.shape_key = new_shape_key
return True
if new_shape_key not in self.shape_key_to_cudagraph:
logger.debug(
f"The user provided input shape {new_shape_key} is not found in recorded CUDAGraph input shapes. A new CUDAGraph will be recorded with this input shape."
)
return True, new_shape_key

return False
return False, new_shape_key
2 changes: 1 addition & 1 deletion tools/llm/run_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def measure_perf(trt_model, input_signature, backend_name):
arg_parser.add_argument(
"--model",
type=str,
default="meta-llama/Llama-3.2-1B-Instruct",
default="Qwen/Qwen2.5-0.5B-Instruct",
help="Name of LLM model",
)
arg_parser.add_argument(
Expand Down
Loading