diff --git a/.github/scripts/filter-matrix.py b/.github/scripts/filter-matrix.py index 3710539f59..69ee24080a 100644 --- a/.github/scripts/filter-matrix.py +++ b/.github/scripts/filter-matrix.py @@ -3,8 +3,9 @@ import argparse import json import sys +from typing import List -disabled_python_versions = "3.13" +disabled_python_versions: List[str] = [] def main(args: list[str]) -> None: diff --git a/.github/scripts/generate-release-matrix.py b/.github/scripts/generate-release-matrix.py index 2a84d2bfb2..4b232026de 100644 --- a/.github/scripts/generate-release-matrix.py +++ b/.github/scripts/generate-release-matrix.py @@ -9,7 +9,7 @@ "tarball": ["cu128"], } RELEASE_PYTHON_VERSION = { - "wheel": ["3.9", "3.10", "3.11", "3.12"], + "wheel": ["3.9", "3.10", "3.11", "3.12", "3.13"], "tarball": ["3.11"], } diff --git a/.github/scripts/generate-tensorrt-test-matrix.py b/.github/scripts/generate-tensorrt-test-matrix.py index 546116d7c2..02b0f746ca 100644 --- a/.github/scripts/generate-tensorrt-test-matrix.py +++ b/.github/scripts/generate-tensorrt-test-matrix.py @@ -28,6 +28,10 @@ # please update the future tensorRT version you want to test here TENSORRT_VERSIONS_DICT = { "windows": { + "10.3.0": { + "urls": "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/zip/TensorRT-10.3.0.26.Windows.win10.cuda-12.5.zip", + "strip_prefix": "TensorRT-10.3.0.26", + }, "10.7.0": { "urls": "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.7.0/zip/TensorRT-10.7.0.23.Windows.win10.cuda-12.6.zip", "strip_prefix": "TensorRT-10.7.0.23", @@ -42,6 +46,10 @@ }, }, "linux": { + "10.3.0": { + "urls": "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/tars/TensorRT-10.3.0.26.Linux.x86_64-gnu.cuda-12.5.tar.gz", + "strip_prefix": "TensorRT-10.3.0.26", + }, "10.7.0": { "urls": "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.7.0/tars/TensorRT-10.7.0.23.Linux.x86_64-gnu.cuda-12.6.tar.gz", "strip_prefix": "TensorRT-10.7.0.23", diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 61a7894dd9..82267a185f 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -18,15 +18,16 @@ import sys from typing import Any, Callable, Dict, List, Optional, Tuple +PYTHON_VERSIONS_FOR_PR_BUILD = ["3.11"] PYTHON_ARCHES_DICT = { - "nightly": ["3.9", "3.10", "3.11", "3.12"], - "test": ["3.9", "3.10", "3.11", "3.12"], - "release": ["3.9", "3.10", "3.11", "3.12"], + "nightly": ["3.9", "3.10", "3.11", "3.12", "3.13"], + "test": ["3.9", "3.10", "3.11", "3.12", "3.13"], + "release": ["3.9", "3.10", "3.11", "3.12", "3.13"], } CUDA_ARCHES_DICT = { "nightly": ["11.8", "12.6", "12.8"], "test": ["11.8", "12.6", "12.8"], - "release": ["11.8", "12.6", "12.8"], + "release": ["11.8", "12.4", "12.6"], } ROCM_ARCHES_DICT = { "nightly": ["6.1", "6.2"], @@ -58,8 +59,8 @@ CURRENT_NIGHTLY_VERSION = "2.7.0" -CURRENT_CANDIDATE_VERSION = "2.5.1" -CURRENT_STABLE_VERSION = "2.5.1" +CURRENT_CANDIDATE_VERSION = "2.7.0" +CURRENT_STABLE_VERSION = "2.6.0" CURRENT_VERSION = CURRENT_STABLE_VERSION # By default use Nightly for CUDA arches @@ -422,11 +423,6 @@ def generate_wheels_matrix( # Define default python version python_versions = list(PYTHON_ARCHES) - # If the list of python versions is set explicitly by the caller, stick with it instead - # of trying to add more versions behind the scene - if channel == NIGHTLY and (os in (LINUX, MACOS_ARM64, LINUX_AARCH64)): - python_versions += ["3.13"] - if os == LINUX: # NOTE: We only build manywheel packages for linux package_type = "manywheel" @@ -456,7 +452,7 @@ def generate_wheels_matrix( arches += [XPU] if limit_pr_builds: - python_versions = [python_versions[0]] + python_versions = PYTHON_VERSIONS_FOR_PR_BUILD global WHEEL_CONTAINER_IMAGES diff --git a/.github/workflows/build-test-linux.yml b/.github/workflows/build-test-linux.yml index 024afd8c62..6c32db5f91 100644 --- a/.github/workflows/build-test-linux.yml +++ b/.github/workflows/build-test-linux.yml @@ -23,7 +23,6 @@ jobs: test-infra-ref: main with-rocm: false with-cpu: false - python-versions: '["3.11", "3.12", "3.10", "3.9"]' filter-matrix: needs: [generate-matrix] @@ -143,6 +142,7 @@ jobs: python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/ python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_automatic_plugin.py python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_automatic_plugin_with_attrs.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_flashinfer_rmsnorm.py popd tests-py-dynamo-fe: @@ -173,7 +173,13 @@ jobs: cd tests/py python -m pip install -r requirements.txt cd dynamo - python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/ + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_models.xml --ir dynamo models/test_models.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_models_dynamic.xml --ir dynamo models/test_dyn_models.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/engine_cache.xml --ir dynamo models/test_engine_cache.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dtype_support.xml --ir dynamo models/test_dtype_support.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/model_refit.xml --ir dynamo models/test_model_refit.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/modelopt_models.xml --ir dynamo models/test_modelopt_models.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/weight_stripped_engine.xml --ir dynamo models/test_weight_stripped_engine.py popd tests-py-dynamo-serde: @@ -206,6 +212,7 @@ jobs: cd dynamo python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/reexport_test_results.xml --ir dynamo models/test_reexport.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_kwargs_serde_test_results.xml --ir dynamo models/test_export_kwargs_serde.py popd tests-py-torch-compile-be: diff --git a/.github/workflows/build-test-windows.yml b/.github/workflows/build-test-windows.yml index f78218e75d..2ee31b4b74 100644 --- a/.github/workflows/build-test-windows.yml +++ b/.github/workflows/build-test-windows.yml @@ -23,7 +23,6 @@ jobs: test-infra-ref: main with-rocm: false with-cpu: false - python-versions: '["3.11", "3.12", "3.10", "3.9"]' substitute-runner: needs: generate-matrix diff --git a/.github/workflows/docgen.yml b/.github/workflows/docgen.yml index b21a4ffc87..a597a80610 100644 --- a/.github/workflows/docgen.yml +++ b/.github/workflows/docgen.yml @@ -35,7 +35,7 @@ jobs: - name: Install base deps run: | python3 -m pip install pip --upgrade - python3 -m pip install pyyaml numpy torch --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 + python3 -m pip install pyyaml numpy torch --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu128 ./packaging/pre_build_script.sh - name: Get HEAD SHA id: vars @@ -44,7 +44,7 @@ jobs: env: USE_PRE_CXX11_ABI: 0 run: | - python3 -m pip install --pre . --extra-index-url https://download.pytorch.org/whl/nightly/cu126 + python3 -m pip install --pre . --extra-index-url https://download.pytorch.org/whl/nightly/cu128 - name: Generate New Docs run: | cd docsrc diff --git a/MODULE.bazel b/MODULE.bazel index 07d6c4c129..dd6db1c5ae 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -55,21 +55,21 @@ http_archive( name = "libtorch", build_file = "@//third_party/libtorch:BUILD", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/nightly/cu128/libtorch-cxx11-abi-shared-with-deps-latest.zip"], + urls = ["https://download.pytorch.org/libtorch/test/cu128/libtorch-cxx11-abi-shared-with-deps-latest.zip"], ) http_archive( name = "libtorch_pre_cxx11_abi", build_file = "@//third_party/libtorch:BUILD", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/nightly/cu128/libtorch-shared-with-deps-latest.zip"], + urls = ["https://download.pytorch.org/libtorch/test/cu126/libtorch-shared-with-deps-latest.zip"], ) http_archive( name = "libtorch_win", build_file = "@//third_party/libtorch:BUILD", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/nightly/cu128/libtorch-win-shared-with-deps-latest.zip"], + urls = ["https://download.pytorch.org/libtorch/test/cu128/libtorch-win-shared-with-deps-latest.zip"], ) # Download these tarballs manually from the NVIDIA website diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 9f93fe4b4e..9a04aba6de 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -453,6 +453,10 @@ std::vector TRTEngine::serialize() { return serialized_info; } +void TRTEngine::reset_captured_graph() { + cudagraph.reset(); +} + } // namespace runtime } // namespace core } // namespace torch_tensorrt diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index e9b1905610..2db640b6b1 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -185,6 +185,7 @@ struct TRTEngine : torch::CustomClassHolder { // c10::List Run(c10::List inputs); void set_profiling_paths(); + void reset_captured_graph(); #ifndef NDEBUG bool profile_execution = true; #else diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index c05be4e8aa..cbe19b0af6 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -88,6 +88,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = .def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info) .def("get_engine_layer_info", &TRTEngine::get_engine_layer_info) .def("infer_outputs", &TRTEngine::infer_outputs) + .def("reset_captured_graph", &TRTEngine::reset_captured_graph) .def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs) .def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs) .def_property( diff --git a/docker/dist-build.sh b/docker/dist-build.sh index 00ce6882c1..214b7b3b19 100755 --- a/docker/dist-build.sh +++ b/docker/dist-build.sh @@ -5,9 +5,9 @@ set -x TOP_DIR=$(cd $(dirname $0); pwd)/.. if [[ -z "${USE_PRE_CXX11}" ]]; then - BUILD_CMD="python -m pip wheel . --extra-index-url https://download.pytorch.org/whl/nightly/cu124 -w dist" + BUILD_CMD="python -m pip wheel . --extra-index-url https://download.pytorch.org/whl/test/cu128 -w dist" else - BUILD_CMD="python -m pip wheel . --config-setting="--build-option=--use-pre-cxx11-abi" --extra-index-url https://download.pytorch.org/whl/nightly/cu124 -w dist" + BUILD_CMD="python -m pip wheel . --config-setting="--build-option=--use-pre-cxx11-abi" --extra-index-url https://download.pytorch.org/whl/test/cu128 -w dist" fi # TensorRT restricts our pip version diff --git a/examples/distributed_inference/llama3_model.py b/examples/distributed_inference/llama3_model.py deleted file mode 100644 index 9fa59b5c49..0000000000 --- a/examples/distributed_inference/llama3_model.py +++ /dev/null @@ -1,538 +0,0 @@ -# Taken and modified pytorch lightening -# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning - - -from dataclasses import dataclass -from typing import Any, Optional, Tuple - -import torch -import torch.nn.functional as F -from torch import nn -from torch.distributed._tensor import Replicate, Shard -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor.parallel import ( - ColwiseParallel, - PrepareModuleInput, - RowwiseParallel, - SequenceParallel, - parallelize_module, -) - - -@dataclass -class ModelArgs: - dim: int = 4096 - n_layers: int = 32 - n_heads: int = 32 - n_kv_heads: Optional[int] = None - vocab_size: int = -1 # defined later by tokenizer - multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 - ffn_dim_multiplier: Optional[float] = None - norm_eps: float = 1e-5 - rope_theta: float = 10000 - - max_batch_size: int = 32 - max_seq_len: int = 2048 - # If `True`, then each transformer block init uses its layer ID, and if - # `False`, each uses the total number of transformer blocks - depth_init: bool = True - device: str = "cuda" - - -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: - """Precompute the frequency tensor for complex exponentials (cis) with given dimensions. - - This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' - and the end index 'end'. The 'theta' parameter scales the frequencies. - The returned tensor contains complex values in complex64 data type. - - Args: - dim (int): Dimension of the frequency tensor. - end (int): End index for precomputing frequencies. - theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. - - Returns: - torch.Tensor: Precomputed frequency tensor with complex exponentials. - - """ - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device) - freqs = torch.outer(t, freqs).float() - return torch.polar(torch.ones_like(freqs), freqs) # complex64 - - -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - """Reshape frequency tensor for broadcasting it with another tensor. - - This function reshapes the frequency tensor to have the same shape as the target tensor 'x' - for the purpose of broadcasting the frequency tensor during element-wise operations. - - The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), - and the first seqlen elements will be sliced, but dim must match x. - - Args: - freqs_cis (torch.Tensor): Frequency tensor to be reshaped. - x (torch.Tensor): Target tensor for broadcasting compatibility. - - Returns: - torch.Tensor: Reshaped frequency tensor. - - """ - ndim = x.ndim - assert 0 <= 1 < ndim - seqlen = x.shape[1] - freqs_cis = freqs_cis[0:seqlen] - assert freqs_cis.shape == (seqlen, x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Apply rotary embeddings to input tensors using the given frequency tensor. - - This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided - frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor - is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are - returned as real tensors. - - Args: - xq (torch.Tensor): Query tensor to apply rotary embeddings. - xk (torch.Tensor): Key tensor to apply rotary embeddings. - freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. - - """ - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" - bs, slen, n_kv_heads, head_dim = x.shape - if n_rep == 1: - return x - return ( - x[:, :, :, None, :] - .expand(bs, slen, n_kv_heads, n_rep, head_dim) - .reshape(bs, slen, n_kv_heads * n_rep, head_dim) - ) - - -class RMSNorm(nn.Module): - """Initialize the RMSNorm normalization layer. - - Args: - dim (int): The dimension of the input tensor. - eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. - - Attributes: - eps (float): A small value added to the denominator for numerical stability. - weight (nn.Parameter): Learnable scaling parameter. - - """ - - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x: torch.Tensor): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x: torch.Tensor): - output = self._norm(x.float()).type_as(x) - return output * self.weight - - def reset_parameters(self): - torch.nn.init.ones_(self.weight) # type: ignore - - -class Attention(nn.Module): - """Multi-head attention module. - - Args: - model_args (ModelArgs): Model configuration arguments. - - Attributes: - n_kv_heads (int): Number of key and value heads. - n_heads (int): Number of query heads. - n_rep (int): Number of repetitions for local heads. - head_dim (int): Dimension size of each attention head. - wq (Linear): Linear transformation for queries. - wk (Linear): Linear transformation for keys. - wv (Linear): Linear transformation for values. - wo (Linear): Linear transformation for output. - - """ - - def __init__(self, model_args: ModelArgs): - super().__init__() - self.n_heads = model_args.n_heads - self.n_kv_heads = ( - model_args.n_heads - if model_args.n_kv_heads is None - else model_args.n_kv_heads - ) - self.n_rep = self.n_heads // self.n_kv_heads - self.head_dim = model_args.dim // model_args.n_heads - - self.wq = nn.Linear( - model_args.dim, model_args.n_heads * self.head_dim, bias=False - ) - self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wo = nn.Linear( - model_args.n_heads * self.head_dim, model_args.dim, bias=False - ) - - def init_weights(self, init_std: float) -> None: - for linear in (self.wq, self.wk, self.wv): - nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) - nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) - - def forward( - self, - x: torch.Tensor, - freqs_cis: torch.Tensor, - ) -> Any: - """Forward pass of the attention module. - - Args: - x (torch.Tensor): Input tensor. - freqs_cis (torch.Tensor): Precomputed frequency tensor. - - Returns: - torch.Tensor: Output tensor after attention. - - """ - bs, seqlen, _ = x.shape - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - - xq = xq.view(bs, seqlen, self.n_heads, self.head_dim) - xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim) - xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim) - - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - - # repeat k/v heads if n_kv_heads < n_heads - keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - - xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - - # we use casual mask for training - output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) - output = output.transpose( - 1, 2 - ).contiguous() # (bs, seqlen, n_local_heads, head_dim) - output = output.view(bs, seqlen, -1) - return self.wo(output) - - -class FeedForward(nn.Module): - """FeedForward module. - - Args: - dim (int): Input dimension. - hidden_dim (int): Hidden dimension of the feedforward layer. - multiple_of (int): Value to ensure hidden dimension is a multiple of this value. - ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None. - - Attributes: - w1 (Linear): Linear transformation for the first layer. - w2 (Linear): Linear transformation for the second layer. - w3 (Linear): Linear transformation for the third layer. - - """ - - def __init__( - self, - dim: int, - hidden_dim: int, - multiple_of: int, - ffn_dim_multiplier: Optional[float], - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - # custom dim factor multiplier - if ffn_dim_multiplier is not None: - hidden_dim = int(ffn_dim_multiplier * hidden_dim) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - - def forward(self, x) -> Any: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) - - def init_weights(self, init_std: float) -> None: - nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) - for linear in (self.w2, self.w3): - nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) - - -class TransformerBlock(nn.Module): - """TransformerBlock Module. - - Args: - layer_id (int): Identifier for the layer. - model_args (ModelArgs): Model configuration arguments. - - Attributes: - n_heads (int): Number of attention heads. - dim (int): Dimension size of the model. - head_dim (int): Dimension size of each attention head. - attention (Attention): Attention module. - feed_forward (FeedForward): FeedForward module. - layer_id (int): Identifier for the layer. - attention_norm (RMSNorm): Layer normalization for attention output. - ffn_norm (RMSNorm): Layer normalization for feedforward output. - - """ - - def __init__(self, layer_id: int, model_args: ModelArgs): - super().__init__() - self.n_heads = model_args.n_heads - self.dim = model_args.dim - self.attention = Attention(model_args) - self.feed_forward = FeedForward( - dim=model_args.dim, - hidden_dim=4 * model_args.dim, - multiple_of=model_args.multiple_of, - ffn_dim_multiplier=model_args.ffn_dim_multiplier, - ) - self.layer_id = layer_id - self.num_layers = model_args.n_layers - - self.attention_norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps) - self.ffn_norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps) - - if model_args.depth_init: - self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5 - else: - self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5 - - def forward( - self, - x: torch.Tensor, - freqs_cis: torch.Tensor, - ): - """Perform a forward pass through the TransformerBlock. - - Args: - x (torch.Tensor): Input tensor. - freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. - - Returns: - torch.Tensor: Output tensor after applying attention and feedforward layers. - - """ - h = x + self.attention(self.attention_norm(x), freqs_cis) - return h + self.feed_forward(self.ffn_norm(h)) - - def init_weights(self): - for norm in (self.attention_norm, self.ffn_norm): - norm.reset_parameters() - self.attention.init_weights(self.weight_init_std) - self.feed_forward.init_weights(self.weight_init_std) - - -class ParallelTransformer(nn.Module): - """Transformer Module. - - Args: - model_args (ModelArgs): Model configuration arguments. - - Attributes: - model_args (ModelArgs): Model configuration arguments. - vocab_size (int): Vocabulary size. - n_layers (int): Number of layers in the model. - tok_embeddings (ParallelEmbedding): Token embeddings. - layers (torch.nn.ModuleList): List of Transformer blocks. - norm (RMSNorm): Layer normalization for the model output. - output (ColumnParallelLinear): Linear layer for final output. - freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. - - """ - - def __init__(self, model_args: ModelArgs, tp_mesh: DeviceMesh = None): - # Here we use distributed model initialization to avoid memory overflow - super().__init__() - self.model_args = model_args - self.vocab_size = model_args.vocab_size - self.n_layers = model_args.n_layers - - self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) - self.tok_embeddings.to(model_args.device) - self.tok_embeddings = self.parallel_embeddings(self.tok_embeddings, tp_mesh) - - # TODO persistent should be set to false, since this buffer can be recomputed. - # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, - # compile or pipeline-tracer will not correctly handle non-persistent buffers, - # so we need to fix that. (2) if we initialize pipeline-parallel models from - # a seed checkpoint rather than calling init_weights, we need freqs_cis to be - # initialized by the checkpoint, or we need to add a separate initializer for - # just the non-persistent buffers that is called after loading checkpoints. - self.register_buffer( - "freqs_cis", - self._precompute_freqs_cis().to(model_args.device), - persistent=True, - ) - - self.layers = torch.nn.ModuleDict().to(model_args.device) - for layer_id in range(model_args.n_layers): - block = TransformerBlock(layer_id, model_args).to(model_args.device) - self.layers[str(layer_id)] = block - self.parallel_transformer_block(self.layers[str(layer_id)], tp_mesh) - - self.norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps).to( - model_args.device - ) - self.norm = self.parallel_norm(self.norm, tp_mesh) - self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False).to( - model_args.device - ) - self.output = self.parallel_output(self.output, tp_mesh) - self.init_weights() - - def parallel_transformer_block(self, transformer_block, tp_mesh): - if tp_mesh.size() <= 1: - return - plan = { - "attention": PrepareModuleInput( - input_layouts=(Shard(1), None), - desired_input_layouts=(Replicate(), None), - ), - "attention.wq": ColwiseParallel(), - "attention.wk": ColwiseParallel(), - "attention.wv": ColwiseParallel(), - "attention.wo": RowwiseParallel(output_layouts=Shard(1)), - "attention_norm": SequenceParallel(), - "feed_forward": PrepareModuleInput( - input_layouts=(Shard(1),), - desired_input_layouts=(Replicate(),), - ), - "feed_forward.w1": ColwiseParallel(), - "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)), - "feed_forward.w3": ColwiseParallel(), - "ffn_norm": SequenceParallel(), - } - - # Adjust attention module to use the local number of heads - attn_layer = transformer_block.attention - attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() - attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() - - # Apply the plan for the current transformer block - parallelize_module(transformer_block, tp_mesh, plan) - - def parallel_embeddings(self, embedding, tp_mesh): - plan = { - "tok_embeddings": RowwiseParallel( - input_layouts=Replicate(), - output_layouts=Shard(1), - ) - } - return parallelize_module(embedding, tp_mesh, plan) - - def parallel_output(self, output, tp_mesh): - plan = { - "output": ColwiseParallel( - input_layouts=Shard(1), - ), - } - return parallelize_module(output, tp_mesh, plan) - - def parallel_norm(self, norm, tp_mesh): - plan = { - "norm": SequenceParallel(), - } - return parallelize_module(norm, tp_mesh, plan) - - def reset_parameters(self): - with torch.device(self.freqs_cis.device): - self.freqs_cis = self._precompute_freqs_cis() - - def init_weights(self): - """[Note: On ``init_weights`` vs. - - ``reset_parameters``] - Modules may define ``reset_parameters`` to initialize parameter values. - ``reset_parameters`` is meant to only initialize directly owned - parameters/buffers, not those of their child modules, and it can be - used to give the initial values for these tensors. - Separately, users may want custom initialization for their modules, - different from that in ``reset_parameters``. For this, we define - ``init_weights``. We only call it in the constructor of this - ``Transformer`` root module to avoid reinitializing tensors. - - """ - with torch.device(self.freqs_cis.device): - self.freqs_cis = self._precompute_freqs_cis() - nn.init.normal_(self.tok_embeddings.weight) - for layer in self.layers.values(): - layer.init_weights() - self.norm.reset_parameters() - final_out_std = self.model_args.dim**-0.5 - cutoff_factor = 3 - nn.init.trunc_normal_( - self.output.weight, - mean=0.0, - std=final_out_std, - a=-cutoff_factor * final_out_std, - b=cutoff_factor * final_out_std, - ) - - def _precompute_freqs_cis(self) -> torch.Tensor: - return precompute_freqs_cis( - self.model_args.dim // self.model_args.n_heads, - # Need to compute until at least the max token limit for generation - # (use 2x max sequence length to be safe) - self.model_args.max_seq_len * 2, - self.model_args.rope_theta, - ) - - def forward(self, tokens: torch.Tensor): - """Perform a forward pass through the Transformer model. - - Args: - tokens (torch.Tensor): Input token indices. - - Returns: - torch.Tensor: Output logits after applying the Transformer model. - - """ - # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages - h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens - - for layer in self.layers.values(): - h = layer(h, self.freqs_cis) - - h = self.norm(h) if self.norm else h - return self.output(h).float() if self.output else h - - @classmethod - def from_model_args(cls, model_args: ModelArgs) -> "Transformer": - """Initialize a Transformer model from a ModelArgs object. - - Args: - model_args (ModelArgs): Model configuration arguments. - - Returns: - Transformer: Transformer model. - - """ - return cls(model_args) diff --git a/examples/distributed_inference/tensor_parallel_llama3.py b/examples/distributed_inference/tensor_parallel_llama3.py deleted file mode 100644 index 998c378be2..0000000000 --- a/examples/distributed_inference/tensor_parallel_llama3.py +++ /dev/null @@ -1,70 +0,0 @@ -# Taken and modified pytorch lightening -# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning -import logging -import os -import time - -import torch -from llama3_model import ModelArgs, ParallelTransformer -from tensor_parallel_initialize_dist import initialize_distributed_env -from torch.distributed._composable.fsdp import MixedPrecisionPolicy -from torch.distributed._composable.fsdp.fully_shard import fully_shard -from torch.distributed._tensor import Replicate, Shard -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - checkpoint_wrapper, -) - -device_mesh, _world_size, _rank, logger = initialize_distributed_env( - "./tensor_parallel_llama3" -) -# Import should be after initialization of the TRT-LLM plugin .so path -import tensorrt_llm - -logger.info(f"Starting PyTorch TP example on rank {_rank}.") -assert ( - _world_size % 2 == 0 -), f"TP examples require even number of GPUs, but got {_world_size} gpus" - -model_args = ModelArgs( - vocab_size=32000, - dim=1024, - n_layers=4, - n_heads=8, - rope_theta=500000.0, - n_kv_heads=8, - device="cuda", -) - -with torch.no_grad(): - model = ParallelTransformer(model_args, device_mesh) - torch.manual_seed(0) - inp = torch.randint(32000, (8, 256), device="cuda") - python_result = model(inp) - torch_tensorrt.runtime.set_multi_device_safe_mode(True) - model = torch.compile( - model, - fullgraph=True, - backend="torch_tensorrt", - options={ - "truncate_long_and_double": True, - "enabled_precisions": {torch.float32, torch.float16}, - "use_python_runtime": True, - "workspace_size": 1 << 33, - "debug": False, - "use_aot_joint_export": False, - }, - dynamic=False, - ) - for i in range(15): - # seeding with dp_rank to ensure identical inputs for TP groups - torch.manual_seed(i) - start = time.time() - output = model(inp) - end = time.time() - if i == 0: - logger.info(f"Compilation time is {end-start}") - assert ( - python_result - output - ).std() < 0.01, "Compilation result is not correct." - elif _rank == 0: - logger.info(f"Inference time is {end-start}") diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index 837648fdb4..d2e3c590c6 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -2,6 +2,7 @@ import tensorrt as trt import torch +import torch.distributed as dist import torch.nn as nn import torch_tensorrt from tensor_parallel_initialize_dist import initialize_distributed_env @@ -15,7 +16,6 @@ device_mesh, _world_size, _rank, logger = initialize_distributed_env( "./tensor_parallel_simple_example" ) -import tensorrt_llm """ This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py @@ -65,7 +65,6 @@ def forward(self, x): inp = torch.rand(20, 10, device="cuda") python_result = tp_model(inp) - backend = "torch_tensorrt" tp_model = torch.compile( tp_model, @@ -75,23 +74,28 @@ def forward(self, x): "enabled_precisions": {torch.float32, torch.float16}, "use_python_runtime": True, "min_block_size": 1, - "use_aot_joint_export": False, + "use_distributed_mode_trace": True, }, - dynamic=False, + dynamic=None, ) -for i in range(10): - # For TP, input needs to be same across all TP ranks. - # Setting the random seed is to mimic the behavior of dataloader. - torch.manual_seed(i) - inp = torch.rand(20, 10, device="cuda") - start = time.time() - output = tp_model(inp) - end = time.time() - if i == 0: - logger.info(f"Compilation time is {end-start}") - assert ( - python_result - output - ).std() < 0.01, "Compilation result is not correct." - elif _rank == 0: - logger.info(f"Inference time is {end-start}") +try: + for i in range(10): + # For TP, input needs to be same across all TP ranks. + # Setting the random seed is to mimic the behavior of dataloader. + torch.manual_seed(i) + inp = torch.rand(20, 10, device="cuda") + start = time.time() + output = tp_model(inp) + end = time.time() + if i == 0: + logger.info(f"Compilation time is {end-start}") + assert ( + python_result - output + ).std() < 0.01, "Compilation result is not correct." + elif _rank == 0: + logger.info(f"Inference time is {end-start}") +finally: + # This cleans up the distributed process group + if dist.is_initialized(): + dist.destroy_process_group() diff --git a/examples/dynamo/llama2_flashinfer_rmsnorm.py b/examples/dynamo/llama2_flashinfer_rmsnorm.py new file mode 100644 index 0000000000..847d80238b --- /dev/null +++ b/examples/dynamo/llama2_flashinfer_rmsnorm.py @@ -0,0 +1,241 @@ +from typing import Callable, Optional, Sequence, Union + +import flashinfer +import torch +import torch_tensorrt +from torch.fx.passes.shape_prop import TensorMetadata +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + _aten_lowering_pass, +) +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) +from transformers import LlamaConfig, LlamaForCausalLM + + +@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc] +def flashinfer_rmsnorm( + input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: + return flashinfer.norm.rmsnorm(input, weight) + + +@torch.library.register_fake("flashinfer::rmsnorm") +def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor: + return input + + +torch_tensorrt.dynamo.conversion.plugins.custom_op( + "flashinfer::rmsnorm", supports_dynamic_shapes=True +) + + +@_aten_lowering_pass +def replace_rmsnorm( + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] +) -> torch.fx.GraphModule: + for node in gm.graph.nodes: + if ( + node.target == torch.ops.aten._to_copy.default + and node.kwargs.get("dtype") is torch.float32 + and len(node.users) == 2 + ): + if ( + list(node.users)[0].target == torch.ops.aten.pow.Tensor_Scalar + and list(node.users)[1].target == torch.ops.aten.mul.Tensor + ): + pow_node = list(node.users)[0] + if ( + len(pow_node.users) == 1 + and list(pow_node.users)[0].target == torch.ops.aten.mean.dim + ): + mean_node = list(pow_node.users)[0] + if ( + len(mean_node.users) == 1 + and list(mean_node.users)[0].target == torch.ops.aten.add.Tensor + ): + add_node = list(mean_node.users)[0] + if ( + len(add_node.users) == 1 + and list(add_node.users)[0].target + == torch.ops.aten.sqrt.default + ): + sqrt_node = list(add_node.users)[0] + if ( + len(sqrt_node.users) == 1 + and list(sqrt_node.users)[0].target + == torch.ops.aten.div.Tensor + ): + div_node = list(sqrt_node.users)[0] + if list(div_node.users)[0] == list(node.users)[1]: + mul_node = list(div_node.users)[0] + copy_node = list(mul_node.users)[0] + weight_mul_node = list(copy_node.users)[0] + + weight = weight_mul_node.args[0] + + original_meta = weight_mul_node.meta.get( + "tensor_meta", {} + ) + memory_format = original_meta.memory_format + + with gm.graph.inserting_after(weight_mul_node): + b = gm.graph.create_node( + op="call_function", + target=torch.ops.aten.sym_size.int, + args=(node.args[0], 0), + ) + b.meta["tensor_meta"] = TensorMetadata( + shape=torch.Size([]), + dtype=torch.int64, + requires_grad=False, + stride=None, + memory_format=memory_format, + is_quantized=False, + qparams={}, + ) + s = gm.graph.create_node( + op="call_function", + target=torch.ops.aten.sym_size.int, + args=(node.args[0], 1), + ) + s.meta.update(b.meta) + + d = gm.graph.create_node( + op="call_function", + target=torch.ops.aten.sym_size.int, + args=(node.args[0], 2), + ) + d.meta.update(b.meta) + + with gm.graph.inserting_after(b): + new_first_dim = gm.graph.create_node( + op="call_function", + target=torch.ops.aten.mul.Scalar, + args=(b, s), + ) + new_first_dim.meta.update(b.meta) + + with gm.graph.inserting_after(new_first_dim): + # with gm.graph.inserting_after(weight_mul_node): + reshape_node = gm.graph.create_node( + op="call_function", + target=torch.ops.aten.reshape.default, + args=(node.args[0], [new_first_dim, d]), + ) + b_val = original_meta.shape[0] + s_val = original_meta.shape[1] + d_val = original_meta.shape[2] + + reshape_node.meta["tensor_meta"] = ( + TensorMetadata( + shape=torch.Size( + [b_val * s_val, d_val] + ), + dtype=original_meta.dtype, + requires_grad=True, + stride=None, + memory_format=memory_format, + is_quantized=False, + qparams={}, + ) + ) + + with gm.graph.inserting_after(reshape_node): + flashinfer_rmsnorm_node = gm.graph.create_node( + op="call_function", + target=torch.ops.flashinfer.rmsnorm.default, + args=( + reshape_node, + weight, + add_node.args[1], + ), + ) + flashinfer_rmsnorm_node.meta.update( + reshape_node.meta + ) + + with gm.graph.inserting_after( + flashinfer_rmsnorm_node + ): + reshapback_node = gm.graph.create_node( + op="call_function", + target=torch.ops.aten.reshape.default, + args=( + flashinfer_rmsnorm_node, + [b, s, d], + ), + ) + + weight_mul_node.replace_all_uses_with( + reshapback_node + ) + reshapback_node.meta.update(weight_mul_node.meta) + + modified_graph = True + + gm.graph.erase_node(weight_mul_node) + gm.graph.erase_node(copy_node) + gm.graph.erase_node(mul_node) + gm.graph.erase_node(div_node) + gm.graph.erase_node(sqrt_node) + gm.graph.erase_node(add_node) + gm.graph.erase_node(mean_node) + gm.graph.erase_node(pow_node) + gm.graph.erase_node(node) + + if modified_graph: + gm = clean_up_graph_after_modifications(gm) + + return gm + + +# 1. Create a custom config with 1 layer +config = LlamaConfig( + vocab_size=32000, + hidden_size=4096, # LLaMA2-7B dimensions + intermediate_size=11008, # FFN hidden_dim = 4 * 4096 * 0.7 (SwiGLU scaling) + num_hidden_layers=1, # Only 1 decoder layer + num_attention_heads=32, + max_position_embeddings=4096, + use_cache=False, # Disable KV caching for export +) + +# 2. Initialize model (random weights) +with torch.no_grad(): + model = LlamaForCausalLM(config).eval().half() + +# 3. Export with static shapes +input_ids = torch.randint(0, 32000, (1, 64)) # Static [batch=1, seq=64] +exported = torch.export.export( + model, + (input_ids,), + dynamic_shapes=None, # Fully static +) + +# Test forward pass +input_ids = torch.randint(0, 32000, (1, 64)) +output = model(input_ids) +print(output) + +# Export validation + +DEVICE = torch.device("cuda:0") + +with torch_tensorrt.logging.errors(): + trt_model = torch_tensorrt.dynamo.compile( + exported, + inputs=[input_ids], + enabled_precisions={torch.float32, torch.float16}, + truncate_double=True, + device=DEVICE, + disable_tf32=True, + use_explicit_typing=False, + use_fp32_acc=True, + # debug=True, + ) + +input_ids = input_ids.to(DEVICE) + +res = trt_model.forward(input_ids) +print(res) diff --git a/py/requirements.txt b/py/requirements.txt index 5644656330..6a4d3aa90c 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -1,8 +1,8 @@ numpy packaging pybind11==2.6.2 ---extra-index-url https://download.pytorch.org/whl/nightly/cu124 -torch>=2.7.0.dev,<2.8.0 -torchvision>=0.22.0.dev,<0.23.0 +--extra-index-url https://download.pytorch.org/whl/test/cu128 +torch>=2.7.0,<2.8.0 +torchvision>=0.22.0,<0.23.0 --extra-index-url https://pypi.ngc.nvidia.com pyyaml diff --git a/py/torch_tensorrt/_features.py b/py/torch_tensorrt/_features.py index 8da7ac6fff..bee0c3dbf0 100644 --- a/py/torch_tensorrt/_features.py +++ b/py/torch_tensorrt/_features.py @@ -14,6 +14,7 @@ "torch_tensorrt_runtime", "dynamo_frontend", "fx_frontend", + "refit", ], ) @@ -36,9 +37,10 @@ _TORCHTRT_RT_AVAIL = _TS_FE_AVAIL or os.path.isfile(linked_file_runtime_full_path) _DYNAMO_FE_AVAIL = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev") _FX_FE_AVAIL = True +_REFIT_AVAIL = True ENABLED_FEATURES = FeatureSet( - _TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL + _TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL, _REFIT_AVAIL ) @@ -62,6 +64,22 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: return wrapper +def needs_refit(f: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: + if ENABLED_FEATURES.refit: + return f(*args, **kwargs) + else: + + def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: + raise NotImplementedError( + "Refit feature is currently not available in Python 3.13 or higher" + ) + + return not_implemented(*args, **kwargs) + + return wrapper + + T = TypeVar("T") diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 6928347baa..acd16a32f0 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -1206,7 +1206,7 @@ def save_cross_compiled_exported_program( from torch_tensorrt.dynamo._exporter import export - exp_program = export(gm, cross_compile_flag=True) + exp_program = export(gm, cross_compile_module=True) torch.export.save(exp_program, file_path) logger.debug(f"successfully saved the module for windows at {file_path}") diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index ba404a4102..379a196e2e 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -46,9 +46,9 @@ IMMUTABLE_WEIGHTS = True ENABLE_WEIGHT_STREAMING = False ENABLE_CROSS_COMPILE_FOR_WINDOWS = False -USE_AOT_JOINT_EXPORT = True TILING_OPTIMIZATION_LEVEL = "none" L2_LIMIT_FOR_TILING = -1 +USE_DISTRIBUTED_MODE_TRACE = False def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_exporter.py b/py/torch_tensorrt/dynamo/_exporter.py index f2d4cfee88..17e0ad4561 100644 --- a/py/torch_tensorrt/dynamo/_exporter.py +++ b/py/torch_tensorrt/dynamo/_exporter.py @@ -22,23 +22,23 @@ def export( gm: torch.fx.GraphModule, - cross_compile_flag: Optional[bool] = False, + cross_compile_module: Optional[bool] = False, ) -> ExportedProgram: """Export the result of TensorRT compilation into the desired output format. Arguments: gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile`` inputs (torch.Tensor): Torch input tensors - cross_compile_flag (bool): Flag to indicated whether it is cross_compilation enabled or not + cross_compile_module (bool): Flag to indicated whether it is cross_compilation enabled or not """ - patched_module = transform(gm, cross_compile_flag) + patched_module = transform(gm, cross_compile_module) exp_program = create_trt_exp_program(patched_module) return exp_program def transform( gm: torch.fx.GraphModule, - cross_compile_flag: Optional[bool] = False, + cross_compile_module: Optional[bool] = False, ) -> torch.fx.GraphModule: """ Transforms the graphmodule by inlining Pytorch and TensorRT submodules. @@ -48,7 +48,7 @@ def transform( Arguments: gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile`` inputs (torch.Tensor): Torch input tensors - cross_compile_flag (bool): Flag to indicated whether it is cross_compilation enabled or not + cross_compile_module (bool): Flag to indicated whether it is cross_compilation enabled or not Returns an inlined torch.fx.GraphModule """ @@ -57,7 +57,7 @@ def transform( gm = copy.deepcopy(gm) # Inline TensorRT submodules - inline_trt_modules(gm, cross_compile_flag) + inline_trt_modules(gm, cross_compile_module) # Inline pytorch submodules inline_torch_modules(gm) @@ -356,7 +356,7 @@ def create_trt_exp_program( def inline_trt_modules( - gm: torch.fx.GraphModule, cross_compile_flag: Optional[bool] = False + gm: torch.fx.GraphModule, cross_compile_module: Optional[bool] = False ) -> torch.fx.GraphModule: """ Replace TRT submodules with trt engine nodes. @@ -380,7 +380,16 @@ def inline_trt_modules( num_outputs = len(trt_module_node.meta["val"]) # Insert a call_function node to perform inference on TRT engine with gm.graph.inserting_before(trt_module_node): - if not cross_compile_flag: + if cross_compile_module: + engine_info = trt_module._pack_engine_info() + engine_bytes = engine_info[ENGINE_IDX] + engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes).decode("utf-8") + # insert the no_placeholder node in the graph which should be replaced to the actual execute_engine node while load in the windows + trt_node = gm.graph.call_function( + torch.ops.tensorrt.no_op_placeholder_for_execute_engine.default, + (trt_module_node.args, *engine_info), + ) + else: # for the normal workflow: use the execute_engine node engine_name = f"{name}_engine" setattr(gm, engine_name, trt_module.engine) @@ -396,16 +405,6 @@ def inline_trt_modules( engine_node.meta["val"] = CustomObjArgument( name=engine_node.name, class_fqn="" ) - else: - # for the cross compile for windows workflow: use the no_op_placeholder node - engine_info = trt_module._pack_engine_info() - engine_bytes = engine_info[ENGINE_IDX] - engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes).decode("utf-8") - # insert the no_placeholder node in the graph which should be replaced to the actual execute_engine node while load in the windows - trt_node = gm.graph.call_function( - torch.ops.tensorrt.no_op_placeholder_for_execute_engine.default, - (trt_module_node.args, *engine_info), - ) # set trt_node.meta with trt_module_node.meta assert num_outputs > 0 trt_node.meta["val"] = trt_module_node.meta["val"] @@ -464,16 +463,10 @@ def replace_execute_engine_no_op_node( name=engine_node.name, class_fqn="" ) - if len(no_op_placeholder_node.meta["val"]) == 1: - with gm.graph.inserting_after(trt_node): - getitem_output = gm.graph.call_function(operator.getitem, (trt_node, 0)) - getitem_output.meta["val"] = trt_node.meta["val"] - no_op_placeholder_node.replace_all_uses_with(getitem_output) - else: - no_op_placeholder_node.replace_all_uses_with(trt_node) - getitem_nodes = trt_node.users - for idx, getitem_node in enumerate(getitem_nodes): - getitem_node.meta["val"] = trt_node.meta["val"][idx] + no_op_placeholder_node.replace_all_uses_with(trt_node) + getitem_nodes = trt_node.users + for idx, getitem_node in enumerate(getitem_nodes): + getitem_node.meta["val"] = trt_node.meta["val"][idx] gm.graph.erase_node(no_op_placeholder_node) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 96fc6daad2..f217383f5c 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -9,7 +9,9 @@ import tensorrt as trt import torch from torch.export import ExportedProgram +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._enums import dtype +from torch_tensorrt._features import needs_refit from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import partitioning from torch_tensorrt.dynamo._exporter import inline_torch_modules @@ -46,6 +48,7 @@ logger = logging.getLogger(__name__) +@needs_refit # type: ignore def construct_refit_mapping( module: torch.fx.GraphModule, inputs: Sequence[Input], @@ -59,18 +62,6 @@ def construct_refit_mapping( Returns: Mapping from weight name in TensorRT to actual weight value in np.ndarray """ - MODULE_MAP = { - "SCALE": (trt.IScaleLayer, [("scale", "SCALE"), ("shift", "SHIFT")]), - "CONVOLUTION": ( - trt.IConvolutionLayer, - [("kernel", "KERNEL"), ("bias", "BIAS")], - ), - "DECONVOLUTION": ( - trt.IDeconvolutionLayer, - [("kernel", "KERNEL"), ("bias", "BIAS")], - ), - "CONSTANT": (trt.IConstantLayer, [("weights", "CONSTANT")]), - } output_dtypes = infer_module_output_dtypes( module, @@ -78,7 +69,6 @@ def construct_refit_mapping( ) # Use Interpreter - weight_map = {} interpreter = TRTInterpreter( module, inputs, @@ -87,39 +77,27 @@ def construct_refit_mapping( compilation_settings=settings, ) interpreter._construct_trt_network_def() - net = interpreter.ctx.net - for i in range(net.num_layers): - layer = net[i] - layer_type: str = layer.type.name - if layer_type in MODULE_MAP: - # Cast the parent class to child class to access attributes - # For example: ILayer does not have ILayer.kernel/ILayer.bias - # So we cast it to IConvolutionLayer and access the attributes - layer.__class__ = MODULE_MAP[layer_type][0] - for weight_type, weight_name in MODULE_MAP[layer_type][1]: - weight = layer.__getattribute__(weight_type).copy() - weight_dtype = dtype.try_from(weight.dtype).to(trt.DataType) - weight_map[f"{layer.name} {weight_name}"] = ( - weight, - weight_dtype, - ) - return weight_map + return interpreter.ctx.mapping +@needs_refit # type: ignore def construct_refit_mapping_from_weight_name_map( - weight_name_map: dict[Any, Any], state_dict: dict[Any, Any] + weight_name_map: dict[Any, Any], + state_dict: dict[Any, Any], + settings: CompilationSettings, ) -> dict[Any, Any]: engine_weight_map = {} for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items(): - trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType) - torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) - if sd_weight_name not in state_dict: # If weights is not in sd, we can leave it unchanged continue else: - engine_weight_map[engine_weight_name] = state_dict[sd_weight_name] + trt_dtype = dtype._from(np_weight_type).to(trt.DataType) + torch_dtype = dtype._from(np_weight_type).to(torch.dtype) + engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to( + to_torch_device(settings.device) + ) engine_weight_map[engine_weight_name] = ( engine_weight_map[engine_weight_name] @@ -133,6 +111,7 @@ def construct_refit_mapping_from_weight_name_map( return engine_weight_map +@needs_refit # type: ignore def _refit_single_trt_engine_with_gm( new_gm: torch.fx.GraphModule, old_engine: trt.ICudaEngine, @@ -144,73 +123,76 @@ def _refit_single_trt_engine_with_gm( Refit a TensorRT Engine in place """ - refitted = set() - torch_device = get_model_device(new_gm) - refitter = trt.Refitter(old_engine, TRT_LOGGER) - weight_list = refitter.get_all_weights() - - if weight_name_map: - # Get the refitting mapping - trt_wt_location = ( - trt.TensorLocation.DEVICE - if torch_device.type == "cuda" - else trt.TensorLocation.HOST - ) - - constant_mapping: dict[str, Any] = weight_name_map.pop( - "constant_mapping", {} - ) # type: ignore - mapping = construct_refit_mapping_from_weight_name_map( - weight_name_map, new_gm.state_dict() - ) - constant_mapping_with_type = {} - - for constant_name, val in constant_mapping.items(): - np_weight_type = val.dtype - val_tensor = torch.from_numpy(val).cuda() - trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType) - torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) - constant_mapping_with_type[constant_name] = ( - val_tensor.clone().reshape(-1).contiguous().to(torch_dtype), - trt_dtype, + with unset_fake_temporarily(): + refitted = set() + torch_device = get_model_device(new_gm) + refitter = trt.Refitter(old_engine, TRT_LOGGER) + weight_list = refitter.get_all_weights() + + if weight_name_map: + # Get the refitting mapping + trt_wt_location = ( + trt.TensorLocation.DEVICE + if torch_device.type == "cuda" + else trt.TensorLocation.HOST ) - mapping.update(constant_mapping_with_type) - - for layer_name in weight_list: - if layer_name not in mapping: - logger.warning(f"{layer_name} is not found in weight mapping.") - continue - # Use Numpy to create weights - weight, weight_dtype = mapping[layer_name] - trt_wt_tensor = trt.Weights( - weight_dtype, weight.data_ptr(), torch.numel(weight) + constant_mapping: dict[str, Any] = weight_name_map.pop( + "constant_mapping", {} + ) # type: ignore + mapping = construct_refit_mapping_from_weight_name_map( + weight_name_map, new_gm.state_dict(), settings ) - refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) - assert ( - len(refitter.get_missing_weights()) == 0 - ), "Fast refitting failed due to incomplete mapping" - - else: - mapping = construct_refit_mapping(new_gm, input_list, settings) - trt_wt_location = trt.TensorLocation.HOST - for layer_name in weight_list: - if layer_name not in mapping: - raise AssertionError(f"{layer_name} is not found in weight mapping") - # Use Numpy to create weights - weight, datatype = mapping[layer_name] - trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size) - refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) - refitted.add(layer_name) - - if len(refitted) != len(weight_list): - logger.warning("Not all weights have been refitted!!!") + constant_mapping_with_type = {} + + for constant_name, val in constant_mapping.items(): + np_weight_type = val.dtype + val_tensor = torch.from_numpy(val).cuda() + trt_dtype = dtype._from(np_weight_type).to(trt.DataType) + torch_dtype = dtype._from(np_weight_type).to(torch.dtype) + constant_mapping_with_type[constant_name] = ( + val_tensor.clone().reshape(-1).contiguous().to(torch_dtype), + trt_dtype, + ) - if not refitter.refit_cuda_engine(): - logger.error("Error: failed to refit new weights.") - raise AssertionError("Refitting failed.") + mapping.update(constant_mapping_with_type) + for layer_name in weight_list: + if layer_name not in mapping: + logger.warning(f"{layer_name} is not found in weight mapping.") + continue + # Use Numpy to create weights + weight, weight_dtype = mapping[layer_name] + trt_wt_tensor = trt.Weights( + weight_dtype, weight.data_ptr(), torch.numel(weight) + ) + refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) + assert ( + len(refitter.get_missing_weights()) == 0 + ), "Fast refitting failed due to incomplete mapping" + else: + mapping = construct_refit_mapping(new_gm, input_list, settings) + trt_wt_location = trt.TensorLocation.HOST + for layer_name in weight_list: + if layer_name not in mapping: + raise AssertionError(f"{layer_name} is not found in weight mapping") + # Use Numpy to create weights + weight = mapping[layer_name] + trt_dtype = dtype._from(weight.dtype).to(trt.DataType) + trt_wt_tensor = trt.Weights(trt_dtype, weight.ctypes.data, weight.size) + refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) + refitted.add(layer_name) + + if len(refitted) != len(weight_list): + logger.warning("Not all weights have been refitted!!!") + + if not refitter.refit_cuda_engine(): + logger.error("Error: failed to refit new weights.") + raise AssertionError("Refitting failed.") + + +@needs_refit # type: ignore def refit_module_weights( compiled_module: torch.fx.GraphModule | ExportedProgram, new_weight_module: ExportedProgram, diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index fc23ad76cf..98fda3696c 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -35,7 +35,7 @@ TILING_OPTIMIZATION_LEVEL, TIMING_CACHE_PATH, TRUNCATE_DOUBLE, - USE_AOT_JOINT_EXPORT, + USE_DISTRIBUTED_MODE_TRACE, USE_EXPLICIT_TYPING, USE_FAST_PARTITIONER, USE_FP32_ACC, @@ -94,9 +94,9 @@ class CompilationSettings: enable_weight_streaming (bool): Enable weight streaming. enable_cross_compile_for_windows (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built. True will enable cross-platform compatibility which allows the engine to be built on Linux and run on Windows - use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). + use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -137,9 +137,9 @@ class CompilationSettings: immutable_weights: bool = IMMUTABLE_WEIGHTS enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS - use_aot_joint_export: bool = USE_AOT_JOINT_EXPORT tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING + use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE _SETTINGS_TO_BE_ENGINE_INVARIANT = ( diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index ef04745562..f3e1b3e1fa 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -10,11 +10,11 @@ from torch._dynamo.backends.common import aot_autograd from torch._dynamo.utils import detect_fake_mode from torch._functorch.aot_autograd import aot_export_joint_simple +from torch.distributed.tensor import DTensor from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo._compiler import compile_module from torch_tensorrt.dynamo.lowering import ( get_decompositions, - modify_reshape_complex_nodes, post_lowering, remove_detach, remove_sym_nodes, @@ -52,25 +52,39 @@ def aot_torch_tensorrt_aten_backend( gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any ) -> torch.nn.Module: settings, engine_cache = parse_dynamo_kwargs(kwargs) - if settings.use_aot_joint_export: - return _pretraced_backend(gm, sample_inputs, settings, engine_cache) - logger.debug("Wrapping the backend with aot_autograd\n") - _pretraced_backend_autograd = functools.partial( - _pretraced_backend, settings=settings, engine_cache=engine_cache - ) - settings_aot_autograd = {} - settings_aot_autograd["decompostions"] = get_decompositions( - settings.enable_experimental_decompositions - ) - # This is added since detach lowering leads to alias nodes - # Error - View operation returned a tensor that is the same as the input base tensor - # torch nop_decompositions in torch/_decomp/decompositions.py - if aten.detach in settings_aot_autograd["decompositions"]: - del settings_aot_autograd["decompositions"][aten.detach] - return aot_autograd( - fw_compiler=_pretraced_backend_autograd, - decompositions=get_decompositions(settings.enable_experimental_decompositions), - )(gm, sample_inputs) + + if settings.use_distributed_mode_trace: + logger.debug( + "Wrapping the backend with aot_autograd for Distributed examples\n" + ) + _pretraced_backend_autograd = functools.partial( + _pretraced_backend, settings=settings, engine_cache=engine_cache + ) + settings_aot_autograd = {} + settings_aot_autograd["decompositions"] = get_decompositions( + settings.enable_experimental_decompositions + ) + # This is added since detach lowering leads to alias nodes + # Error - View operation returned a tensor that is the same as the input base tensor + # torch nop_decompositions in torch/_decomp/decompositions.py + # transpose key deleted since not desirable to lower it to permute + to_delete = { + key + for key in settings_aot_autograd["decompositions"] + if "detach" in key._name + } + for key in to_delete: + del settings_aot_autograd["decompositions"][key] + + return aot_autograd( + fw_compiler=_pretraced_backend_autograd, + decompositions=settings_aot_autograd["decompositions"], + )(gm, sample_inputs) + if any(isinstance(tensor, DTensor) for tensor in sample_inputs): + logger.warning( + "It is recommended to run the model with use_distributed_mode_trace = True since there are distributed tensors in the input which is not supported in aot_export_joint_simple" + ) + return _pretraced_backend(gm, sample_inputs, settings, engine_cache) def _pretraced_backend( @@ -110,18 +124,8 @@ def _pretraced_backend( # Remove detach nodes remove_detach(gm, settings) - complexInputIndices = [] - for i, torch_input in enumerate(torch_inputs): - if torch_inputs[i].dtype == torch.complex64: - complexInputIndices.append(i) - torch_input_real = torch_inputs[i].real - torch_input_imaginary = torch_inputs[i].imag - torch_inputs[i] = torch.stack( - (torch_input_real, torch_input_imaginary), dim=-1 - ) - # Invoke AOTAutograd to translate operators to aten - if settings.use_aot_joint_export: + if not settings.use_distributed_mode_trace: gm = aot_export_joint_simple( gm, sample_inputs, @@ -137,12 +141,6 @@ def _pretraced_backend( logger.debug("Lowered Input graph:\n " + str(gm.graph)) - if complexInputIndices: - modify_reshape_complex_nodes(gm, complexInputIndices) - logger.debug( - "Input graph after modifying complex nodes:\n " + str(gm.graph) - ) - torchtrt_inputs = prepare_inputs( torch_inputs, disable_memory_format_check=True ) diff --git a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py index 0dbdb2a8f4..fa5eacf7c7 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field +import numpy as np from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.fx.types import TRTNetwork @@ -19,3 +20,4 @@ class ConversionContext: default_factory=CompilationSettings ) requires_output_allocator: bool = False + mapping: dict[str, np.array] = field(default_factory=dict) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 248e06bc3c..da5f3b36c9 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -21,10 +21,12 @@ import tensorrt as trt import torch import torch.fx +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch.fx.node import _get_qualified_name from torch.fx.passes.shape_prop import TensorMetadata from torch.utils._python_dispatch import _disable_current_modes from torch_tensorrt._enums import dtype +from torch_tensorrt._features import needs_refit from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import _defaults from torch_tensorrt.dynamo._engine_cache import BaseEngineCache @@ -41,8 +43,9 @@ get_node_io, get_node_name, get_trt_tensor, + to_torch, ) -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device, to_torch_device +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER @@ -344,9 +347,10 @@ def _populate_trt_builder_config( self.compilation_settings.tiling_optimization_level ] - builder_config.l2_limit_for_tiling = ( - self.compilation_settings.l2_limit_for_tiling - ) + if self.compilation_settings.l2_limit_for_tiling != -1: + builder_config.l2_limit_for_tiling = ( + self.compilation_settings.l2_limit_for_tiling + ) return builder_config @@ -408,12 +412,13 @@ def find_weight( np_map: the map from weight name to np values in INetworkDefinition state_dict: state of the graph module """ - network_weight = torch.from_numpy(np_map[weight_name]).to(device) - for sd_w_name, sd_weight in state_dict.items(): - if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device): - del state_dict[sd_w_name] - return sd_w_name - return "" + with unset_fake_temporarily(): + network_weight = torch.from_numpy(np_map[weight_name]).to(device) + for sd_w_name, sd_weight in state_dict.items(): + if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device): + del state_dict[sd_w_name] + return sd_w_name + return "" @staticmethod def check_weight_equal( @@ -421,15 +426,17 @@ def check_weight_equal( network_weight: Union[torch.Tensor, np.ndarray], device: torch.device, ) -> Any: - if not isinstance(network_weight, torch.Tensor): - network_weight = torch.from_numpy(network_weight).to(device) - try: - return sd_weight.shape == network_weight.shape and torch.all( - torch.abs(sd_weight - network_weight) < 0.01 - ) - except Exception: - return torch.all(sd_weight == network_weight) + with unset_fake_temporarily(): + if not isinstance(network_weight, torch.Tensor): + network_weight = torch.from_numpy(network_weight).to(device) + try: + return sd_weight.shape == network_weight.shape and torch.all( + torch.abs(sd_weight - network_weight) < 0.01 + ) + except Exception: + return torch.all(sd_weight == network_weight) + @needs_refit def _save_weight_mapping(self) -> None: """ Construct the weight name mapping from engine weight name to state_dict weight name. @@ -487,29 +494,20 @@ def _save_weight_mapping(self) -> None: _LOGGER.info("Building weight name mapping...") # Stage 1: Name mapping torch_device = to_torch_device(self.compilation_settings.device) - gm_is_on_cuda = get_model_device(self.module).type == "cuda" - if not gm_is_on_cuda: - # If the model original position is on CPU, move it GPU - sd = { - k: v.reshape(-1).to(torch_device) - for k, v in self.module.state_dict().items() - } - else: - sd = {k: v.reshape(-1) for k, v in self.module.state_dict().items()} + sd = { + k: v.reshape(-1).to(torch_device) + for k, v in self.module.state_dict().items() + } weight_name_map: dict[str, Any] = {} - np_map = {} - constant_mapping = {} + np_map = self.ctx.mapping + constant_mapping = {k: v for k, v in np_map.items() if v.size == 1} net = self.ctx.net for i in range(net.num_layers): layer = net[i] layer_type: str = layer.type.name if layer_type in MODULE_MAP: - layer.__class__ = MODULE_MAP[layer_type][0] # Name mapping for weight_type, weight_name, torch_attr in MODULE_MAP[layer_type][1]: - weight = layer.__getattribute__(weight_type).copy() - if weight.size == 0: - continue engine_weight_name = f"{layer.name} {weight_name}" # Infer the corresponding weight name(s) in state_dict sd_weight_name_list = ( @@ -537,17 +535,15 @@ def _save_weight_mapping(self) -> None: elif "bias" in suffix: sd_weight_name = f"{sd_weight_name}.bias" else: - # Save the constant weights for future fast refit sd_weight_name = f"{sd_weight_name}.unknown" - constant_mapping[engine_weight_name] = weight elif layer_type == "SCALE": # Batch norm needs all weights to calculate scale and shift sd_weight_name = [f"{sd_weight_name}.{n}" for n in torch_attr] else: sd_weight_name = f"{sd_weight_name}.{torch_attr}" - weight_name_map[engine_weight_name] = sd_weight_name - np_map[engine_weight_name] = weight + if engine_weight_name in np_map: + weight_name_map[engine_weight_name] = sd_weight_name # Stage 2: Value mapping for engine_weight_name, sd_weight_name in weight_name_map.items(): @@ -579,6 +575,7 @@ def _save_weight_mapping(self) -> None: gc.collect() torch.cuda.empty_cache() + @needs_refit def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None: # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine # if not self.compilation_settings.strip_engine_weights: @@ -606,6 +603,7 @@ def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> No ), ) + @needs_refit def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]: # query the cached TRT engine cached_data = self.engine_cache.check(hash_val) # type: ignore[union-attr] @@ -716,7 +714,7 @@ def run( if self.compilation_settings.reuse_cached_engines: interpreter_result = self._pull_cached_engine(hash_val) if interpreter_result is not None: # hit the cache - return interpreter_result + return interpreter_result # type: ignore[no-any-return] self._construct_trt_network_def() @@ -887,9 +885,7 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any: return converter(self.ctx, target, args, kwargs, self._cur_node_name) def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray: - with _disable_current_modes(): - from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy - + with _disable_current_modes(), unset_fake_temporarily(): frozen_attr = self.fetch_attr(target) if isinstance(frozen_attr, torch.nn.Parameter): @@ -897,9 +893,7 @@ def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray: else: constant_tensor = frozen_attr - network_constant = to_numpy(constant_tensor) - - return network_constant + return to_torch(constant_tensor) def call_method(self, target: str, args: Any, kwargs: Any) -> Any: assert isinstance(target, str) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 62526080c4..3541f57f1a 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -9,6 +9,7 @@ import tensorrt as trt import torch import torch_tensorrt.dynamo.conversion.impl as impl +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch.fx.node import Argument, Target from torch.fx.passes.shape_prop import TensorMetadata from torch_tensorrt import _enums @@ -340,17 +341,44 @@ def create_constant( Returns: A TensorRT ITensor that represents the given value. """ - shape = (1,) - # Rank 0 constant is required in IFillLayer inputs. - if min_rank == 0: - shape = trt.Dims() - numpy_value = to_numpy(value, dtype) - constant = ctx.net.add_constant( - shape if isinstance(value, (int, float, bool)) else value.shape, - numpy_value.copy() if isinstance(numpy_value, np.ndarray) else numpy_value, - ) - constant.name = name - return constant.get_output(0) + with unset_fake_temporarily(): + torch_value = to_torch(value, dtype) + if torch_value is None: + raise ValueError( + f"Cannot convert tensor '{name}' to a TensorRT constant because its value is None." + ) + if torch_value.dtype == torch.float64: + raise ValueError( + "TensorRT does not support float64 (double) precision. To resolve this, please set truncate_double=True in your compilation settings and re-run the model." + ) + # Rank 0 constant is required in IFillLayer inputs. + if min_rank == 0 and isinstance(value, (int, float, bool)): + shape = trt.Dims() + elif list(torch_value.shape) == []: + shape = trt.Dims() + else: + shape = list(torch_value.shape) + + if torch_value.dtype == torch.bfloat16: + torch_value_fp32 = torch_value.to(torch.float32) + numpy_value = torch_value_fp32.numpy() + else: + numpy_value = torch_value.numpy() + + ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1) + constant = ctx.net.add_constant( + shape, + numpy_value, + ) + constant.name = name + if torch_value.dtype == torch.bfloat16: + return cast_trt_tensor( + ctx, + constant.get_output(0), + trt.DataType.BF16, + name + "_bf16_cast", + ) + return constant.get_output(0) def get_trt_tensor( @@ -554,39 +582,92 @@ def to_numpy( Returns: A Numpy array or None, if the input was None. """ - output = None + with unset_fake_temporarily(): + output = None + + if value is None or isinstance(value, np.ndarray): + output = value + + elif isinstance(value, torch.Tensor): + if value.is_quantized: + value = value.dequantize() + elif value.dtype == torch.bfloat16: + # TODO: Remove when numpy has a BF16 type + _LOGGER.warning( + "Requested a conversion of bfloat16 tensor from torch to numpy which isn't supported. Casting this tensor to FP32 precision currently. Please use to_torch() API for better data representation", + ) + value = value.to(torch.float) + + output = value.cpu().detach().contiguous().numpy() + + elif isinstance(value, int): + output = np.array([value], dtype=np.int32) + + elif isinstance(value, float): + output = np.array([value], dtype=np.float32) + + elif isinstance(value, bool): + output = np.array([value], dtype=np.bool_) + + if isinstance(output, np.ndarray) or output is None: + return ( + output + if (dtype is None or output is None) + else output.astype( + _enums.dtype._from(dtype).to(np.dtype, use_default=True) + ) + ) + else: + raise AssertionError( + f"to_numpy can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got: {value}" + ) - if value is None or isinstance(value, np.ndarray): - output = value - elif isinstance(value, torch.Tensor): - if value.is_quantized: - value = value.dequantize() - elif value.dtype == torch.bfloat16: - # TODO: Remove when numpy has a BF16 type - value = value.to(torch.float) +def to_torch( + value: Optional[Union[torch.Tensor, np.ndarray, int, float, bool]], + dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None, +) -> Optional[torch.Tensor]: + """ + Convert a Numpy array, or scalar to a PyTorch tensor and move it to CPU + Args: + value (Optional[Union[torch.Tensor, np.ndarray, int, float, bool]]): + A PyTorch tensor, Numpy array, int, float, or bool + dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): + If a dtype is given, we will convert the type of the given `value` to this dtype. + Returns: + A PyTorch tensor or None, if the input was None. + """ - output = value.cpu().detach().contiguous().numpy() + cpu_device = torch.device("cpu") + torch_dtype = ( + _enums.dtype._from(dtype).to(torch.dtype, use_default=True) if dtype else None + ) - elif isinstance(value, int): - output = np.array([value], dtype=np.int32) + with unset_fake_temporarily(): + if value is None: + return None - elif isinstance(value, float): - output = np.array([value], dtype=np.float32) + elif isinstance(value, torch.Tensor): + output = value.to(cpu_device).contiguous() - elif isinstance(value, bool): - output = np.array([value], dtype=np.bool_) + elif isinstance(value, np.ndarray): + output = torch.from_numpy(value).to(cpu_device).contiguous() - if isinstance(output, np.ndarray) or output is None: - return ( - output - if (dtype is None or output is None) - else output.astype(_enums.dtype._from(dtype).to(np.dtype, use_default=True)) - ) - else: - raise AssertionError( - f"to_numpy can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got: {value}" - ) + elif isinstance(value, int): + output = torch.tensor([value], device=cpu_device, dtype=torch.int32) + + elif isinstance(value, float): + output = torch.tensor([value], device=cpu_device, dtype=torch.float32) + + elif isinstance(value, bool): + output = torch.tensor([value], device=cpu_device, dtype=torch.bool) + + else: + raise AssertionError( + f"to_torch can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got an object of type: {type(value)}" + ) + + return output.to(torch_dtype) if torch_dtype else output def flatten_dims( diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py index 17850fabce..79611c7552 100644 --- a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py @@ -3,6 +3,7 @@ import logging from typing import Dict, Sequence, Tuple, Union +import tensorrt as trt from torch.fx.node import Argument, Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl @@ -16,8 +17,6 @@ tensorrt_fused_nccl_reduce_scatter_op, ) -import tensorrt as trt - _LOGGER: logging.Logger = logging.getLogger(__name__) if load_tensorrt_llm(): @@ -30,7 +29,7 @@ def fused_nccl_gather( kwargs: Dict[str, Argument], name: str, ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: - return impl.distributed.nccl_gather( + return impl.nccl_ops.nccl_gather( ctx, target, SourceIR.ATEN, @@ -46,7 +45,7 @@ def fused_nccl_reduce_scatter( kwargs: Dict[str, Argument], name: str, ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: - return impl.distributed.nccl_reduce_scatter( + return impl.nccl_ops.nccl_reduce_scatter( ctx, target, SourceIR.ATEN, @@ -54,7 +53,6 @@ def fused_nccl_reduce_scatter( [args[0]], ) - breakpoint() else: _LOGGER.debug( "Did not load torch.distributed converters since TensorRT-LLM is not available" diff --git a/py/torch_tensorrt/dynamo/conversion/impl/conv.py b/py/torch_tensorrt/dynamo/conversion/impl/conv.py index 25419d7f60..f27fb13e97 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/conv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/conv.py @@ -13,7 +13,7 @@ cast_trt_tensor, extend_attr_to_tuple, get_trt_tensor, - to_numpy, + to_torch, ) from torch_tensorrt.fx.converters.converter_utils import ( get_dyn_range, @@ -45,7 +45,6 @@ def convNd( assert input.shape[1] != -1, "Channel dim can't be dynamic for convolution." num_dims = len(input.shape) - 2 - if is_conv1d: # Apply an unsqueeze operation to transform the conv1d problem into conv2d input = impl.unsqueeze.unsqueeze( @@ -54,8 +53,8 @@ def convNd( # Process bias terms if isinstance(bias, (torch.Tensor, np.ndarray)): - # Transform the bias constant into a Numpy array - bias = to_numpy(bias, dtype=input.dtype) + bias = to_torch(bias, dtype=input.dtype) + bias = get_trt_tensor(ctx, bias, f"{name}_bias") elif isinstance(bias, TRTTensor): bias = get_trt_tensor(ctx, bias, f"{name}_bias") @@ -74,12 +73,11 @@ def convNd( ctx, target, source_ir, weight.name + "_unsqueeze_conv1d", weight, -1 ) elif isinstance(weight, (torch.Tensor, np.ndarray)): - # Transform the weight constant into a Numpy array - weight = to_numpy(weight, dtype=input.dtype) - + weight = to_torch(weight, dtype=input.dtype) # Append new dimension (unsqueeze) if the convolution is 1d if is_conv1d: - weight = np.expand_dims(weight, -1) + weight = torch.unsqueeze(weight, -1) + weight = get_trt_tensor(ctx, weight, f"{name}_weight") else: raise RuntimeError( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py index d19a92e646..629cecf5db 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py @@ -6,13 +6,12 @@ import tensorrt as trt import torch from torch.fx.node import Target - from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( extend_attr_to_tuple, get_trt_tensor, - to_numpy, + to_torch, ) from torch_tensorrt.fx.converters.converter_utils import ( SourceIR, @@ -53,7 +52,8 @@ def deconvNd( # Process bias terms if isinstance(bias, (torch.Tensor, np.ndarray)): # Transform the bias constant into a Numpy array - bias = to_numpy(bias) + bias = to_torch(bias, dtype=input.dtype) + bias = get_trt_tensor(ctx, bias, f"{name}_bias") elif isinstance(bias, TRTTensor): bias = get_trt_tensor(ctx, bias, f"{name}_bias") @@ -73,12 +73,12 @@ def deconvNd( ) elif isinstance(weight, (torch.Tensor, np.ndarray)): - # Transform the weight constant into a Numpy array - weight = to_numpy(weight) - + weight = to_torch(weight, dtype=input.dtype) # Append new dimension (unsqueeze) if the deconvolution is 1d if is_deconv1d: - weight = np.expand_dims(weight, axis=-1) + weight = torch.unsqueeze(weight, -1) + + weight = get_trt_tensor(ctx, weight, f"{name}_weight") else: raise RuntimeError( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index 4a4b33abea..ab9629b0db 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Optional, Union import numpy as np +import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt import _enums @@ -15,11 +16,10 @@ get_trt_tensor, has_dynamic_shape, set_layer_name, + to_torch, ) from torch_tensorrt.dynamo.types import TRTElementWiseOp, TRTTensor -import tensorrt as trt - def get_python_op_from_trt_elementwise_op( trt_op: TRTElementWiseOp, @@ -125,10 +125,9 @@ def convert_binary_elementwise( # dtype but we don't have a way to detect whether it makes sense for the # scalar to be float or half. Hence we go with the lhs dtype. if is_lhs_trt_tensor and isinstance(rhs_val, (float, int, bool)): - rhs_val = np.array([rhs_val], dtype=_enums.dtype._from(lhs_dtype).to(np.dtype)) + rhs_val = to_torch(rhs_val, dtype=lhs_dtype) if is_rhs_trt_tensor and isinstance(lhs_val, (float, int, bool)): - lhs_val = np.array([lhs_val], dtype=_enums.dtype._from(rhs_dtype).to(np.dtype)) - + lhs_val = to_torch(lhs_val, dtype=rhs_dtype) lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype) rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py index 013268f803..c28c5bcc7d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py @@ -3,12 +3,11 @@ from typing import Optional, Tuple, Union import numpy as np +import tensorrt as trt from torch.fx.node import Argument, Target from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.fx.converters.converter_utils import SourceIR, set_layer_name -import tensorrt as trt - # class for AllReduce class AllReduceStrategy(IntEnum): @@ -94,7 +93,7 @@ def nccl_reduce_scatter( "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32 ) - p_dtype = trt.float16 + p_dtype = trt.float32 pf_dtype = trt.PluginField( "type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32 ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py index b97840cd09..dbccd2e332 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py @@ -1,11 +1,13 @@ -from typing import Optional +from typing import Optional, Union import numpy as np import tensorrt as trt +import torch +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor, to_torch from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor @@ -16,7 +18,7 @@ def quantize( source_ir: Optional[SourceIR], name: str, input_tensor: TRTTensor, - amax: np.ndarray, + amax: Union[np.ndarray, torch.Tensor], num_bits: int, exponent_bits: int, ) -> TRTTensor: @@ -24,40 +26,43 @@ def quantize( Adds quantize and dequantize ops (QDQ) which quantize to INT8 or FP8 based on the output_type set and dequantizes them back. """ - if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in ( - trt.float32, - trt.float16, - ): - raise ValueError( - f"quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16" - ) - if num_bits != 8 or exponent_bits not in (0, 4): - raise ValueError( - f"quantize converter currently only accept INT8 or FP8 based quantize, got {num_bits=}, {exponent_bits=}" - ) - if num_bits == 8 and exponent_bits == 0: - max_bound = 127 - elif num_bits == 8 and exponent_bits == 4: - max_bound = 448 - scale = np.divide(amax, max_bound) - scale = get_trt_tensor(ctx, scale, name + "_scale") - # Add Q node - quantize_layer = ctx.net.add_quantize(input_tensor, scale) - if num_bits == 8 and exponent_bits == 0: - quantize_layer.set_output_type(0, trt.DataType.INT8) - elif num_bits == 8 and exponent_bits == 4: - quantize_layer.set_output_type(0, trt.DataType.FP8) + with unset_fake_temporarily(): + if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in ( + trt.float32, + trt.float16, + ): + raise ValueError( + f"quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16" + ) + if num_bits != 8 or exponent_bits not in (0, 4): + raise ValueError( + f"quantize converter currently only accept INT8 or FP8 based quantize, got {num_bits=}, {exponent_bits=}" + ) + if num_bits == 8 and exponent_bits == 0: + max_bound = 127 + elif num_bits == 8 and exponent_bits == 4: + max_bound = 448 - set_layer_name(quantize_layer, target, name + "_quantize", source_ir) - q_output = quantize_layer.get_output(0) - # Add DQ node - dequantize_layer = ctx.net.add_dequantize(q_output, scale) - set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir) - if num_bits == 8 and exponent_bits == 0: - dequantize_layer.precision = trt.DataType.INT8 - elif num_bits == 8 and exponent_bits == 4: - # Set DQ layer precision to FP8 - dequantize_layer.precision = trt.DataType.FP8 - dq_output = dequantize_layer.get_output(0) + amax = to_torch(amax, None) + scale = torch.divide(amax, max_bound) + scale = get_trt_tensor(ctx, scale, name + "_scale") + # Add Q node + quantize_layer = ctx.net.add_quantize(input_tensor, scale) + if num_bits == 8 and exponent_bits == 0: + quantize_layer.set_output_type(0, trt.DataType.INT8) + elif num_bits == 8 and exponent_bits == 4: + quantize_layer.set_output_type(0, trt.DataType.FP8) - return dq_output + set_layer_name(quantize_layer, target, name + "_quantize", source_ir) + q_output = quantize_layer.get_output(0) + # Add DQ node + dequantize_layer = ctx.net.add_dequantize(q_output, scale) + set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir) + if num_bits == 8 and exponent_bits == 0: + dequantize_layer.precision = trt.DataType.INT8 + elif num_bits == 8 and exponent_bits == 4: + # Set DQ layer precision to FP8 + dequantize_layer.precision = trt.DataType.FP8 + dq_output = dequantize_layer.get_output(0) + + return dq_output diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py index 4211bae1fa..8f5f173a7b 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py @@ -1,3 +1,4 @@ +import itertools import logging from types import FunctionType from typing import Any, Callable, Tuple @@ -108,7 +109,6 @@ def generate_signature( def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]: shape_env = ShapeEnv() - fake_mode = FakeTensorMode(shape_env=shape_env) syms_args = [] tensor_args = [elem for elem in args if isinstance(elem, trtp.TensorDesc)] @@ -121,7 +121,7 @@ def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]: ] syms_args.append(syms_arg) - with FakeTensorMode() as fake_mode: + with FakeTensorMode(shape_env=shape_env) as fake_mode: fake_args = [] for syms_arg in syms_args: fake_arg = torch.randn(syms_arg) @@ -130,16 +130,25 @@ def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]: output = torch_op(*fake_args, **kwargs) # We assume that number of dimensions are the same in torch op - shape_calc_fns = [None] * args[0].ndim - for i in range(args[0].ndim): - input_node_expr = [syms_arg[i].node.expr for syms_arg in syms_args] + shape_calc_fns = [None] * output.ndim + + for i in range(output.ndim): + input_node_expr = list( + itertools.chain.from_iterable( + [sym.node.expr for sym in syms_arg] for syms_arg in syms_args + ) + ) + shape_calc_fns[i] = lambdify( tuple(input_node_expr), output.shape[i].node.expr, "math" ) out_desc = tensor_args[0].like() for i in range(out_desc.ndim): - input_shape_expr = [tensor_arg.shape_expr[i] for tensor_arg in tensor_args] + input_shape_expr = list( + itertools.chain.from_iterable(arg.shape_expr for arg in tensor_args) + ) + if output.shape[i].node.expr is None: raise ValueError(f"output.shape[{i}].node.expr cannot be None") out_desc.shape_expr[i] = shape_calc_fns[i](*input_shape_expr) # type: ignore[misc] diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py index 8b0e60881a..99ea3bc356 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py @@ -1,4 +1,5 @@ import logging +import uuid from typing import Callable, Dict, Optional, Sequence, Tuple, Union import numpy as np @@ -47,11 +48,15 @@ def custom_kernel_converter( kwargs: Dict[str, Argument], name: str, ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + plugin = getattr(getattr(trtp.op, namespace), op_name) + tensor_inputs = plugin.input_tensor_names tensor_args = args[0 : len(tensor_inputs)] + + unique_id = uuid.uuid4() itensor_args = [ - get_trt_tensor(ctx, t, f"{t_name}") + get_trt_tensor(ctx, t, f"{t_name}_{unique_id}") for (t, t_name) in zip(tensor_args, tensor_inputs) ] diff --git a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py index f709f177d6..02cb2ccd56 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py @@ -49,7 +49,6 @@ def fuse_distributed_ops( == torch.ops._c10d_functional.wait_tensor.default ): wait_tensor_node = list(node.users)[0] - fused_op = None if node.target == torch.ops._c10d_functional.all_gather_into_tensor.default: with gm.graph.inserting_after(wait_tensor_node): fused_node = gm.graph.create_node( @@ -58,11 +57,12 @@ def fuse_distributed_ops( args=(node.args[0], node.args[1], node.args[2]), ) else: - fused_node = gm.graph.create_node( - op="call_function", - target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function - args=(node.args[0], node.args[1], node.args[2], node.args[3]), - ) + with gm.graph.inserting_after(wait_tensor_node): + fused_node = gm.graph.create_node( + op="call_function", + target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function + args=(node.args[0], node.args[1], node.args[2], node.args[3]), + ) wait_tensor_node.replace_all_uses_with(fused_node) fused_node.meta.update(node.meta) diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index 5af9b11a4b..9e54fbac3d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -103,9 +103,13 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: return False - def __del__(self) -> None: + def _reset_captured_graph(self) -> None: if self.cudagraph: self.cudagraph.reset() + self.cudagraph = None + + def __del__(self) -> None: + self._reset_captured_graph() def set_use_output_allocator(self, enable: bool) -> None: self.use_output_allocator_outputs = enable @@ -119,8 +123,7 @@ def forward( shape_changed = self.validate_input_shapes(inputs) need_cudagraphs_record = shape_changed or self.is_weight_streaming_set if need_cudagraphs_record: - if self.cudagraph: - self.cudagraph.reset() + self._reset_captured_graph() self._input_buffers = [None] * len(inputs) self.is_weight_streaming_set = False @@ -196,7 +199,5 @@ def forward( return outputs[0] return outputs else: - if self.cudagraph: - self.cudagraph.reset() - self.cudagraph = None + self._reset_captured_graph() return self.compiled_module(*args, **kwargs) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 891d063ed3..6415ce11c3 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -333,9 +333,13 @@ def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule: result.__setstate__(self.__getstate__()) return result - def __del__(self) -> None: + def _reset_captured_graph(self) -> None: if self.cudagraph: self.cudagraph.reset() + self.cudagraph = None + + def __del__(self) -> None: + self._reset_captured_graph() def setup_input_tensors( self, @@ -426,9 +430,8 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: self.cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed ) - if need_cudagraphs_reset and self.cudagraph: - self.cudagraph.reset() - self.cudagraph = None + if need_cudagraphs_reset: + self._reset_captured_graph() if need_cudagraphs_record: self._input_buffers = [None] * len(self.input_names) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index e6b6a21421..c3fe925eee 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -209,6 +209,9 @@ def set_device_memory_budget(self, budget_bytes: int) -> int: return budget_bytes + def _reset_captured_graph(self) -> None: + self.engine.reset_captured_graph() + def setup_engine(self) -> None: """ Setup engine for a module which has deferred engine setup. diff --git a/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py b/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py index f481c5b2b8..500a665688 100644 --- a/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py +++ b/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py @@ -142,8 +142,32 @@ def automatic_device_memory_budget_getter(self) -> Any: def infer_outputs(self, input_shapes: List[Any]) -> Any: pass + def reset_captured_graph(self) -> Any: + pass + def __setstate__(self, serialized_state: List[str]) -> Any: pass def __getstate__(self) -> Any: pass + + +@torch.library.custom_op( + "tensorrt::no_op_placeholder_for_execute_engine", mutates_args=() +) +def no_op_placeholder_for_execute_engine( + inputs: List[torch.Tensor], + abi_version: str, + name: str, + serialized_device_info: str, + serialized_engine: str, + serialized_in_binding_names: str, + serialized_out_binding_names: str, + serialized_hardware_compatible: str, + serialized_metadata: str, + serialized_target_platform: str, + serialized_require_output_allocator: str, +) -> List[torch.Tensor]: + raise RuntimeError( + "The saved model is cross compiled for windows in Linux, should only be loadded in Windows via torch_tensorrt.load_cross_compiled_exported_program() api." + ) diff --git a/py/torch_tensorrt/runtime/_cudagraphs.py b/py/torch_tensorrt/runtime/_cudagraphs.py index c771564826..346132145e 100644 --- a/py/torch_tensorrt/runtime/_cudagraphs.py +++ b/py/torch_tensorrt/runtime/_cudagraphs.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Union +from typing import Any, Optional, Union import torch import torch_tensorrt @@ -68,6 +68,7 @@ def __init__(self, compiled_module: torch.nn.Module) -> None: global _PY_RT_CUDAGRAPHS self.old_mode = _PY_RT_CUDAGRAPHS self.compiled_module = compiled_module + self.cudagraphs_module: Optional[CudaGraphsTorchTensorRTModule] = None def __enter__(self) -> torch.nn.Module: global _PY_RT_CUDAGRAPHS @@ -98,7 +99,8 @@ def __enter__(self) -> torch.nn.Module: logger.debug( "Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule" ) - return CudaGraphsTorchTensorRTModule(self.compiled_module) + self.cudagraphs_module = CudaGraphsTorchTensorRTModule(self.compiled_module) + return self.cudagraphs_module else: if num_trt_module > 0: logger.debug("No graph breaks detected, using runtime cudagraphs mode") @@ -113,6 +115,9 @@ def __enter__(self) -> torch.nn.Module: def __exit__(self, *args: Any) -> None: # Set cudagraphs back to old mode set_cudagraphs_mode(self.old_mode) + # __del__ is not entirely predictable, so we reset cudagraph here + if self.cudagraphs_module: + self.cudagraphs_module._reset_captured_graph() def enable_cudagraphs( diff --git a/py/torch_tensorrt/runtime/_utils.py b/py/torch_tensorrt/runtime/_utils.py index c42a2b2a2b..bc2e5a6a70 100644 --- a/py/torch_tensorrt/runtime/_utils.py +++ b/py/torch_tensorrt/runtime/_utils.py @@ -128,23 +128,3 @@ def _get_most_compatible_device( best_match = candidate return best_match - - -@torch.library.custom_op( - "tensorrt::no_op_placeholder_for_execute_engine", mutates_args=() -) -def no_op_placeholder_for_execute_engine( - inputs: List[torch.Tensor], - abi_version: str, - name: str, - serialized_device_info: str, - serialized_engine: str, - serialized_in_binding_names: str, - serialized_out_binding_names: str, - serialized_hardware_compatible: str, - serialized_metadata: str, - serialized_target_platform: str, -) -> List[torch.Tensor]: - raise RuntimeError( - "The saved model is cross compiled for windows in Linux, should only be loadded in Windows via torch_tensorrt.load_cross_compiled_exported_program() api." - ) diff --git a/py/torch_tensorrt/runtime/_weight_streaming.py b/py/torch_tensorrt/runtime/_weight_streaming.py index 3b11087fcb..0874d31d11 100755 --- a/py/torch_tensorrt/runtime/_weight_streaming.py +++ b/py/torch_tensorrt/runtime/_weight_streaming.py @@ -76,12 +76,15 @@ def _set_streamable_weight_bytes(self, requested_budget: int) -> int: int(streamable_bytes / total_bytes * requested_budget) for streamable_bytes in self.streamable_budget ] + if self.cuda_graphs_module: + self.cuda_graphs_module.is_weight_streaming_set = True + self.cuda_graphs_module._reset_captured_graph() + for i, (name, rt_mod) in enumerate(self.rt_mods): + rt_mod._reset_captured_graph() ws_budget_bytes += rt_mod.set_device_memory_budget(normalized_size[i]) logger.debug(f"Set weight streaming size {normalized_size[i]} for {name}") - if self.cuda_graphs_module: - self.cuda_graphs_module.is_weight_streaming_set = True return ws_budget_bytes def __setattr__(self, name: str, value: Any) -> None: diff --git a/pyproject.toml b/pyproject.toml index e9b12e93e6..5f7fbb62cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] requires = [ - "setuptools>=68.0.0", + "setuptools>=77.0.0", "packaging>=23.1", "wheel>=0.40.0", "ninja>=1.11.0", @@ -9,7 +9,7 @@ requires = [ "typing-extensions>=4.7.0", "future>=0.18.3", "tensorrt-cu12>=10.9.0,<10.10.0", - "torch>=2.7.0.dev,<2.8.0", + "torch>=2.7.0,<2.8.0", "pybind11==2.6.2", "numpy", "sympy", @@ -55,7 +55,7 @@ keywords = [ "inference", ] dependencies = [ - "torch>=2.7.0.dev,<2.8.0", + "torch>=2.7.0,<2.8.0", "tensorrt>=10.9.0,<10.10.0", "tensorrt-cu12>=10.9.0,<10.10.0", "tensorrt-cu12-bindings>=10.9.0,<10.10.0", @@ -108,12 +108,12 @@ prerelease = "if-necessary-or-explicit" index-strategy = "unsafe-best-match" # Needed for TRT-LLM [tool.uv.sources] -torch = [{ index = "pytorch-nightly-cu126" }] -torchvision = [{ index = "pytorch-nightly-cu126" }] +torch = [{ index = "pytorch-test-cu128" }] +torchvision = [{ index = "pytorch-test-cu128" }] [[tool.uv.index]] -name = "pytorch-nightly-cu126" -url = "https://download.pytorch.org/whl/nightly/cu126" +name = "pytorch-test-cu128" +url = "https://download.pytorch.org/whl/test/cu128" explicit = false [[tool.uv.index]] @@ -151,7 +151,6 @@ explicit = false # url = "https://download.pytorch.org/whl/cu118" # explicit = false - [tool.ruff] # NOTE: Synchoronize the ignores with .flake8 lint.ignore = [ diff --git a/setup.py b/setup.py index 09933307c8..9f74cdb9d0 100644 --- a/setup.py +++ b/setup.py @@ -18,12 +18,12 @@ import torch import yaml from setuptools import Extension, find_namespace_packages, setup +from setuptools.command.bdist_wheel import bdist_wheel from setuptools.command.build_ext import build_ext from setuptools.command.develop import develop from setuptools.command.editable_wheel import editable_wheel from setuptools.command.install import install from torch.utils.cpp_extension import IS_WINDOWS, BuildExtension, CUDAExtension -from wheel.bdist_wheel import bdist_wheel __version__: str = "0.0.0" __cuda_version__: str = "0.0" diff --git a/tests/modules/custom_models.py b/tests/modules/custom_models.py index b62faffd1b..4906a1d495 100644 --- a/tests/modules/custom_models.py +++ b/tests/modules/custom_models.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from transformers import BertConfig, BertModel, BertTokenizer # Sample Pool Model (for testing plugin serialization) @@ -165,6 +164,8 @@ def forward(self, z: List[torch.Tensor]): def BertModule(): + from transformers import BertConfig, BertModel, BertTokenizer + enc = BertTokenizer.from_pretrained("google-bert/bert-base-uncased") text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" tokenized_text = enc.tokenize(text) diff --git a/tests/modules/hub.py b/tests/modules/hub.py index 0cce523fb3..d87635b435 100644 --- a/tests/modules/hub.py +++ b/tests/modules/hub.py @@ -4,10 +4,7 @@ import custom_models as cm import timm import torch -import torch.nn as nn -import torch.nn.functional as F import torchvision.models as models -from transformers import BertConfig, BertModel, BertTokenizer torch.hub._validate_not_a_forked_repo = lambda a, b, c: True diff --git a/tests/py/dynamo/automatic_plugin/test_automatic_plugin.py b/tests/py/dynamo/automatic_plugin/test_automatic_plugin.py index ae60f8cda7..8ab47def08 100644 --- a/tests/py/dynamo/automatic_plugin/test_automatic_plugin.py +++ b/tests/py/dynamo/automatic_plugin/test_automatic_plugin.py @@ -81,12 +81,3 @@ def forward(self, lhs, rhs): if __name__ == "__main__": run_tests() - -# Example Usage -# A = torch.full((64, 64), 2, device="cuda", dtype=torch.float) -# B = torch.full((64, 64), 3, device="cuda", dtype=torch.float) - -# C, D = torch.ops.torchtrt_ex.elementwise_add_mul.default(A, B) - -# print("C (Addition):", C) -# print("D (Multiplication):", D) diff --git a/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py b/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py new file mode 100644 index 0000000000..fd5ed390ff --- /dev/null +++ b/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py @@ -0,0 +1,57 @@ +import unittest + +import pytest +import torch +import torch.nn as nn +import torch_tensorrt +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt._enums import dtype + +from ..conversion.harness import DispatchTestCase + +# Toggle this flag to enable/disable flashinfer-based overrides +enable_flashinfer: bool = False +if enable_flashinfer: + import flashinfer + + @torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc] + def flashinfer_rmsnorm( + input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 + ) -> torch.Tensor: + return flashinfer.norm.rmsnorm(input, weight) + + @torch.library.register_fake("flashinfer::rmsnorm") + def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor: + return input + + torch_tensorrt.dynamo.conversion.plugins.custom_op( + "flashinfer::rmsnorm", supports_dynamic_shapes=True + ) + + +@unittest.skip( + "Flashinfer RMSNorm test is disabled due to error: SM75 support not available" +) +class TestAutomaticPlugin(DispatchTestCase): + @parameterized.expand( + [ + ((64, 64), (64,), torch.float16), + ((256, 256), (256,), torch.float16), + ] + ) + def test_rmsnorm_float(self, input_shape, weight_shape, data_type): + class rmsnorm(nn.Module): + def forward(self, input, weight): + return torch.ops.flashinfer.rmsnorm.default(input, weight) + + inputs = [ + torch.randn(input_shape, device="cuda", dtype=data_type), + torch.randn(weight_shape, device="cuda", dtype=data_type), + ] + + self.run_test(rmsnorm(), inputs, precision=dtype.f16) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/backend/test_backend_compiler.py b/tests/py/dynamo/backend/test_backend_compiler.py index 4c65800f05..6369d3805c 100644 --- a/tests/py/dynamo/backend/test_backend_compiler.py +++ b/tests/py/dynamo/backend/test_backend_compiler.py @@ -2,11 +2,10 @@ from copy import deepcopy import torch +import torch_tensorrt from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt.dynamo.partitioning import fast_partition -import torch_tensorrt - from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing @@ -51,7 +50,6 @@ def forward(self, x, y): pass_through_build_failures=True, torch_executed_ops={"torch.ops.aten.add.Tensor"}, use_python_runtime=False, - debug=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = fx_graph(*inputs).detach().cpu() @@ -132,7 +130,6 @@ def forward(self, x, y): pass_through_build_failures=True, torch_executed_ops={"torch.ops.aten.add.Tensor"}, use_python_runtime=False, - debug=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = model(*inputs).detach().cpu() @@ -177,7 +174,6 @@ def forward(self, x, y): optimization_level=4, version_compatible=True, max_aux_streams=5, - debug=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = fx_graph(*inputs).detach().cpu() @@ -225,7 +221,6 @@ def forward(self, x, y): min_block_size=1, pass_through_build_failures=True, truncate_double=True, - debug=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = fx_graph(*inputs).detach().cpu() @@ -298,7 +293,6 @@ def forward(self, x, y): min_block_size=1, pass_through_build_failures=True, truncate_double=False, - debug=True, torch_executed_ops={"torch.ops.aten.add.Tensor"}, ) optimized_model_results = optimized_model(*inputs).detach().cpu() diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 6ff45507a0..aa22a74fc0 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -415,7 +415,6 @@ def run_test( compilation_settings = CompilationSettings( enabled_precisions={dtype._from(precision)}, truncate_double=True, - debug=True, immutable_weights=immutable_weights, ) @@ -507,7 +506,6 @@ def run_test_compare_tensor_attributes_only( compilation_settings = CompilationSettings( enabled_precisions={dtype._from(precision)}, truncate_double=True, - debug=True, immutable_weights=immutable_weights, ) diff --git a/tests/py/dynamo/conversion/test_binary_ops_aten.py b/tests/py/dynamo/conversion/test_binary_ops_aten.py index 79c0d9430a..ac8cf4b00b 100644 --- a/tests/py/dynamo/conversion/test_binary_ops_aten.py +++ b/tests/py/dynamo/conversion/test_binary_ops_aten.py @@ -228,6 +228,28 @@ def forward(self, x, y): ] self.run_test_with_dynamic_shape(Op(), input_specs) + @parameterized.expand( + [ + (f"bf16_{op[0].__name__}_one_constant", op[0]) + for op in elementwise_ops + if op[0].__name__ not in ["pow.Tensor_Tensor", "fmod.Tensor"] + ] + ) + def test_elementwise_ops_bf16(self, _, orig_op): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.constant = torch.randn(1) + self.orig_op = orig_op + + def forward(self, x): + x = self.orig_op(x, self.constant) + return self.orig_op(x, -2) + + m = TestModule(orig_op) + inputs = [torch.randn(2, 2, dtype=torch.bfloat16)] + self.run_test(m, inputs) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/distributed/distributed_utils.py b/tests/py/dynamo/distributed/distributed_utils.py new file mode 100644 index 0000000000..e3062249fa --- /dev/null +++ b/tests/py/dynamo/distributed/distributed_utils.py @@ -0,0 +1,60 @@ +import logging +import os + +import numpy as np +import tensorrt as trt +import torch +import torch.distributed as dist +from torch.distributed._tensor.device_mesh import init_device_mesh + + +def set_environment_variables_pytest(): + os.environ["WORLD_SIZE"] = str(1) + os.environ["RANK"] = str(0) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(29500) + os.environ["USE_TRTLLM_PLUGINS"] = "1" + + +def initialize_logger(rank, logger_file_name): + logger = logging.getLogger() + logger.setLevel(logging.INFO) + fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") + fh.setLevel(logging.INFO) + logger.addHandler(fh) + return logger + + +# This is required for env initialization since we use mpirun +def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500): + local_rank = int( + os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) + ) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) + + # Set up environment variable to run with mpirun + os.environ["RANK"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["TRTLLM_PLUGINS_PATH"] = "./tmp/lib/libnvinfer_plugin_tensorrt_llm.so" + + # Necessary to assign a device to each rank. + torch.cuda.set_device(local_rank) + + # We use nccl backend + dist.init_process_group("nccl") + + # set a manual seed for reproducibility + torch.manual_seed(1111) + + device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) + rank = device_mesh.get_rank() + assert rank == local_rank + logger = initialize_logger(rank, logger_file_name) + device_id = ( + rank % torch.cuda.device_count() + ) # Ensure each rank gets a unique device + torch.cuda.set_device(device_id) + + return device_mesh, world_size, rank, logger diff --git a/tests/py/dynamo/distributed/test_distributed_simple_example.py b/tests/py/dynamo/distributed/test_distributed_simple_example.py new file mode 100644 index 0000000000..202469e2ea --- /dev/null +++ b/tests/py/dynamo/distributed/test_distributed_simple_example.py @@ -0,0 +1,97 @@ +import time + +import tensorrt as trt +import torch +import torch.distributed as dist +import torch.nn as nn +import torch_tensorrt +from distributed_utils import initialize_distributed_env +from torch.distributed._tensor import Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, + parallelize_module, +) + +device_mesh, _world_size, _rank, logger = initialize_distributed_env( + "./tensor_parallel_simple_example" +) + +""" +This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py +""" + + +class ToyModel(nn.Module): + """MLP based model""" + + def __init__(self): + super(ToyModel, self).__init__() + self.in_proj = nn.Linear(10, 3200) + self.relu = nn.ReLU() + self.out_proj = nn.Linear(3200, 1600) + self.in_proj2 = nn.Linear(1600, 500) + self.out_proj2 = nn.Linear(500, 100) + + def forward(self, x): + x = self.out_proj(self.relu(self.in_proj(x))) + x = self.relu(x) + x = self.out_proj2(self.relu(self.in_proj2(x))) + return x + + +logger.info(f"Starting PyTorch TP example on rank {_rank}.") + +# # create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids. +tp_model = ToyModel().to("cuda") + + +# Custom parallelization plan for the model +tp_model = parallelize_module( + module=tp_model, + device_mesh=device_mesh, + parallelize_plan={ + "in_proj": ColwiseParallel(input_layouts=Shard(0)), + "out_proj": RowwiseParallel(output_layouts=Shard(0)), + "in_proj2": ColwiseParallel(input_layouts=Shard(0)), + "out_proj2": RowwiseParallel(output_layouts=Shard(0)), + }, +) +torch.manual_seed(0) +inp = torch.rand(20, 10, device="cuda") +python_result = tp_model(inp) + +backend = "torch_tensorrt" +tp_model = torch.compile( + tp_model, + backend=backend, + options={ + "truncate_long_and_double": True, + "enabled_precisions": {torch.float32, torch.float16}, + "use_python_runtime": True, + "min_block_size": 1, + "use_distributed_mode_trace": True, + }, + dynamic=None, +) + +try: + for i in range(10): + # For TP, input needs to be same across all TP ranks. + # Setting the random seed is to mimic the behavior of dataloader. + torch.manual_seed(i) + inp = torch.rand(20, 10, device="cuda") + start = time.time() + output = tp_model(inp) + end = time.time() + if i == 0: + logger.info(f"Compilation time is {end-start}") + assert ( + python_result - output + ).std() < 0.01, "Compilation result is not correct." + elif _rank == 0: + logger.info(f"Inference time is {end-start}") +finally: + # This cleans up the distributed process group + if dist.is_initialized(): + dist.destroy_process_group() diff --git a/tests/py/dynamo/distributed/test_nccl_ops.py b/tests/py/dynamo/distributed/test_nccl_ops.py new file mode 100644 index 0000000000..89c94300b7 --- /dev/null +++ b/tests/py/dynamo/distributed/test_nccl_ops.py @@ -0,0 +1,76 @@ +import os + +import torch +import torch.distributed as dist +import torch.nn as nn +from distributed_utils import set_environment_variables_pytest +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +set_environment_variables_pytest() +dist.init_process_group(backend="nccl", init_method="env://") +group = dist.new_group(ranks=[0]) +group_name = group.group_name +world_size = 1 + +from conversion.harness import DispatchTestCase + + +class TestGatherNcclOpsConverter(DispatchTestCase): + @parameterized.expand([8]) + def test_nccl_ops(self, linear_layer_dim): + class DistributedGatherModel(nn.Module): + def __init__(self, input_dim): + super().__init__() + self.fc = torch.nn.Linear(input_dim, input_dim) + + def forward(self, x): + x = self.fc(x) + gathered_tensor = torch.ops._c10d_functional.all_gather_into_tensor( + x, world_size, group_name + ) + gathered_tensor = torch.ops._c10d_functional.wait_tensor( + gathered_tensor + ) + return gathered_tensor + + inputs = [torch.randn(1, linear_layer_dim).to("cuda")] + self.run_test( + DistributedGatherModel(linear_layer_dim).cuda(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + @parameterized.expand([8]) + def test_nccl_ops_scatter(self, linear_layer_dim): + + class DistributedReduceScatterModel(nn.Module): + def __init__(self, input_dim): + super().__init__() + self.fc = torch.nn.Linear(input_dim, input_dim) + + def forward(self, x): + x = self.fc(x) + scatter_reduce_tensor = ( + torch.ops._c10d_functional.reduce_scatter_tensor( + x, "sum", world_size, group_name + ) + ) + scatter_reduce_tensor = torch.ops._c10d_functional.wait_tensor( + scatter_reduce_tensor + ) + return scatter_reduce_tensor + + inputs = [torch.zeros(1, linear_layer_dim).to("cuda")] + + self.run_test( + DistributedReduceScatterModel(linear_layer_dim).cuda(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/distributed/test_nccl_ops.sh b/tests/py/dynamo/distributed/test_nccl_ops.sh new file mode 100644 index 0000000000..dd54700048 --- /dev/null +++ b/tests/py/dynamo/distributed/test_nccl_ops.sh @@ -0,0 +1,137 @@ +#!/bin/bash + +check_command() { + command -v "$1" >/dev/null 2>&1 +} + +ensure_installed() { + local pkg="$1" + if ! check_command "$pkg"; then + echo "$pkg is not installed. Installing $pkg..." + + # Determine if sudo is needed + if check_command sudo; then + SUDO="sudo" + else + SUDO="" + fi + + # Detect OS and install accordingly + OS="$(uname -s)" + if [[ "$OS" == "Linux" ]]; then + if check_command apt-get; then + $SUDO apt-get update && $SUDO apt-get install -y "$pkg" + fi + else + echo "Unsupported OS: $OS. Please install $pkg manually." + exit 1 + fi + else + echo "$pkg is already installed." + fi +} + +ensure_mpi_installed() { + local pkg="$1" + if dpkg -l | grep -q "$pkg"; then + echo "$pkg is already installed." + else + echo "$pkg is not installed. Installing $pkg..." + + # Determine if sudo is needed + if check_command sudo; then + SUDO="sudo" + else + SUDO="" + fi + + # Detect OS and install accordingly + OS="$(uname -s)" + if [[ "$OS" == "Linux" ]]; then + if check_command apt-get; then + $SUDO apt-get update && $SUDO apt-get install -y "$pkg" + fi + else + echo "Unsupported OS: $OS. Please install $pkg manually." + exit 1 + fi + fi +} + +ensure_pytest_installed(){ + if check_command pip; then + echo "pip is installed, installing pytest..." + pip install pytest + else + echo "pip is not installed. Please install pip first." + exit 1 + fi +} + +echo "Setting up the environment" + +OS="$(uname -s)" +ARCH="$(uname -m)" + + +#getting the file name for TensorRT-LLM download +if [[ "$OS" == "Linux" && "$ARCH" == "x86_64"]]; then + FILE="tensorrt_llm-0.17.0.post1-cp312-cp312-linux_x86_64.whl" +elif [[ "$OS" == "Linux" && "$ARCH" == "aarch64"]]; then + FILE="tensorrt_llm-0.17.0.post1-cp312-cp312-linux_aarch64.whl" +else: + echo "Unsupported platform: OS=$OS ARCH=$ARCH + exit 1 +fi + +# Download the selected file +URL="https://pypi.nvidia.com/tensorrt-llm/$FILE" +echo "Downloading $FILE from $URL..." + +#Installing wget +ensure_installed wget + +#Downloading the file +filename=$(basename "$URL") +if [ -f "$filename" ]; then + echo "File already exists: $filename" +else + wget "$URL" +fi +echo "Download complete: $FILE" + +UNZIP_DIR="tensorrt_llm_unzip" +if [[ ! -d "$UNZIP_DIR" ]]; then + echo "Creating directory: $UNZIP_DIR" + mkdir -p "$UNZIP_DIR" + echo "extracting $FILE to $UNZIP_DIR ..." + #Installing unzip + ensure_installed unzip + #unzip the TensorRT-LLM package + unzip -q "$FILE" -d "$UNZIP_DIR" + echo "Unzip complete" +fi + + +export TRTLLM_PLUGINS_PATH="$(pwd)/${UNZIP_DIR}/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so" +echo ${TRTLLM_PLUGINS_PATH} + +ensure_mpi_installed libmpich-dev +ensure_mpi_installed libopenmpi-dev + +run_tests() { + cd .. + export PYTHONPATH=$(pwd) + echo "Running pytest on distributed/test_nccl_ops.py..." + pytest distributed/test_nccl_ops.py +} + +run_mpi_tests(){ + cd distributed + echo "Running test_distributed_simple_example with mpirun..."--- + mpirun -n 1 --allow-run-as-root python test_distributed_simple_example.py +} + +ensure_pytest_installed +run_tests +run_mpi_tests \ No newline at end of file diff --git a/tests/py/dynamo/models/test_dtype_support.py b/tests/py/dynamo/models/test_dtype_support.py index 146f7fdb7d..37b40574a1 100644 --- a/tests/py/dynamo/models/test_dtype_support.py +++ b/tests/py/dynamo/models/test_dtype_support.py @@ -297,7 +297,6 @@ def forward(self, x): ir="torch_compile", inputs=inputs, enabled_precisions={torch.bfloat16}, - debug=True, min_block_size=1, device=device, cache_built_engines=False, diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py index 36bf5edc95..0bc7c665b3 100644 --- a/tests/py/dynamo/models/test_engine_cache.py +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -250,6 +250,10 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", ) + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_dynamo_compile_with_custom_engine_cache(self): model = models.resnet18(pretrained=True).eval().to("cuda") @@ -314,6 +318,10 @@ def test_dynamo_compile_with_custom_engine_cache(self): for h, count in custom_engine_cache.hashes.items() ] + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_dynamo_compile_change_input_shape(self): """Runs compilation 3 times, the cache should miss each time""" model = models.resnet18(pretrained=True).eval().to("cuda") @@ -346,6 +354,10 @@ def test_dynamo_compile_change_input_shape(self): for h, count in custom_engine_cache.hashes.items() ] + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) @pytest.mark.xfail def test_torch_compile_with_default_disk_engine_cache(self): # Custom Engine Cache @@ -485,6 +497,10 @@ def test_torch_compile_with_custom_engine_cache(self): for h, count in custom_engine_cache.hashes.items() ] + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_torch_trt_compile_change_input_shape(self): # Custom Engine Cache model = models.resnet18(pretrained=True).eval().to("cuda") @@ -611,6 +627,10 @@ def forward(self, c, d): assertions.assertEqual(hash1, hash2) # @unittest.skip("benchmark on small models") + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_caching_small_model(self): from torch_tensorrt.dynamo._refit import refit_module_weights diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index a0b3292c29..b170bcc47d 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -1,3 +1,4 @@ +import importlib import os import tempfile import unittest @@ -21,7 +22,6 @@ pre_export_lowering, ) from torch_tensorrt.logging import TRT_LOGGER -from transformers import BertModel assertions = unittest.TestCase() @@ -30,6 +30,10 @@ not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_mapping(): model = models.resnet18(pretrained=False).eval().to("cuda") @@ -85,6 +89,10 @@ def test_mapping(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_one_engine_with_weightmap(): model = models.resnet18(pretrained=False).eval().to("cuda") @@ -134,6 +142,10 @@ def test_refit_one_engine_with_weightmap(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_one_engine_no_map_with_weightmap(): model = models.resnet18(pretrained=False).eval().to("cuda") @@ -184,6 +196,10 @@ def test_refit_one_engine_no_map_with_weightmap(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_one_engine_with_wrong_weightmap(): model = models.resnet18(pretrained=False).eval().to("cuda") @@ -238,8 +254,18 @@ def test_refit_one_engine_with_wrong_weightmap(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not importlib.util.find_spec("transformers"), + "transformers is required to run this test", +) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_one_engine_bert_with_weightmap(): + from transformers import BertModel + inputs = [ torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), ] @@ -293,6 +319,10 @@ def test_refit_one_engine_bert_with_weightmap(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_one_engine_inline_runtime__with_weightmap(): trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") @@ -339,6 +369,10 @@ def test_refit_one_engine_inline_runtime__with_weightmap(): torch._dynamo.reset() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_one_engine_python_runtime_with_weightmap(): model = models.resnet18(pretrained=False).eval().to("cuda") @@ -387,6 +421,10 @@ def test_refit_one_engine_python_runtime_with_weightmap(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_multiple_engine_with_weightmap(): class net(nn.Module): @@ -458,6 +496,10 @@ def forward(self, x): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_one_engine_without_weightmap(): model = models.resnet18(pretrained=True).eval().to("cuda") @@ -506,8 +548,18 @@ def test_refit_one_engine_without_weightmap(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not importlib.util.find_spec("transformers"), + "transformers is required to run this test", +) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_one_engine_bert_without_weightmap(): + from transformers import BertModel + inputs = [ torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), ] @@ -561,6 +613,10 @@ def test_refit_one_engine_bert_without_weightmap(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_one_engine_inline_runtime_without_weightmap(): trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") @@ -607,6 +663,10 @@ def test_refit_one_engine_inline_runtime_without_weightmap(): torch._dynamo.reset() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_one_engine_python_runtime_without_weightmap(): model = models.resnet18(pretrained=True).eval().to("cuda") @@ -655,6 +715,10 @@ def test_refit_one_engine_python_runtime_without_weightmap(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_multiple_engine_without_weightmap(): class net(nn.Module): @@ -722,6 +786,10 @@ def forward(self, x): torch._dynamo.reset() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_cumsum_fallback(): class net(nn.Module): @@ -747,7 +815,6 @@ def forward(self, x): exp_program, tuple(inputs), enabled_precisions={torch.float}, - debug=True, min_block_size=1, immutable_weights=False, ) diff --git a/tests/py/dynamo/models/test_modelopt_models.py b/tests/py/dynamo/models/test_modelopt_models.py new file mode 100644 index 0000000000..c2cd719bf9 --- /dev/null +++ b/tests/py/dynamo/models/test_modelopt_models.py @@ -0,0 +1,117 @@ +# type: ignore +import importlib +import platform +import unittest +from importlib import metadata + +import pytest +import torch +import torch_tensorrt as torchtrt + +from packaging.version import Version + +assertions = unittest.TestCase() + + +@unittest.skipIf( + torch.cuda.get_device_capability() < (8, 9), + "FP8 quantization requires compute capability 8.9 or later", +) +@unittest.skipIf( + not importlib.util.find_spec("modelopt"), + "ModelOpt is required to run this test", +) +@pytest.mark.unit +def test_base_fp8(): + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.utils import export_torch_mode + + class SimpleNetwork(torch.nn.Module): + def __init__(self): + super(SimpleNetwork, self).__init__() + self.linear1 = torch.nn.Linear(in_features=10, out_features=5) + self.linear2 = torch.nn.Linear(in_features=5, out_features=1) + + def forward(self, x): + x = self.linear1(x) + x = torch.nn.ReLU()(x) + x = self.linear2(x) + return x + + def calibrate_loop(model): + """Simple calibration function for testing.""" + model(input_tensor) + + input_tensor = torch.randn(1, 10).cuda() + model = SimpleNetwork().eval().cuda() + + quant_cfg = mtq.FP8_DEFAULT_CFG + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + # model has FP8 qdq nodes at this point + output_pyt = model(input_tensor) + + with torch.no_grad(): + with export_torch_mode(): + exp_program = torch.export.export(model, (input_tensor,), strict=False) + trt_model = torchtrt.dynamo.compile( + exp_program, + inputs=[input_tensor], + enabled_precisions={torch.float8_e4m3fn}, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + ) + outputs_trt = trt_model(input_tensor) + assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2) + + +@unittest.skipIf( + platform.system() != "Linux" + or not importlib.util.find_spec("modelopt") + or Version(metadata.version("nvidia-modelopt")) < Version("0.27.0"), + "modelopt 0.17.0 or later is required, Int8 quantization is supported in modelopt since 0.17.0 or later for linux", +) +@pytest.mark.unit +def test_base_int8(): + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.utils import export_torch_mode + + class SimpleNetwork(torch.nn.Module): + def __init__(self): + super(SimpleNetwork, self).__init__() + self.linear1 = torch.nn.Linear(in_features=10, out_features=5) + self.linear2 = torch.nn.Linear(in_features=5, out_features=1) + + def forward(self, x): + x = self.linear1(x) + x = torch.nn.ReLU()(x) + x = self.linear2(x) + return x + + def calibrate_loop(model): + """Simple calibration function for testing.""" + model(input_tensor) + + input_tensor = torch.randn(1, 10).cuda() + model = SimpleNetwork().eval().cuda() + + quant_cfg = mtq.INT8_DEFAULT_CFG + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + # model has INT8 qdq nodes at this point + output_pyt = model(input_tensor) + + with torchtrt.logging.debug(), torch.no_grad(): + with export_torch_mode(): + exp_program = torch.export.export(model, (input_tensor,), strict=False) + trt_model = torchtrt.dynamo.compile( + exp_program, + inputs=[input_tensor], + enabled_precisions={torch.int8}, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + truncate_double=True, + debug=True, + ) + outputs_trt = trt_model(input_tensor) + assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2) diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index b6f986711a..aa48836590 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -1,5 +1,5 @@ # type: ignore - +import importlib import unittest import pytest @@ -8,7 +8,6 @@ import torch_tensorrt as torchtrt import torchvision.models as models from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity -from transformers import BertModel assertions = unittest.TestCase() @@ -109,10 +108,16 @@ def test_efficientnet_b0(ir): @pytest.mark.unit +@unittest.skipIf( + not importlib.util.find_spec("transformers"), + "transformers is required to run this test", +) def test_bert_base_uncased(ir): + from transformers import BertModel + model = BertModel.from_pretrained("bert-base-uncased").cuda().eval() - input = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") - input2 = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") + input = torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda") + input2 = torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda") compile_spec = { "inputs": [ @@ -182,3 +187,94 @@ def test_resnet18_half(ir): # Clean up model env torch._dynamo.reset() + + +@pytest.mark.unit +def test_bf16_model(ir): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + + model = MyModule().eval().cuda().to(torch.bfloat16) + input = torch.randn((1, 3, 224, 224)).to("cuda").to(torch.bfloat16) + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.bfloat16, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float32}, + "ir": ir, + "pass_through_build_failures": True, + "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + "use_explicit_typing": True, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"BF16 model TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() + + +@pytest.mark.unit +def test_bf16_fallback_model(ir): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, padding=1, stride=1, bias=True) + self.relu = torch.nn.ReLU() + self.conv2 = torch.nn.Conv2d(16, 16, 3, padding=1, stride=1, bias=True) + + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + out = self.conv2(out) + return out + + model = MyModule().eval().cuda().to(torch.bfloat16) + input = torch.randn((1, 3, 224, 224)).to("cuda").to(torch.bfloat16) + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.bfloat16, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float32}, + "ir": ir, + "pass_through_build_failures": True, + "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + "use_explicit_typing": True, + "torch_executed_ops": {"torch.ops.aten.relu.default"}, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"BF16 fallback model TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 469ed569d1..19fdeaa9ab 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -10,7 +10,6 @@ import torch_tensorrt as torchtrt import torchvision.models as models from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity -from transformers import BertModel from packaging.version import Version @@ -114,12 +113,18 @@ def test_efficientnet_b0(ir): @pytest.mark.unit +@unittest.skipIf( + not importlib.util.find_spec("transformers"), + "transformers is required to run this test", +) def test_bert_base_uncased(ir): + from transformers import BertModel + model = ( BertModel.from_pretrained("bert-base-uncased", return_dict=False).cuda().eval() ) - input = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") - input2 = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") + input = torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda") + input2 = torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda") compile_spec = { "inputs": [ @@ -249,6 +254,7 @@ def calibrate_loop(model): @unittest.skipIf( platform.system() != "Linux" + or torch.cuda.get_device_capability() < (8, 9) or not importlib.util.find_spec("modelopt") or Version(metadata.version("nvidia-modelopt")) < Version("0.17.0"), "modelopt 0.17.0 or later is required, Int8 quantization is supported in modelopt since 0.17.0 or later for linux", @@ -257,7 +263,6 @@ def calibrate_loop(model): def test_base_int8(ir): import modelopt.torch.quantization as mtq from modelopt.torch.quantization.utils import export_torch_mode - from torch.export._trace import _export class SimpleNetwork(torch.nn.Module): def __init__(self): @@ -285,7 +290,7 @@ def calibrate_loop(model): with torch.no_grad(): with export_torch_mode(): - exp_program = _export(model, (input_tensor,)) + exp_program = torch.export.export(model, (input_tensor,)) trt_model = torchtrt.dynamo.compile( exp_program, inputs=[input_tensor], @@ -294,6 +299,7 @@ def calibrate_loop(model): debug=True, cache_built_engines=False, reuse_cached_engines=False, + truncate_double=True, ) outputs_trt = trt_model(input_tensor) assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2) diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index 0c79ba7a3f..33bf94e711 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -16,6 +16,10 @@ class TestWeightStrippedEngine(TestCase): + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_three_ways_to_compile(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) @@ -57,6 +61,10 @@ def test_three_ways_to_compile(self): gm1_output, gm2_output, 1e-2, 1e-2 ), "gm2_output is not correct" + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_three_ways_to_compile_weight_stripped_engine(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) @@ -89,6 +97,10 @@ def test_three_ways_to_compile_weight_stripped_engine(self): gm1_output.sum(), 0, msg="gm1_output should be all zeros" ) + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_weight_stripped_engine_sizes(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) @@ -126,6 +138,10 @@ def test_weight_stripped_engine_sizes(self): msg=f"Weight-stripped refit-identical engine size is not smaller than the weight included engine size. Weight included engine size: {len(bytes(weight_included_engine))}, weight-stripped refit-identical engine size: {len(bytes(weight_stripped_refit_identical_engine))}", ) + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_weight_stripped_engine_results(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) @@ -187,6 +203,10 @@ def test_weight_stripped_engine_results(self): @unittest.skip( "For now, torch-trt will save weighted engine if strip_engine_weights is False. In the near future, we plan to save weight-stripped engine regardless of strip_engine_weights, which is pending on TRT's feature development: NVBug #4914602" ) + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_engine_caching_saves_weight_stripped_engine(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) @@ -233,6 +253,10 @@ def test_engine_caching_saves_weight_stripped_engine(self): msg=f"cached engine size is not smaller than the weight included engine size. Weight included engine size: {len(bytes(weight_included_engine))}, cached stripped engine size: {len(bytes(cached_stripped_engine))}", ) + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_dynamo_compile_with_refittable_weight_stripped_engine(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) @@ -397,6 +421,10 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", ) + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_different_args_dont_share_cached_engine(self): class MyModel(torch.nn.Module): def __init__(self): @@ -446,6 +474,10 @@ def forward(self, x): msg=f"It has {len(os.listdir(engine_cache_dir))} cached engine(s) but should have 2 engines", ) + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_constant_mul_in_refitting(self): class MyModel(torch.nn.Module): def __init__(self): @@ -483,6 +515,10 @@ def forward(self, x): msg=f"TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_two_TRTRuntime_in_refitting(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) @@ -523,6 +559,10 @@ def test_two_TRTRuntime_in_refitting(self): ) @unittest.skip("Waiting for implementation") + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_refit_identical_engine_weights(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) diff --git a/tests/py/dynamo/runtime/test_002_cudagraphs_py.py b/tests/py/dynamo/runtime/test_002_cudagraphs_py.py index 0a4629644d..0c9b8bc13f 100644 --- a/tests/py/dynamo/runtime/test_002_cudagraphs_py.py +++ b/tests/py/dynamo/runtime/test_002_cudagraphs_py.py @@ -61,7 +61,6 @@ def forward(self, x): min_block_size=1, pass_through_build_failures=True, use_python_runtime=True, - debug=True, ) result_samples = [] diff --git a/tests/py/dynamo/runtime/test_003_cross_compile_for_windows.py b/tests/py/dynamo/runtime/test_003_cross_compile_for_windows.py index 867bf14bee..44a14a74de 100644 --- a/tests/py/dynamo/runtime/test_003_cross_compile_for_windows.py +++ b/tests/py/dynamo/runtime/test_003_cross_compile_for_windows.py @@ -63,3 +63,31 @@ def forward(self, a, b): ) except Exception as e: pytest.fail(f"unexpected exception raised: {e}") + + @unittest.skipIf( + platform.system() != "Linux" or platform.architecture()[0] != "64bit", + "Cross compile for windows can only be enabled on linux x86-64 platform", + ) + @pytest.mark.unit + def test_dynamo_cross_compile_for_windows_multiple_output(self): + class Add(torch.nn.Module): + def forward(self, a, b): + return torch.add(a, b), torch.add(a, b) + + model = Add().eval().cuda() + inputs = (torch.randn(2, 3).cuda(), torch.randn(2, 3).cuda()) + trt_ep_path = os.path.join(tempfile.gettempdir(), "trt.ep") + exp_program = torch.export.export(model, inputs) + compile_spec = { + "inputs": inputs, + "min_block_size": 1, + } + try: + trt_gm = torch_tensorrt.dynamo.cross_compile_for_windows( + exp_program, **compile_spec + ) + torch_tensorrt.dynamo.save_cross_compiled_exported_program( + trt_gm, file_path=trt_ep_path + ) + except Exception as e: + pytest.fail(f"unexpected exception raised: {e}") diff --git a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py index c07e04b6a4..f1af1098b1 100644 --- a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py +++ b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py @@ -74,6 +74,10 @@ def test_check_input_shape_dynamic(): ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_model_complex_dynamic_shape(): device = "cuda:0" @@ -194,6 +198,10 @@ def forward(self, a, b, c=None): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_resnet18(): torch.manual_seed(0) @@ -230,6 +238,10 @@ def test_resnet18(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_save(): torch.manual_seed(0) @@ -266,6 +278,10 @@ def test_save(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_resnet18_modify_attribute(): torch.manual_seed(0) @@ -306,6 +322,10 @@ def test_resnet18_modify_attribute(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_resnet18_modify_attribute_no_refit(): torch.manual_seed(0) @@ -353,6 +373,10 @@ def test_resnet18_modify_attribute_no_refit(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_custom_model_with_kwarg(): class net(nn.Module): @@ -420,6 +444,10 @@ def forward(self, x, b=5, c=None, d=None): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_custom_model_with_inplace_init(): class net(nn.Module): @@ -483,6 +511,10 @@ def set_weights(self): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_custom_model_with_init_recompile(): class net(nn.Module): @@ -546,6 +578,10 @@ def set_layer(self): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_custom_model_with_kwarg_different_input(): class net(nn.Module): diff --git a/tests/py/requirements.txt b/tests/py/requirements.txt index 6fb6128089..94db519d28 100644 --- a/tests/py/requirements.txt +++ b/tests/py/requirements.txt @@ -8,6 +8,6 @@ pytest>=8.2.1 pytest-xdist>=3.6.1 pyyaml timm>=1.0.3 -transformers==4.40.2 -nvidia-modelopt[deploy,hf,torch]~=0.17.0 +transformers==4.49.0 +nvidia-modelopt[all]~=0.27.0; python_version < "3.13" --extra-index-url https://pypi.nvidia.com