diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 7be7e0f16c..23648facaf 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -78,7 +78,7 @@ def construct_refit_mapping( ) interpreter._construct_trt_network_def() - return interpreter.ctx.mapping + return interpreter.ctx.weight_refit_map @needs_refit diff --git a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py index 1c4926bcfa..0bc15a36d1 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py @@ -15,6 +15,8 @@ class ConversionContext: net: TensorRT Network being built compilation_settings: Settings selected by the user for compilation requires_output_allocator: Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators) + weight_refit_map: Dictionary mapping weight names to their corresponding np.array + cpu_weights_reference_holder: Dictionary mapping weight names to their corresponding torch.Tensor """ net: TRTNetwork @@ -22,8 +24,8 @@ class ConversionContext: default_factory=CompilationSettings ) requires_output_allocator: bool = False - mapping: dict[str, np.array] = field(default_factory=dict) - cpu_weights_reference_holder: dict[str, Union[torch.Tensor, np.array]] = field( + weight_refit_map: dict[str, np.array] = field(default_factory=dict) + cpu_weights_reference_holder: dict[str, Union[torch.Tensor]] = field( default_factory=dict ) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index bb1a77b4eb..dd6bf346f8 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -499,7 +499,7 @@ def _save_weight_mapping(self) -> None: for k, v in self.module.state_dict().items() } weight_name_map: dict[str, Any] = {} - np_map = self.ctx.mapping + np_map = self.ctx.weight_refit_map constant_mapping = {k: v for k, v in np_map.items() if v.size == 1} net = self.ctx.net for i in range(net.num_layers): diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index b5b7cce868..2df2f0f31b 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -320,6 +320,37 @@ def cast_int_or_float_to_bool( return tensor +def to_trt_weights( + value: Any, target_quantized_type: Optional[trt.DataType] = None +) -> trt.Weights: + """ + Convert a PyTorch tensor or NumPy array to TensorRT weights. + + Args: + value (Union[torch.Tensor, np.ndarray]): The tensor or array to convert to TRT weights + + Returns: + trt.Weights: TensorRT weights object with appropriate data type + + Note: + - Input tensors are made contiguous before conversion + - Data type is preserved from the original tensor/array + """ + if isinstance(value, torch.Tensor): + # Tensor must be contiguous before conversion + value = value.contiguous() + value_trt_dtype = _enums.dtype._from(value.dtype).to(trt.DataType) + return trt.Weights(value_trt_dtype, value.data_ptr(), value.nelement()) + elif isinstance(value, np.ndarray): + value = np.ascontiguousarray(value) + value_np_dtype = _enums.dtype._from(value.dtype).to(np.dtype, use_default=True) + return trt.Weights(value_np_dtype, value.data, value.size) + else: + raise AssertionError( + f"to_trt_weights can only be called on torch.Tensor or np.ndarray, got an object of type: {type(value)}" + ) + + def create_constant( ctx: ConversionContext, value: Union[int, float, bool, np.ndarray, torch.Tensor], @@ -363,19 +394,6 @@ def create_constant( shape = list(torch_value.shape) if torch_value is not None: - if torch_value.dtype == torch.float8_e4m3fn: - weights = trt.Weights( - type=trt.DataType.FP8, - ptr=torch_value.data_ptr(), - count=torch_value.numel(), - ) - constant = ctx.net.add_constant( - shape, - weights, - ) - constant.name = name - ctx.cpu_weights_reference_holder[name + " FP8_CONSTANT"] = torch_value - return constant.get_output(0) if torch_value.dtype == torch.uint8: if ( @@ -400,25 +418,27 @@ def create_constant( ctx.cpu_weights_reference_holder[name + " FP4_CONSTANT"] = torch_value return constant.get_output(0) + # TODO: Refit map uses numpy arrays. Remove this once refit is updated to use torch.Tensor if torch_value.dtype == torch.bfloat16: torch_value_fp32 = torch_value.to(torch.float32) numpy_value = torch_value_fp32.numpy() else: numpy_value = torch_value.numpy() - ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1) + + # Used for refit + ctx.weight_refit_map[name + " CONSTANT"] = numpy_value.reshape(-1) + + # This is a buffer to hold the torch.Tensor so that they are alive during the course of TRT compilation. + ctx.cpu_weights_reference_holder[name] = torch_value + + # Convert the torch.Tensor to a trt.Weights object + trt_weights = to_trt_weights(torch_value) constant = ctx.net.add_constant( shape, - numpy_value, + trt_weights, ) constant.name = name - if torch_value.dtype == torch.bfloat16: - return cast_trt_tensor( - ctx, - constant.get_output(0), - trt.DataType.BF16, - name + "_bf16_cast", - ) return constant.get_output(0) else: raise ValueError( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/conv.py b/py/torch_tensorrt/dynamo/conversion/impl/conv.py index f27fb13e97..4d9573addf 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/conv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/conv.py @@ -13,13 +13,14 @@ cast_trt_tensor, extend_attr_to_tuple, get_trt_tensor, + has_dynamic_shape, + set_layer_name, to_torch, + to_trt_weights, ) from torch_tensorrt.fx.converters.converter_utils import ( get_dyn_range, - has_dynamic_shape, mark_as_int8_layer, - set_layer_name, ) from torch_tensorrt.fx.types import TRTTensor @@ -64,6 +65,8 @@ def convNd( f"Convolution {name} has bias of type {type(bias)}, Expected Torch Tensor or TRT Tensor" ) + num_output_maps = 0 + kernel_shape = () # Process weight terms if isinstance(weight, TRTTensor): weight = get_trt_tensor(ctx, weight, f"{name}_weight") @@ -72,23 +75,33 @@ def convNd( weight = impl.unsqueeze.unsqueeze( ctx, target, source_ir, weight.name + "_unsqueeze_conv1d", weight, -1 ) + num_output_maps = weight.shape[0] + kernel_shape = weight.shape[2:] elif isinstance(weight, (torch.Tensor, np.ndarray)): weight = to_torch(weight, dtype=input.dtype) # Append new dimension (unsqueeze) if the convolution is 1d if is_conv1d: weight = torch.unsqueeze(weight, -1) - weight = get_trt_tensor(ctx, weight, f"{name}_weight") + + num_output_maps = weight.shape[0] + kernel_shape = weight.shape[2:] + weight = to_trt_weights(weight) else: raise RuntimeError( f"Convolution {name} has weight of type {type(weight)}, Expect Optional[Tensor]" ) + assert ( + num_output_maps > 0 + ), "Number of output channels in convolution must be greater than 0" + assert len(kernel_shape) > 0, "Convolution kernel shape must be non-empty" + # add conv layer conv_layer = ctx.net.add_convolution_nd( input=input, - num_output_maps=weight.shape[0], - kernel_shape=weight.shape[2:], + num_output_maps=num_output_maps, + kernel_shape=kernel_shape, kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight, bias=trt.Weights() if isinstance(bias, TRTTensor) else bias, ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py index 629cecf5db..bc796deab5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py @@ -9,14 +9,15 @@ from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( + SourceIR, extend_attr_to_tuple, get_trt_tensor, + has_dynamic_shape, to_torch, + to_trt_weights, ) from torch_tensorrt.fx.converters.converter_utils import ( - SourceIR, get_dyn_range, - has_dynamic_shape, mark_as_int8_layer, set_layer_name, ) @@ -40,6 +41,7 @@ def deconvNd( scale: Optional[Union[torch.Tensor, float]] = None, zero_point: Optional[Union[torch.Tensor, float]] = None, ) -> TRTTensor: + if has_dynamic_shape(input.shape): assert input.shape[1] != -1, "Channel dim can't be dynamic for deconvolution." @@ -64,6 +66,8 @@ def deconvNd( ) # Process weight terms + num_output_maps = 0 + kernel_shape = () if isinstance(weight, TRTTensor): weight = get_trt_tensor(ctx, weight, f"{name}_weight") # Append new dimension (unsqueeze) if the deconvolution is 1d @@ -71,25 +75,33 @@ def deconvNd( input = impl.unsqueeze.unsqueeze( ctx, target, source_ir, name + "_unsqueeze_weight", weight, -1 ) + num_output_maps = weight.shape[1] + kernel_shape = weight.shape[2:] elif isinstance(weight, (torch.Tensor, np.ndarray)): weight = to_torch(weight, dtype=input.dtype) # Append new dimension (unsqueeze) if the deconvolution is 1d if is_deconv1d: weight = torch.unsqueeze(weight, -1) - - weight = get_trt_tensor(ctx, weight, f"{name}_weight") + num_output_maps = weight.shape[1] + kernel_shape = weight.shape[2:] + weight = to_trt_weights(weight) else: raise RuntimeError( - f"Convolution {name} has weight of type {type(weight)}, Expect Optional[Tensor]" + f"Deconvolution {name} has weight of type {type(weight)}, Expect Optional[Tensor]" ) + assert ( + num_output_maps > 0 + ), "Number of output channels in deconvolution must be greater than 0" + assert len(kernel_shape) > 0, "Deconvolution kernel shape must be non-empty" + # add deconv layer deconv_layer = ctx.net.add_deconvolution_nd( input=input, - num_output_maps=weight.shape[1] * groups, - kernel_shape=weight.shape[2:], + num_output_maps=num_output_maps * groups, + kernel_shape=kernel_shape, kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight, bias=trt.Weights() if isinstance(bias, TRTTensor) else bias, )