Skip to content

Commit 1f8a158

Browse files
authored
Support broadcast index put (#3421)
1 parent f8e285d commit 1f8a158

File tree

4 files changed

+361
-83
lines changed

4 files changed

+361
-83
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

-35
Original file line numberDiff line numberDiff line change
@@ -810,43 +810,8 @@ def aten_ops_select(
810810
)
811811

812812

813-
def index_put_validator(
814-
node: Node, settings: Optional[CompilationSettings] = None
815-
) -> bool:
816-
if args_bounds_check(node.args, 3, False): # Check if accumulate is valid
817-
_LOGGER.debug("We do not support accumulate=True for aten.index_put operation")
818-
accumulate_valid = False
819-
else:
820-
accumulate_valid = True
821-
822-
# Retrieve input tensor's meta information
823-
input_meta = node.args[0].meta.get("tensor_meta")
824-
if not input_meta:
825-
_LOGGER.warning(
826-
"Meta information of input is missing. Unable to validate if broadcasting is needed, falling back to PyTorch operation."
827-
)
828-
return False
829-
830-
input_shape = input_meta.shape
831-
input_num_dims = len(input_shape)
832-
833-
# Check if broadcasting is valid
834-
indices_num_dims = len(node.args[1])
835-
if indices_num_dims == input_num_dims:
836-
broadcast_valid = True
837-
else:
838-
_LOGGER.debug(
839-
"We do not support broadcasting when the number of index dimensions does not match the number of input tensor dimensions."
840-
)
841-
broadcast_valid = False
842-
843-
# Return validation result
844-
return accumulate_valid and broadcast_valid
845-
846-
847813
@dynamo_tensorrt_converter(
848814
torch.ops.aten.index_put.default,
849-
capability_validator=index_put_validator,
850815
)
851816
@enforce_tensor_types(
852817
{

py/torch_tensorrt/dynamo/conversion/impl/arange.py

+30-20
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
SourceIR,
1010
cast_trt_tensor,
1111
get_trt_tensor,
12+
set_layer_name,
1213
)
1314
from torch_tensorrt.fx.types import TRTTensor
1415

@@ -22,36 +23,45 @@ def arange(
2223
end: Union[int, TRTTensor],
2324
step: Union[int, TRTTensor],
2425
) -> TRTTensor:
25-
if any(isinstance(tensor, TRTTensor) for tensor in (start, end, step)):
26+
"""
27+
Creates a sequence of values (arange) either dynamically or statically,
28+
then outputs a TensorRT tensor.
29+
30+
If any of (start, end, step) is a TRT tensor, it sets up a dynamic arange
31+
using a Fill layer. Otherwise, it creates a static NumPy array and converts
32+
it into a TensorRT constant tensor.
33+
"""
34+
# If any argument is a TRT tensor, use dynamic arange with a Fill layer
35+
if any(isinstance(x, TRTTensor) for x in (start, end, step)):
36+
# Convert start, end, step into TRT tensors with appropriate rank
2637
start_rank_0 = get_trt_tensor(ctx, start, name + "_start_rank_0", min_rank=0)
2738
start_rank_1 = get_trt_tensor(ctx, start, name + "_start_rank_1", min_rank=1)
2839
end = get_trt_tensor(ctx, end, name + "_end", min_rank=1)
2940
step = get_trt_tensor(ctx, step, name + "_step", min_rank=1)
30-
# Calculate shape = (end-start) / step
41+
42+
# Compute (end - start) / step to determine the output length
3143
shape = impl.elementwise.sub(
32-
ctx,
33-
target,
34-
source_ir,
35-
name + "_sub",
36-
end,
37-
start_rank_1,
44+
ctx, target, source_ir, name + "_sub", end, start_rank_1
3845
)
3946
shape = impl.elementwise.trunc_div(
40-
ctx,
41-
target,
42-
source_ir,
43-
name + "_shape",
44-
shape,
45-
step,
47+
ctx, target, source_ir, name + "_shape", shape, step
4648
)
4749
shape = cast_trt_tensor(ctx, shape, end.dtype, name + "_shape_casted")
50+
51+
# Build a Fill layer in LINSPACE mode
4852
fill_layer = ctx.net.add_fill(
4953
shape.shape, trt.FillOperation.LINSPACE, shape.dtype
5054
)
51-
fill_layer.set_input(0, shape)
52-
# Set start index
53-
fill_layer.set_input(1, start_rank_0)
54-
# Set delta/step
55-
fill_layer.set_input(2, step)
55+
fill_layer.set_input(0, shape) # output length
56+
fill_layer.set_input(1, start_rank_0) # start value
57+
fill_layer.set_input(2, step) # step size
58+
5659
return fill_layer.get_output(0)
57-
return np.arange(start, end, step)
60+
61+
else:
62+
# All arguments are static, so use NumPy arange and create a TRT constant
63+
arr = np.arange(start, end, step, dtype=np.int32)
64+
weights = trt.Weights(arr)
65+
const_layer = ctx.net.add_constant(arr.shape, weights)
66+
set_layer_name(const_layer, target, f"{name}_arange_const", source_ir)
67+
return const_layer.get_output(0)

0 commit comments

Comments
 (0)