Skip to content

✨[Feature] Proper bfloat16 support in elementwise converter #3458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
HolyWu opened this issue Apr 1, 2025 · 1 comment
Closed

✨[Feature] Proper bfloat16 support in elementwise converter #3458

HolyWu opened this issue Apr 1, 2025 · 1 comment
Assignees
Labels
feature request New feature or request

Comments

@HolyWu
Copy link
Contributor

HolyWu commented Apr 1, 2025

from __future__ import annotations

import os

import torch
import torch_tensorrt

os.environ["CI_BUILD"] = "1"

dtype = torch.bfloat16
device = torch.device("cuda", 0)


class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x * 0.5


with torch.inference_mode():
    model = MyModule().eval().to(device, dtype)
    inputs = (torch.randn(1, 3, 224, 224, dtype=dtype, device=device),)
    exported_program = torch.export.export(model, inputs)

    trt_model = torch_tensorrt.dynamo.compile(
        exported_program,
        inputs,
        device=device,
        enabled_precisions={dtype},
        debug=True,
        min_block_size=1,
    )
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {})
    return (mul,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {})
    return (mul,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {})
    return (mul,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_nodes:Removed 0 assert_scalar nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {})
    return (mul,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_num_users_is_0_nodes:Removed ops that [num_users=0] nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {})
    return (mul,)
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {})
    return (mul,)
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.mul.Tensor: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.mul.Tensor
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.mul.Tensor + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.mul.Tensor: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.mul.Tensor
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.mul.Tensor + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Updated metadata for node: _run_on_acc_0 with its corresponding submodule outputs
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
 Input shapes: [(1, 3, 224, 224)]
 graph():
    %x : [num_users=1] = placeholder[target=x]
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {})
    return mul
WARNING:py.warnings:/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/utils.py:423: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  return torch.tensor(tensor).dtype

DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.mul.Tensor: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.mul.Tensor
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[1, 3, 224, 224], dtype=DataType.BF16]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (1, 3, 224, 224)@torch.bfloat16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /mul (kind: aten.mul.Tensor, args: ('x <Node>', '0.5 <float>'))
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.mul.Tensor: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.mul.Tensor
Traceback (most recent call last):
  File "/home/holywu/test.py", line 27, in <module>
    trt_model = torch_tensorrt.dynamo.compile(
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 693, in compile
    trt_gm = compile_module(
             ^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 897, in compile_module
    trt_module = convert_module(
                 ^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 90, in convert_module
    interpreter_result = interpret_module_to_result(
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 69, in interpret_module_to_result
    interpreter_result = interpreter.run()
                         ^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 725, in run
    self._construct_trt_network_def()
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 393, in _construct_trt_network_def
    super().run()
  File "/home/holywu/.local/lib/python3.12/site-packages/torch/fx/interpreter.py", line 171, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 784, in run_node
    trt_node: torch.fx.Node = super().run_node(n)
                              ^^^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch/fx/interpreter.py", line 240, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 891, in call_function
    return converter(self.ctx, target, args, kwargs, self._cur_node_name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 1854, in aten_ops_mul
    return impl.elementwise.mul(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py", line 473, in mul
    return convert_binary_elementwise(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py", line 128, in convert_binary_elementwise
    rhs_val = np.array([rhs_val], dtype=_enums.dtype._from(lhs_dtype).to(np.dtype))
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/_enums.py", line 426, in to
    raise TypeError("Unsupported numpy dtype")
TypeError: Unsupported numpy dtype

While executing %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {})
GraphModule: class GraphModule(torch.nn.Module):
    def forward(self, x: "bf16[1, 3, 224, 224][150528, 50176, 224, 1]"):
         # File: /home/holywu/test.py:19 in forward, code: return x * 0.5
        mul: "bf16[1, 3, 224, 224][150528, 50176, 224, 1]" = torch.ops.aten.mul.Tensor(x, 0.5);  x = None
        return mul


Original traceback:
File "/home/holywu/test.py", line 19, in forward
    return x * 0.5
@HolyWu HolyWu added the feature request New feature or request label Apr 1, 2025
@apbose
Copy link
Collaborator

apbose commented Apr 2, 2025

Reproed. Working on the fix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants