From 545364e5f99d3cc6f17ca1d218852321236a09e9 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 26 Aug 2025 16:03:44 -0700 Subject: [PATCH 1/4] add deprecatin notice for fx frontend --- py/torch_tensorrt/_compile.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index acae618f1b..7f4af1b68d 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -3,6 +3,7 @@ import collections.abc import logging import platform +import warnings from enum import Enum from typing import Any, Callable, List, Optional, Sequence, Set, Union @@ -121,6 +122,11 @@ def _get_target_fe(module_type: _ModuleType, ir: str) -> _IRType: "Requested using the TS frontend but the TS frontend is not available in this build of Torch-TensorRT" ) elif module_is_fxable and ir_targets_fx: + warnings.warn( + "FX frontend is deprecated. Please use the Dynamo frontend instead.", + DeprecationWarning, + stacklevel=2, + ) if ENABLED_FEATURES.fx_frontend: return _IRType.fx else: From 863eeceedf32050912aed82b223c581f4fcd0a37 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 26 Aug 2025 16:24:04 -0700 Subject: [PATCH 2/4] update docs --- docsrc/getting_started/jetpack.rst | 39 +++++++++++++++++++----------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/docsrc/getting_started/jetpack.rst b/docsrc/getting_started/jetpack.rst index f032685b68..874b526a2e 100644 --- a/docsrc/getting_started/jetpack.rst +++ b/docsrc/getting_started/jetpack.rst @@ -60,8 +60,22 @@ System Preparation sudo cp -a libcusparse_lt-linux-sbsa-0.5.2.1-archive/include/* /usr/local/cuda/include/ sudo cp -a libcusparse_lt-linux-sbsa-0.5.2.1-archive/lib/* /usr/local/cuda/lib64/ -Building Torch-TensorRT -*********************** +Installation Torch-TensorRT in JetPack +************************************* + +You can directly install the torch-tensorrt wheel from the JPL repo which is built specifically for JetPack 6.2. + +.. code-block:: sh + # verify tensorrt 10.3 is already installed via jetpack installation process + python -m pip list | grep tensorrt + # install torch-tensorrt wheel from JPL repo which is built specifically for JetPack 6.2 + python -m pip install torch==2.8.0 torch_tensorrt==2.8.0 torchvision==0.24.0 --extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu126 + + +Building Torch-TensorRT in JetPack +********************************* + +You can also build the torch-tensorrt wheel from the source code on your own. Build Environment Setup ======================= @@ -92,25 +106,22 @@ Build Environment Setup # Can only install the torch and torchvision wheel from the JPL repo which is built specifically for JetPack 6.2 python -m pip install torch==2.8.0 torchvision==0.23.0 --index-url=https://pypi.jetson-ai-lab.io/jp6/cu126 +4. **Build the Wheel**: -Building the Wheel -================== + .. code-block:: sh -.. code-block:: sh - python setup.py bdist_wheel --jetpack + python setup.py bdist_wheel --jetpack -Installation -============ +5. **Install the Wheel**: -.. code-block:: sh - # you will be able to find the wheel in the dist directory, has platform name linux_tegra_aarch64 + .. code-block:: sh + + # you will be able to find the wheel in the dist directory cd dist - python -m pip install torch_tensorrt-2.8.0.dev0+d8318d8fc-cp310-cp310-linux_tegra_aarch64.whl + python -m pip install torch_tensorrt-2.8.0.dev0+d8318d8fc-cp310-cp310-linux_aarch64.whl -Post-Installation Verification -============================== +6. **Verify installation by importing in Python**: -Verify installation by importing in Python: .. code-block:: python # verify whether the torch-tensorrt can be imported From e339583887218669f3524e7a06324102728bade8 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 27 Aug 2025 16:15:50 -0700 Subject: [PATCH 3/4] added fx deprecation notice --- py/torch_tensorrt/_compile.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 7f4af1b68d..82e95955a1 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -243,6 +243,11 @@ def compile( ) return compiled_ts_module elif target_ir == _IRType.fx: + warnings.warn( + "FX frontend is deprecated. Please use the Dynamo frontend instead.", + DeprecationWarning, + stacklevel=2, + ) if ( torch.float16 in enabled_precisions_set or torch_tensorrt.dtype.half in enabled_precisions_set From 39a18da961f817f4ec5b978ad4b348add43c05c8 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 28 Aug 2025 10:16:26 -0700 Subject: [PATCH 4/4] reenable flashinfer --- .github/scripts/filter-matrix.py | 5 +- .../automatic_plugin/test_automatic_plugin.py | 92 +++++++++--------- .../test_automatic_plugin_with_attrs.py | 96 ++++++++++--------- .../test_flashinfer_rmsnorm.py | 38 ++++---- 4 files changed, 124 insertions(+), 107 deletions(-) diff --git a/.github/scripts/filter-matrix.py b/.github/scripts/filter-matrix.py index 674ca1866c..14bcb7028d 100644 --- a/.github/scripts/filter-matrix.py +++ b/.github/scripts/filter-matrix.py @@ -8,6 +8,7 @@ # currently we don't support python 3.13t due to tensorrt does not support 3.13t disabled_python_versions: List[str] = ["3.13t", "3.14", "3.14t"] +disabled_cuda_versions: List[str] = ["cu130"] # jetpack 6.2 only officially supports python 3.10 and cu126 jetpack_python_versions: List[str] = ["3.10"] @@ -36,7 +37,9 @@ def filter_matrix_item( if item["python_version"] in disabled_python_versions: # Skipping disabled Python version return False - + if item["desired_cuda"] in disabled_cuda_versions: + # Skipping disabled CUDA version + return False if is_jetpack: if limit_pr_builds: # pr build,matrix passed from test-infra is cu128, python 3.9, change to cu126, python 3.10 diff --git a/tests/py/dynamo/automatic_plugin/test_automatic_plugin.py b/tests/py/dynamo/automatic_plugin/test_automatic_plugin.py index 83e367ff5f..44acb7b105 100644 --- a/tests/py/dynamo/automatic_plugin/test_automatic_plugin.py +++ b/tests/py/dynamo/automatic_plugin/test_automatic_plugin.py @@ -12,54 +12,60 @@ from ..conversion.harness import DispatchTestCase +@triton.jit +def elementwise_mul_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr): + # Program ID determines the block of data each thread will process + pid = tl.program_id(0) + # Compute the range of elements that this thread block will work on + block_start = pid * BLOCK_SIZE + # Range of indices this thread will handle + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Load elements from the X and Y tensors + x_vals = tl.load(X + offsets) + y_vals = tl.load(Y + offsets) + # Perform the element-wise multiplication + z_vals = x_vals * y_vals + # Store the result in Z + tl.store(Z + offsets, z_vals) + + +@torch.library.custom_op("torchtrt_ex::elementwise_mul", mutates_args=()) # type: ignore[misc] +def elementwise_mul(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: + # Ensure the tensors are on the GPU + assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device." + assert X.shape == Y.shape, "Tensors must have the same shape." + + # Create output tensor + Z = torch.empty_like(X) + + # Define block size + BLOCK_SIZE = 1024 + + # Grid of programs + grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],) + + # Launch the kernel + elementwise_mul_kernel[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE) + + return Z + + +@torch.library.register_fake("torchtrt_ex::elementwise_mul") +def elementwise_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + + +if not torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx: + torch_tensorrt.dynamo.conversion.plugins.custom_op( + "torchtrt_ex::elementwise_mul", supports_dynamic_shapes=True + ) + + @unittest.skipIf( torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx, "TensorRT RTX does not support plugins", ) class TestAutomaticPlugin(DispatchTestCase): - @triton.jit - def elementwise_mul_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr): - # Program ID determines the block of data each thread will process - pid = tl.program_id(0) - # Compute the range of elements that this thread block will work on - block_start = pid * BLOCK_SIZE - # Range of indices this thread will handle - offsets = block_start + tl.arange(0, BLOCK_SIZE) - # Load elements from the X and Y tensors - x_vals = tl.load(X + offsets) - y_vals = tl.load(Y + offsets) - # Perform the element-wise multiplication - z_vals = x_vals * y_vals - # Store the result in Z - tl.store(Z + offsets, z_vals) - - @torch.library.custom_op("torchtrt_ex::elementwise_mul", mutates_args=()) # type: ignore[misc] - def elementwise_mul(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: - # Ensure the tensors are on the GPU - assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device." - assert X.shape == Y.shape, "Tensors must have the same shape." - - # Create output tensor - Z = torch.empty_like(X) - - # Define block size - BLOCK_SIZE = 1024 - - # Grid of programs - grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],) - - # Launch the kernel - elementwise_mul_kernel[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE) - - return Z - - @torch.library.register_fake("torchtrt_ex::elementwise_mul") - def elementwise_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return x - - torch_tensorrt.dynamo.conversion.plugins.custom_op( - "torchtrt_ex::elementwise_mul", supports_dynamic_shapes=True - ) @parameterized.expand( [ diff --git a/tests/py/dynamo/automatic_plugin/test_automatic_plugin_with_attrs.py b/tests/py/dynamo/automatic_plugin/test_automatic_plugin_with_attrs.py index 5153ead976..823d0d600e 100644 --- a/tests/py/dynamo/automatic_plugin/test_automatic_plugin_with_attrs.py +++ b/tests/py/dynamo/automatic_plugin/test_automatic_plugin_with_attrs.py @@ -1,3 +1,4 @@ +import unittest from typing import Tuple import torch @@ -11,57 +12,62 @@ from ..conversion.harness import DispatchTestCase -@unittest.skipIf( - torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx, - "TensorRT RTX does not support plugins", -) -class TestAutomaticPlugin(DispatchTestCase): +@triton.jit +def elementwise_scale_mul_kernel(X, Y, Z, a, b, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + # Compute the range of elements that this thread block will work on + block_start = pid * BLOCK_SIZE + # Range of indices this thread will handle + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Load elements from the X and Y tensors + x_vals = tl.load(X + offsets) + y_vals = tl.load(Y + offsets) + # Perform the element-wise multiplication + z_vals = x_vals * y_vals * a + b + # Store the result in Z + tl.store(Z + offsets, z_vals) + + +@torch.library.custom_op("torchtrt_ex::elementwise_scale_mul", mutates_args=()) # type: ignore[misc] +def elementwise_scale_mul( + X: torch.Tensor, Y: torch.Tensor, b: float = 0.2, a: int = 2 +) -> torch.Tensor: + # Ensure the tensors are on the GPU + assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device." + assert X.shape == Y.shape, "Tensors must have the same shape." + + # Create output tensor + Z = torch.empty_like(X) + + # Define block size + BLOCK_SIZE = 1024 + + # Grid of programs + grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],) + + # Launch the kernel with parameters a and b + elementwise_scale_mul_kernel[grid](X, Y, Z, a, b, BLOCK_SIZE=BLOCK_SIZE) - @triton.jit - def elementwise_scale_mul_kernel(X, Y, Z, a, b, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(0) - # Compute the range of elements that this thread block will work on - block_start = pid * BLOCK_SIZE - # Range of indices this thread will handle - offsets = block_start + tl.arange(0, BLOCK_SIZE) - # Load elements from the X and Y tensors - x_vals = tl.load(X + offsets) - y_vals = tl.load(Y + offsets) - # Perform the element-wise multiplication - z_vals = x_vals * y_vals * a + b - # Store the result in Z - tl.store(Z + offsets, z_vals) - - @torch.library.custom_op("torchtrt_ex::elementwise_scale_mul", mutates_args=()) # type: ignore[misc] - def elementwise_scale_mul( - X: torch.Tensor, Y: torch.Tensor, b: float = 0.2, a: int = 2 - ) -> torch.Tensor: - # Ensure the tensors are on the GPU - assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device." - assert X.shape == Y.shape, "Tensors must have the same shape." - - # Create output tensor - Z = torch.empty_like(X) - - # Define block size - BLOCK_SIZE = 1024 - - # Grid of programs - grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],) - - # Launch the kernel with parameters a and b - elementwise_scale_mul_kernel[grid](X, Y, Z, a, b, BLOCK_SIZE=BLOCK_SIZE) - - return Z - - @torch.library.register_fake("torchtrt_ex::elementwise_scale_mul") - def _(x: torch.Tensor, y: torch.Tensor, b: float = 0.2, a: int = 2) -> torch.Tensor: - return x + return Z + +@torch.library.register_fake("torchtrt_ex::elementwise_scale_mul") +def _(x: torch.Tensor, y: torch.Tensor, b: float = 0.2, a: int = 2) -> torch.Tensor: + return x + + +if not torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx: torch_tensorrt.dynamo.conversion.plugins.custom_op( "torchtrt_ex::elementwise_scale_mul", supports_dynamic_shapes=True ) + +@unittest.skipIf( + torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx, + "TensorRT RTX does not support plugins", +) +class TestAutomaticPlugin(DispatchTestCase): + @parameterized.expand( [ ((64, 64), torch.float), diff --git a/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py b/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py index 6068a002d1..d85c8a633f 100644 --- a/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py +++ b/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py @@ -11,10 +11,26 @@ from ..conversion.harness import DispatchTestCase -# flashinfer has been impacted by torch upstream change: https://github.com/pytorch/pytorch/commit/660b0b8128181d11165176ea3f979fa899f24db1 -# got ImportError: cannot import name '_get_pybind11_abi_build_flags' from 'torch.utils.cpp_extension' -# if importlib.util.find_spec("flashinfer"): -# import flashinfer +if importlib.util.find_spec("flashinfer"): + import flashinfer + + +@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc] +def flashinfer_rmsnorm( + input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: + return flashinfer.norm.rmsnorm(input, weight) + + +@torch.library.register_fake("flashinfer::rmsnorm") +def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor: + return input + + +if not torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx: + torch_tensorrt.dynamo.conversion.plugins.custom_op( + "flashinfer::rmsnorm", supports_dynamic_shapes=True + ) @unittest.skip("Not Available") @@ -25,20 +41,6 @@ ) class TestAutomaticPlugin(DispatchTestCase): - @torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc] - def flashinfer_rmsnorm( - input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 - ) -> torch.Tensor: - return flashinfer.norm.rmsnorm(input, weight) - - @torch.library.register_fake("flashinfer::rmsnorm") - def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor: - return input - - torch_tensorrt.dynamo.conversion.plugins.custom_op( - "flashinfer::rmsnorm", supports_dynamic_shapes=True - ) - @parameterized.expand( [ ((64, 64), (64,), torch.float16),