We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
from __future__ import annotations import os import torch import torch_tensorrt os.environ["CI_BUILD"] = "1" dtype = torch.bfloat16 device = torch.device("cuda", 0) class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: return x * 0.5 with torch.inference_mode(): model = MyModule().eval().to(device, dtype) inputs = (torch.randn(1, 3, 224, 224, dtype=dtype, device=device),) exported_program = torch.export.export(model, inputs) trt_model = torch_tensorrt.dynamo.compile( exported_program, inputs, device=device, enabled_precisions={dtype}, debug=True, min_block_size=1, )
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes: graph(): %x : [num_users=1] = placeholder[target=x] %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {}) return (mul,) DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph(): %x : [num_users=1] = placeholder[target=x] %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {}) return (mul,) DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding: graph(): %x : [num_users=1] = placeholder[target=x] %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {}) return (mul,) DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_nodes:Removed 0 assert_scalar nodes: graph(): %x : [num_users=1] = placeholder[target=x] %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {}) return (mul,) DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_num_users_is_0_nodes:Removed ops that [num_users=0] nodes: graph(): %x : [num_users=1] = placeholder[target=x] %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {}) return (mul,) DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph(): %x : [num_users=1] = placeholder[target=x] %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {}) return (mul,) DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.mul.Tensor: 1 DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.mul.Tensor DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner: Supported Nodes: - torch.ops.aten.mul.Tensor + Operator Count: 1 DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner: All Nodes Supported DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph. INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.mul.Tensor: 1 DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.mul.Tensor DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner: Number of TensorRT-Accelerated Engines Generated: 1 DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner: Supported Nodes: - torch.ops.aten.mul.Tensor + Operator Count: 1 DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner: All Nodes Supported DEBUG:torch_tensorrt.dynamo._compiler:Updated metadata for node: _run_on_acc_0 with its corresponding submodule outputs DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0 Input shapes: [(1, 3, 224, 224)] graph(): %x : [num_users=1] = placeholder[target=x] %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {}) return mul WARNING:py.warnings:/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/utils.py:423: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor). return torch.tensor(tensor).dtype DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.mul.Tensor: 1 DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.mul.Tensor DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ()) DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[1, 3, 224, 224], dtype=DataType.BF16] INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (1, 3, 224, 224)@torch.bfloat16)) DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /mul (kind: aten.mul.Tensor, args: ('x <Node>', '0.5 <float>')) DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.mul.Tensor: 1 DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.mul.Tensor Traceback (most recent call last): File "/home/holywu/test.py", line 27, in <module> trt_model = torch_tensorrt.dynamo.compile( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 693, in compile trt_gm = compile_module( ^^^^^^^^^^^^^^^ File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 897, in compile_module trt_module = convert_module( ^^^^^^^^^^^^^^^ File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 90, in convert_module interpreter_result = interpret_module_to_result( ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 69, in interpret_module_to_result interpreter_result = interpreter.run() ^^^^^^^^^^^^^^^^^ File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 725, in run self._construct_trt_network_def() File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 393, in _construct_trt_network_def super().run() File "/home/holywu/.local/lib/python3.12/site-packages/torch/fx/interpreter.py", line 171, in run self.env[node] = self.run_node(node) ^^^^^^^^^^^^^^^^^^^ File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 784, in run_node trt_node: torch.fx.Node = super().run_node(n) ^^^^^^^^^^^^^^^^^^^ File "/home/holywu/.local/lib/python3.12/site-packages/torch/fx/interpreter.py", line 240, in run_node return getattr(self, n.op)(n.target, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 891, in call_function return converter(self.ctx, target, args, kwargs, self._cur_node_name) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 1854, in aten_ops_mul return impl.elementwise.mul( ^^^^^^^^^^^^^^^^^^^^^ File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py", line 473, in mul return convert_binary_elementwise( ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py", line 128, in convert_binary_elementwise rhs_val = np.array([rhs_val], dtype=_enums.dtype._from(lhs_dtype).to(np.dtype)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/_enums.py", line 426, in to raise TypeError("Unsupported numpy dtype") TypeError: Unsupported numpy dtype While executing %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {}) GraphModule: class GraphModule(torch.nn.Module): def forward(self, x: "bf16[1, 3, 224, 224][150528, 50176, 224, 1]"): # File: /home/holywu/test.py:19 in forward, code: return x * 0.5 mul: "bf16[1, 3, 224, 224][150528, 50176, 224, 1]" = torch.ops.aten.mul.Tensor(x, 0.5); x = None return mul Original traceback: File "/home/holywu/test.py", line 19, in forward return x * 0.5
The text was updated successfully, but these errors were encountered:
Reproed. Working on the fix
Sorry, something went wrong.
narendasan
apbose
No branches or pull requests
The text was updated successfully, but these errors were encountered: