Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/scripts/filter-matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

# currently we don't support python 3.13t due to tensorrt does not support 3.13t
disabled_python_versions: List[str] = ["3.13t", "3.14", "3.14t"]
disabled_cuda_versions: List[str] = ["cu130"]

# jetpack 6.2 only officially supports python 3.10 and cu126
jetpack_python_versions: List[str] = ["3.10"]
Expand Down Expand Up @@ -36,7 +37,9 @@ def filter_matrix_item(
if item["python_version"] in disabled_python_versions:
# Skipping disabled Python version
return False

if item["desired_cuda"] in disabled_cuda_versions:
# Skipping disabled CUDA version
return False
if is_jetpack:
if limit_pr_builds:
# pr build,matrix passed from test-infra is cu128, python 3.9, change to cu126, python 3.10
Expand Down
39 changes: 25 additions & 14 deletions docsrc/getting_started/jetpack.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,22 @@ System Preparation
sudo cp -a libcusparse_lt-linux-sbsa-0.5.2.1-archive/include/* /usr/local/cuda/include/
sudo cp -a libcusparse_lt-linux-sbsa-0.5.2.1-archive/lib/* /usr/local/cuda/lib64/

Building Torch-TensorRT
***********************
Installation Torch-TensorRT in JetPack
*************************************

You can directly install the torch-tensorrt wheel from the JPL repo which is built specifically for JetPack 6.2.

.. code-block:: sh
# verify tensorrt 10.3 is already installed via jetpack installation process
python -m pip list | grep tensorrt
# install torch-tensorrt wheel from JPL repo which is built specifically for JetPack 6.2
python -m pip install torch==2.8.0 torch_tensorrt==2.8.0 torchvision==0.24.0 --extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu126


Building Torch-TensorRT in JetPack
*********************************

You can also build the torch-tensorrt wheel from the source code on your own.

Build Environment Setup
=======================
Expand Down Expand Up @@ -92,25 +106,22 @@ Build Environment Setup
# Can only install the torch and torchvision wheel from the JPL repo which is built specifically for JetPack 6.2
python -m pip install torch==2.8.0 torchvision==0.23.0 --index-url=https://pypi.jetson-ai-lab.io/jp6/cu126

4. **Build the Wheel**:

Building the Wheel
==================
.. code-block:: sh

.. code-block:: sh
python setup.py bdist_wheel --jetpack
python setup.py bdist_wheel --jetpack

Installation
============
5. **Install the Wheel**:

.. code-block:: sh
# you will be able to find the wheel in the dist directory, has platform name linux_tegra_aarch64
.. code-block:: sh

# you will be able to find the wheel in the dist directory
cd dist
python -m pip install torch_tensorrt-2.8.0.dev0+d8318d8fc-cp310-cp310-linux_tegra_aarch64.whl
python -m pip install torch_tensorrt-2.8.0.dev0+d8318d8fc-cp310-cp310-linux_aarch64.whl

Post-Installation Verification
==============================
6. **Verify installation by importing in Python**:

Verify installation by importing in Python:
.. code-block:: python

# verify whether the torch-tensorrt can be imported
Expand Down
11 changes: 11 additions & 0 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import collections.abc
import logging
import platform
import warnings
from enum import Enum
from typing import Any, Callable, List, Optional, Sequence, Set, Union

Expand Down Expand Up @@ -121,6 +122,11 @@ def _get_target_fe(module_type: _ModuleType, ir: str) -> _IRType:
"Requested using the TS frontend but the TS frontend is not available in this build of Torch-TensorRT"
)
elif module_is_fxable and ir_targets_fx:
warnings.warn(
"FX frontend is deprecated. Please use the Dynamo frontend instead.",
DeprecationWarning,
stacklevel=2,
)
if ENABLED_FEATURES.fx_frontend:
return _IRType.fx
else:
Expand Down Expand Up @@ -237,6 +243,11 @@ def compile(
)
return compiled_ts_module
elif target_ir == _IRType.fx:
warnings.warn(
"FX frontend is deprecated. Please use the Dynamo frontend instead.",
DeprecationWarning,
stacklevel=2,
)
if not ENABLED_FEATURES.fx_frontend:
raise RuntimeError(
"FX frontend is not enabled, cannot compile with target_ir=fx"
Expand Down
92 changes: 49 additions & 43 deletions tests/py/dynamo/automatic_plugin/test_automatic_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,54 +12,60 @@
from ..conversion.harness import DispatchTestCase


@triton.jit
def elementwise_mul_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr):
# Program ID determines the block of data each thread will process
pid = tl.program_id(0)
# Compute the range of elements that this thread block will work on
block_start = pid * BLOCK_SIZE
# Range of indices this thread will handle
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Load elements from the X and Y tensors
x_vals = tl.load(X + offsets)
y_vals = tl.load(Y + offsets)
# Perform the element-wise multiplication
z_vals = x_vals * y_vals
# Store the result in Z
tl.store(Z + offsets, z_vals)


@torch.library.custom_op("torchtrt_ex::elementwise_mul", mutates_args=()) # type: ignore[misc]
def elementwise_mul(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
# Ensure the tensors are on the GPU
assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device."
assert X.shape == Y.shape, "Tensors must have the same shape."

# Create output tensor
Z = torch.empty_like(X)

# Define block size
BLOCK_SIZE = 1024

# Grid of programs
grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],)

# Launch the kernel
elementwise_mul_kernel[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE)

return Z


@torch.library.register_fake("torchtrt_ex::elementwise_mul")
def elementwise_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x


if not torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx:
torch_tensorrt.dynamo.conversion.plugins.custom_op(
"torchtrt_ex::elementwise_mul", supports_dynamic_shapes=True
)


@unittest.skipIf(
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx,
"TensorRT RTX does not support plugins",
)
class TestAutomaticPlugin(DispatchTestCase):
@triton.jit
def elementwise_mul_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr):
# Program ID determines the block of data each thread will process
pid = tl.program_id(0)
# Compute the range of elements that this thread block will work on
block_start = pid * BLOCK_SIZE
# Range of indices this thread will handle
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Load elements from the X and Y tensors
x_vals = tl.load(X + offsets)
y_vals = tl.load(Y + offsets)
# Perform the element-wise multiplication
z_vals = x_vals * y_vals
# Store the result in Z
tl.store(Z + offsets, z_vals)

@torch.library.custom_op("torchtrt_ex::elementwise_mul", mutates_args=()) # type: ignore[misc]
def elementwise_mul(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
# Ensure the tensors are on the GPU
assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device."
assert X.shape == Y.shape, "Tensors must have the same shape."

# Create output tensor
Z = torch.empty_like(X)

# Define block size
BLOCK_SIZE = 1024

# Grid of programs
grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],)

# Launch the kernel
elementwise_mul_kernel[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE)

return Z

@torch.library.register_fake("torchtrt_ex::elementwise_mul")
def elementwise_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x

torch_tensorrt.dynamo.conversion.plugins.custom_op(
"torchtrt_ex::elementwise_mul", supports_dynamic_shapes=True
)

@parameterized.expand(
[
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import unittest
from typing import Tuple

import torch
Expand All @@ -11,57 +12,62 @@
from ..conversion.harness import DispatchTestCase


@unittest.skipIf(
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx,
"TensorRT RTX does not support plugins",
)
class TestAutomaticPlugin(DispatchTestCase):
@triton.jit
def elementwise_scale_mul_kernel(X, Y, Z, a, b, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
# Compute the range of elements that this thread block will work on
block_start = pid * BLOCK_SIZE
# Range of indices this thread will handle
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Load elements from the X and Y tensors
x_vals = tl.load(X + offsets)
y_vals = tl.load(Y + offsets)
# Perform the element-wise multiplication
z_vals = x_vals * y_vals * a + b
# Store the result in Z
tl.store(Z + offsets, z_vals)


@torch.library.custom_op("torchtrt_ex::elementwise_scale_mul", mutates_args=()) # type: ignore[misc]
def elementwise_scale_mul(
X: torch.Tensor, Y: torch.Tensor, b: float = 0.2, a: int = 2
) -> torch.Tensor:
# Ensure the tensors are on the GPU
assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device."
assert X.shape == Y.shape, "Tensors must have the same shape."

# Create output tensor
Z = torch.empty_like(X)

# Define block size
BLOCK_SIZE = 1024

# Grid of programs
grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],)

# Launch the kernel with parameters a and b
elementwise_scale_mul_kernel[grid](X, Y, Z, a, b, BLOCK_SIZE=BLOCK_SIZE)

@triton.jit
def elementwise_scale_mul_kernel(X, Y, Z, a, b, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
# Compute the range of elements that this thread block will work on
block_start = pid * BLOCK_SIZE
# Range of indices this thread will handle
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Load elements from the X and Y tensors
x_vals = tl.load(X + offsets)
y_vals = tl.load(Y + offsets)
# Perform the element-wise multiplication
z_vals = x_vals * y_vals * a + b
# Store the result in Z
tl.store(Z + offsets, z_vals)

@torch.library.custom_op("torchtrt_ex::elementwise_scale_mul", mutates_args=()) # type: ignore[misc]
def elementwise_scale_mul(
X: torch.Tensor, Y: torch.Tensor, b: float = 0.2, a: int = 2
) -> torch.Tensor:
# Ensure the tensors are on the GPU
assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device."
assert X.shape == Y.shape, "Tensors must have the same shape."

# Create output tensor
Z = torch.empty_like(X)

# Define block size
BLOCK_SIZE = 1024

# Grid of programs
grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],)

# Launch the kernel with parameters a and b
elementwise_scale_mul_kernel[grid](X, Y, Z, a, b, BLOCK_SIZE=BLOCK_SIZE)

return Z

@torch.library.register_fake("torchtrt_ex::elementwise_scale_mul")
def _(x: torch.Tensor, y: torch.Tensor, b: float = 0.2, a: int = 2) -> torch.Tensor:
return x
return Z


@torch.library.register_fake("torchtrt_ex::elementwise_scale_mul")
def _(x: torch.Tensor, y: torch.Tensor, b: float = 0.2, a: int = 2) -> torch.Tensor:
return x


if not torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx:
torch_tensorrt.dynamo.conversion.plugins.custom_op(
"torchtrt_ex::elementwise_scale_mul", supports_dynamic_shapes=True
)


@unittest.skipIf(
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx,
"TensorRT RTX does not support plugins",
)
class TestAutomaticPlugin(DispatchTestCase):

@parameterized.expand(
[
((64, 64), torch.float),
Expand Down
38 changes: 20 additions & 18 deletions tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,26 @@

from ..conversion.harness import DispatchTestCase

# flashinfer has been impacted by torch upstream change: https://github.com/pytorch/pytorch/commit/660b0b8128181d11165176ea3f979fa899f24db1
# got ImportError: cannot import name '_get_pybind11_abi_build_flags' from 'torch.utils.cpp_extension'
# if importlib.util.find_spec("flashinfer"):
# import flashinfer
if importlib.util.find_spec("flashinfer"):
import flashinfer


@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc]
def flashinfer_rmsnorm(
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> torch.Tensor:
return flashinfer.norm.rmsnorm(input, weight)


@torch.library.register_fake("flashinfer::rmsnorm")
def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor:
return input


if not torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx:
torch_tensorrt.dynamo.conversion.plugins.custom_op(
"flashinfer::rmsnorm", supports_dynamic_shapes=True
)


@unittest.skip("Not Available")
Expand All @@ -25,20 +41,6 @@
)
class TestAutomaticPlugin(DispatchTestCase):

@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc]
def flashinfer_rmsnorm(
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> torch.Tensor:
return flashinfer.norm.rmsnorm(input, weight)

@torch.library.register_fake("flashinfer::rmsnorm")
def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor:
return input

torch_tensorrt.dynamo.conversion.plugins.custom_op(
"flashinfer::rmsnorm", supports_dynamic_shapes=True
)

@parameterized.expand(
[
((64, 64), (64,), torch.float16),
Expand Down
Loading