Skip to content

fix: Fix a perf regression due to weights being ITensors #3568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def construct_refit_mapping(
)
interpreter._construct_trt_network_def()

return interpreter.ctx.mapping
return interpreter.ctx.weight_refit_map


@needs_refit
Expand Down
6 changes: 4 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/_ConversionContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@ class ConversionContext:
net: TensorRT Network being built
compilation_settings: Settings selected by the user for compilation
requires_output_allocator: Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators)
weight_refit_map: Dictionary mapping weight names to their corresponding np.array
cpu_weights_reference_holder: Dictionary mapping weight names to their corresponding torch.Tensor
"""

net: TRTNetwork
compilation_settings: CompilationSettings = field(
default_factory=CompilationSettings
)
requires_output_allocator: bool = False
mapping: dict[str, np.array] = field(default_factory=dict)
cpu_weights_reference_holder: dict[str, Union[torch.Tensor, np.array]] = field(
weight_refit_map: dict[str, np.array] = field(default_factory=dict)
cpu_weights_reference_holder: dict[str, Union[torch.Tensor]] = field(
default_factory=dict
)

Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def _save_weight_mapping(self) -> None:
for k, v in self.module.state_dict().items()
}
weight_name_map: dict[str, Any] = {}
np_map = self.ctx.mapping
np_map = self.ctx.weight_refit_map
constant_mapping = {k: v for k, v in np_map.items() if v.size == 1}
net = self.ctx.net
for i in range(net.num_layers):
Expand Down
64 changes: 42 additions & 22 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,37 @@ def cast_int_or_float_to_bool(
return tensor


def to_trt_weights(
value: Any, target_quantized_type: Optional[trt.DataType] = None
) -> trt.Weights:
"""
Convert a PyTorch tensor or NumPy array to TensorRT weights.

Args:
value (Union[torch.Tensor, np.ndarray]): The tensor or array to convert to TRT weights

Returns:
trt.Weights: TensorRT weights object with appropriate data type

