9
9
SourceIR ,
10
10
cast_trt_tensor ,
11
11
get_trt_tensor ,
12
+ set_layer_name ,
12
13
)
13
14
from torch_tensorrt .fx .types import TRTTensor
14
15
@@ -22,36 +23,45 @@ def arange(
22
23
end : Union [int , TRTTensor ],
23
24
step : Union [int , TRTTensor ],
24
25
) -> TRTTensor :
25
- if any (isinstance (tensor , TRTTensor ) for tensor in (start , end , step )):
26
+ """
27
+ Creates a sequence of values (arange) either dynamically or statically,
28
+ then outputs a TensorRT tensor.
29
+
30
+ If any of (start, end, step) is a TRT tensor, it sets up a dynamic arange
31
+ using a Fill layer. Otherwise, it creates a static NumPy array and converts
32
+ it into a TensorRT constant tensor.
33
+ """
34
+ # If any argument is a TRT tensor, use dynamic arange with a Fill layer
35
+ if any (isinstance (x , TRTTensor ) for x in (start , end , step )):
36
+ # Convert start, end, step into TRT tensors with appropriate rank
26
37
start_rank_0 = get_trt_tensor (ctx , start , name + "_start_rank_0" , min_rank = 0 )
27
38
start_rank_1 = get_trt_tensor (ctx , start , name + "_start_rank_1" , min_rank = 1 )
28
39
end = get_trt_tensor (ctx , end , name + "_end" , min_rank = 1 )
29
40
step = get_trt_tensor (ctx , step , name + "_step" , min_rank = 1 )
30
- # Calculate shape = (end-start) / step
41
+
42
+ # Compute (end - start) / step to determine the output length
31
43
shape = impl .elementwise .sub (
32
- ctx ,
33
- target ,
34
- source_ir ,
35
- name + "_sub" ,
36
- end ,
37
- start_rank_1 ,
44
+ ctx , target , source_ir , name + "_sub" , end , start_rank_1
38
45
)
39
46
shape = impl .elementwise .trunc_div (
40
- ctx ,
41
- target ,
42
- source_ir ,
43
- name + "_shape" ,
44
- shape ,
45
- step ,
47
+ ctx , target , source_ir , name + "_shape" , shape , step
46
48
)
47
49
shape = cast_trt_tensor (ctx , shape , end .dtype , name + "_shape_casted" )
50
+
51
+ # Build a Fill layer in LINSPACE mode
48
52
fill_layer = ctx .net .add_fill (
49
53
shape .shape , trt .FillOperation .LINSPACE , shape .dtype
50
54
)
51
- fill_layer .set_input (0 , shape )
52
- # Set start index
53
- fill_layer .set_input (1 , start_rank_0 )
54
- # Set delta/step
55
- fill_layer .set_input (2 , step )
55
+ fill_layer .set_input (0 , shape ) # output length
56
+ fill_layer .set_input (1 , start_rank_0 ) # start value
57
+ fill_layer .set_input (2 , step ) # step size
58
+
56
59
return fill_layer .get_output (0 )
57
- return np .arange (start , end , step )
60
+
61
+ else :
62
+ # All arguments are static, so use NumPy arange and create a TRT constant
63
+ arr = np .arange (start , end , step , dtype = np .int32 )
64
+ weights = trt .Weights (arr )
65
+ const_layer = ctx .net .add_constant (arr .shape , weights )
66
+ set_layer_name (const_layer , target , f"{ name } _arange_const" , source_ir )
67
+ return const_layer .get_output (0 )
0 commit comments