Note:
- Input tensors are made contiguous before conversion
- Data type is preserved from the original tensor/array
"""
if isinstance(value, torch.Tensor):
# Tensor must be contiguous before conversion
value = value.contiguous()
value_trt_dtype = _enums.dtype._from(value.dtype).to(trt.DataType)
return trt.Weights(value_trt_dtype, value.data_ptr(), value.nelement())
elif isinstance(value, np.ndarray):
value = np.ascontiguousarray(value)
value_np_dtype = _enums.dtype._from(value.dtype).to(np.dtype, use_default=True)
return trt.Weights(value_np_dtype, value.data, value.size)
else:
raise AssertionError(
f"to_trt_weights can only be called on torch.Tensor or np.ndarray, got an object of type: {type(value)}"
)


def create_constant(
ctx: ConversionContext,
value: Union[int, float, bool, np.ndarray, torch.Tensor],
Expand Down Expand Up @@ -363,19 +394,6 @@ def create_constant(
shape = list(torch_value.shape)

if torch_value is not None:
if torch_value.dtype == torch.float8_e4m3fn:
weights = trt.Weights(
type=trt.DataType.FP8,
ptr=torch_value.data_ptr(),
count=torch_value.numel(),
)
constant = ctx.net.add_constant(
shape,
weights,
)
constant.name = name
ctx.cpu_weights_reference_holder[name + " FP8_CONSTANT"] = torch_value
return constant.get_output(0)

if torch_value.dtype == torch.uint8:
if (
Expand All @@ -400,25 +418,27 @@ def create_constant(
ctx.cpu_weights_reference_holder[name + " FP4_CONSTANT"] = torch_value
return constant.get_output(0)

# TODO: Refit map uses numpy arrays. Remove this once refit is updated to use torch.Tensor
if torch_value.dtype == torch.bfloat16:
torch_value_fp32 = torch_value.to(torch.float32)
numpy_value = torch_value_fp32.numpy()
else:
numpy_value = torch_value.numpy()
ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1)

# Used for refit
ctx.weight_refit_map[name + " CONSTANT"] = numpy_value.reshape(-1)

# This is a buffer to hold the torch.Tensor so that they are alive during the course of TRT compilation.
ctx.cpu_weights_reference_holder[name] = torch_value

# Convert the torch.Tensor to a trt.Weights object
trt_weights = to_trt_weights(torch_value)
constant = ctx.net.add_constant(
shape,
numpy_value,
trt_weights,
)
constant.name = name

if torch_value.dtype == torch.bfloat16:
return cast_trt_tensor(
ctx,
constant.get_output(0),
trt.DataType.BF16,
name + "_bf16_cast",
)
return constant.get_output(0)
else:
raise ValueError(
Expand Down
23 changes: 18 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/impl/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
cast_trt_tensor,
extend_attr_to_tuple,
get_trt_tensor,
has_dynamic_shape,
set_layer_name,
to_torch,
to_trt_weights,
)
from torch_tensorrt.fx.converters.converter_utils import (
get_dyn_range,
has_dynamic_shape,
mark_as_int8_layer,
set_layer_name,
)
from torch_tensorrt.fx.types import TRTTensor

Expand Down Expand Up @@ -64,6 +65,8 @@ def convNd(
f"Convolution {name} has bias of type {type(bias)}, Expected Torch Tensor or TRT Tensor"
)

num_output_maps = 0
kernel_shape = ()
# Process weight terms
if isinstance(weight, TRTTensor):
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
Expand All @@ -72,23 +75,33 @@ def convNd(
weight = impl.unsqueeze.unsqueeze(
ctx, target, source_ir, weight.name + "_unsqueeze_conv1d", weight, -1
)
num_output_maps = weight.shape[0]
kernel_shape = weight.shape[2:]
elif isinstance(weight, (torch.Tensor, np.ndarray)):
weight = to_torch(weight, dtype=input.dtype)
# Append new dimension (unsqueeze) if the convolution is 1d
if is_conv1d:
weight = torch.unsqueeze(weight, -1)
weight = get_trt_tensor(ctx, weight, f"{name}_weight")

num_output_maps = weight.shape[0]
kernel_shape = weight.shape[2:]
weight = to_trt_weights(weight)

else:
raise RuntimeError(
f"Convolution {name} has weight of type {type(weight)}, Expect Optional[Tensor]"
)

assert (
num_output_maps > 0
), "Number of output channels in convolution must be greater than 0"
assert len(kernel_shape) > 0, "Convolution kernel shape must be non-empty"

# add conv layer
conv_layer = ctx.net.add_convolution_nd(
input=input,
num_output_maps=weight.shape[0],
kernel_shape=weight.shape[2:],
num_output_maps=num_output_maps,
kernel_shape=kernel_shape,
kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight,
bias=trt.Weights() if isinstance(bias, TRTTensor) else bias,
)
Expand Down
26 changes: 19 additions & 7 deletions py/torch_tensorrt/dynamo/conversion/impl/deconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
SourceIR,
extend_attr_to_tuple,
get_trt_tensor,
has_dynamic_shape,
to_torch,
to_trt_weights,
)
from torch_tensorrt.fx.converters.converter_utils import (
SourceIR,
get_dyn_range,
has_dynamic_shape,
mark_as_int8_layer,
set_layer_name,
)
Expand All @@ -40,6 +41,7 @@ def deconvNd(
scale: Optional[Union[torch.Tensor, float]] = None,
zero_point: Optional[Union[torch.Tensor, float]] = None,
) -> TRTTensor:

if has_dynamic_shape(input.shape):
assert input.shape[1] != -1, "Channel dim can't be dynamic for deconvolution."

Expand All @@ -64,32 +66,42 @@ def deconvNd(
)

# Process weight terms
num_output_maps = 0
kernel_shape = ()
if isinstance(weight, TRTTensor):
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
# Append new dimension (unsqueeze) if the deconvolution is 1d
if is_deconv1d:
input = impl.unsqueeze.unsqueeze(
ctx, target, source_ir, name + "_unsqueeze_weight", weight, -1
)
num_output_maps = weight.shape[1]
kernel_shape = weight.shape[2:]

elif isinstance(weight, (torch.Tensor, np.ndarray)):
weight = to_torch(weight, dtype=input.dtype)
# Append new dimension (unsqueeze) if the deconvolution is 1d
if is_deconv1d:
weight = torch.unsqueeze(weight, -1)

weight = get_trt_tensor(ctx, weight, f"{name}_weight")
num_output_maps = weight.shape[1]
kernel_shape = weight.shape[2:]
weight = to_trt_weights(weight)

else:
raise RuntimeError(
f"Convolution {name} has weight of type {type(weight)}, Expect Optional[Tensor]"
f"Deconvolution {name} has weight of type {type(weight)}, Expect Optional[Tensor]"
)

assert (
num_output_maps > 0
), "Number of output channels in deconvolution must be greater than 0"
assert len(kernel_shape) > 0, "Deconvolution kernel shape must be non-empty"

# add deconv layer
deconv_layer = ctx.net.add_deconvolution_nd(
input=input,
num_output_maps=weight.shape[1] * groups,
kernel_shape=weight.shape[2:],
num_output_maps=num_output_maps * groups,
kernel_shape=kernel_shape,
kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight,
bias=trt.Weights() if isinstance(bias, TRTTensor) else bias,
)
Expand Down
Loading