Skip to content

Commit 8e69c3b

Browse files
csarofeenfacebook-github-bot
authored andcommitted
[nvFuser] Reduction support in codegen, fp16 support (pytorch#38627)
Summary: Adds reduction support for the code generator. Reductions are fully supported with split/merge/reorder/rfactor/computeAt/unroll operators. There is also cross thread (intra-block) reduction support. The two remaining pieces missing for reduction support is: - Safety: If cross thread reduction was used, child operators shouldn't be able to bind that thread dim anymore - Cross block reduction: we will want inter-block reduction support to match parity with tensor iterator PR also provides FP16 support for fusions now. We insert casts on FP16 inputs to FP32, and we insert casts to FP16 on FP16 outputs. Also working towards reductions and shape inference for reductions in the fusion pass. Pull Request resolved: pytorch#38627 Reviewed By: albanD Differential Revision: D21663196 Pulled By: soumith fbshipit-source-id: 3ff2df563f86c39cd5821ab9c1148149e5172a9e
1 parent d3b0cf9 commit 8e69c3b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+5312
-1769
lines changed

aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@ namespace at { namespace cuda {
4848
_(cuLaunchKernel) \
4949
_(cuCtxGetCurrent) \
5050
_(cuModuleUnload) \
51-
_(cuDevicePrimaryCtxGetState)
51+
_(cuDevicePrimaryCtxGetState) \
52+
_(cuLinkCreate) \
53+
_(cuLinkAddData) \
54+
_(cuLinkComplete)
5255

5356
#else
5457

caffe2/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
429429
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/tensor_view.cpp
430430
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/transform_iter.cpp
431431
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/transform_replay.cpp
432+
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/transform_rfactor.cpp
432433
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/type.cpp
433434
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/utils.cpp
434435
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/register_interface.cpp

test/cpp/jit/test_gpu.cpp

Lines changed: 814 additions & 283 deletions
Large diffs are not rendered by default.

test/cpp/jit/tests.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@ namespace jit {
113113
_(GPU_FusionTVReorder) \
114114
_(GPU_FusionEquality) \
115115
_(GPU_FusionReplaceAll) \
116+
_(GPU_FusionParser) \
116117
_(GPU_FusionDependency) \
117118
_(GPU_FusionCodeGen) \
118119
_(GPU_FusionCodeGen2) \
119-
_(GPU_FusionCodeGen3) \
120120
_(GPU_FusionSimplePWise) \
121121
_(GPU_FusionExecKernel) \
122122
_(GPU_FusionForLoop) \
@@ -125,8 +125,11 @@ namespace jit {
125125
_(GPU_FusionBinaryOps) \
126126
_(GPU_FusionTernaryOps) \
127127
_(GPU_FusionCompoundOps) \
128-
_(GPU_FusionCastOps)
129-
//_(GPU_FusionCodeGen4)
128+
_(GPU_FusionAdvancedComputeAt) \
129+
_(GPU_FusionScalarInputs) \
130+
_(GPU_FusionRFactorReplay) \
131+
_(GPU_FusionReduction) \
132+
_(GPU_FusionReduction2)
130133
#else
131134
#define TH_FORALL_TESTS_CUDA(_) \
132135
_(ArgumentSpec) \

test/test_jit_cuda_fuser.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
from __future__ import unicode_literals
55

66
import unittest
7+
import os
78

89
import torch
910

1011
from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, skipIfRocm
12+
from torch.testing._internal.codegen.random_topo_test import runDefaultTestWithSeed
1113

1214
from test_jit import JitTestCase, RUN_CUDA
1315

@@ -52,6 +54,31 @@ def _has_cuda_fusion_group(self, graph):
5254
has_cuda_fusion_group = True
5355
return has_cuda_fusion_group
5456

57+
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
58+
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
59+
@skipIfRocm
60+
def test_half(self):
61+
def t(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor, alpha : float):
62+
o_16 = torch.add(x, y)
63+
o_32_a = torch.add(y, z, alpha=alpha)
64+
o_32_b = torch.add(o_16, z)
65+
return (o_16, o_32_a, o_32_b)
66+
67+
t_jit = torch.jit.script(t)
68+
alpha = 0.5
69+
# stick to integers, this avoid the numerical difference due to our
70+
# promotion
71+
x = torch.randint(0, 256, (4, 8)).to(dtype=torch.float16, device="cuda")
72+
y = torch.randint(0, 256, (4, 8)).to(dtype=torch.float16, device="cuda")
73+
z = torch.randint(0, 256, (4, 8)).to(dtype=torch.float16, device="cuda")
74+
jit_o = t_jit(x, y, z, alpha)
75+
jit_o = t_jit(x, y, z, alpha)
76+
o = t(x, y, z, alpha)
77+
for oo, jit_oo in zip(o, jit_o):
78+
self.assertEqual(oo.dtype, jit_oo.dtype)
79+
self.assertEqual(oo, jit_oo)
80+
self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, z, alpha)))
81+
5582
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
5683
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
5784
@skipIfRocm
@@ -149,7 +176,6 @@ def t(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor):
149176
# Currently cannot fuse this
150177
self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, z)))
151178

152-
@unittest.skipIf(True, "temporary disable for buggy codegen")
153179
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
154180
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
155181
@skipIfRocm
@@ -299,5 +325,36 @@ def where(x : torch.Tensor, y : torch.Tensor, cond : torch.Tensor):
299325
where_jit = torch.jit.script(where)
300326
self._run_helper(where_jit, where, True, x, y, cond)
301327

328+
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
329+
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Requires profiling node to run cuda fuser")
330+
@skipIfRocm
331+
def test_dynamic_size(self):
332+
def t(x : torch.Tensor, y : torch.Tensor, z : float):
333+
o = x + y
334+
o = o + z
335+
return o
336+
t_jit = torch.jit.script(t)
337+
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
338+
y = torch.randn(32, 32, dtype=torch.float, device="cuda")
339+
jit_o = t_jit(x, y, 2.0)
340+
jit_o = t_jit(x, y, 2.0)
341+
o = t(x, y, 2.0)
342+
self.assertEqual(o, jit_o)
343+
self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, 2.0)))
344+
x = torch.randn(8, 32, 16, 8, dtype=torch.float, device="cuda")
345+
y = torch.randn(16, 8, dtype=torch.float, device="cuda")
346+
jit_o = t_jit(x, y, 2.0)
347+
o = t(x, y, 2.0)
348+
self.assertEqual(o, jit_o)
349+
x = torch.randn(8, 17, 8, dtype=torch.float, device="cuda")
350+
y = torch.randn(8, 17, 1, dtype=torch.float, device="cuda")
351+
jit_o = t_jit(x, y, 2.0)
352+
353+
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
354+
@skipIfRocm
355+
def test_random_topo(self):
356+
os.environ["PYTORCH_CUDA_FUSER_DISABLE_FALLBACK"] = "1"
357+
self.assertTrue(runDefaultTestWithSeed(28449))
358+
302359
if __name__ == '__main__':
303360
run_tests()

tools/build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ libtorch_cuda_sources = [
314314
"torch/csrc/jit/codegen/cuda/tensor_view.cpp",
315315
"torch/csrc/jit/codegen/cuda/transform_iter.cpp",
316316
"torch/csrc/jit/codegen/cuda/transform_replay.cpp",
317+
"torch/csrc/jit/codegen/cuda/transform_rfactor.cpp",
317318
"torch/csrc/jit/codegen/cuda/type.cpp",
318319
"torch/csrc/jit/codegen/cuda/utils.cpp",
319320
"torch/csrc/jit/codegen/cuda/register_interface.cpp",

torch/csrc/jit/codegen/cuda/arith.cpp

Lines changed: 114 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <torch/csrc/jit/codegen/cuda/arith.h>
22
#include <c10/util/Exception.h>
3-
#include <torch/csrc/jit/codegen/cuda/ir_internal_nodes.h>
3+
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
4+
#include <torch/csrc/jit/codegen/cuda/type.h>
45

56
namespace torch {
67
namespace jit {
@@ -61,12 +62,42 @@ TORCH_CUDA_API Val* promoteNew(Val* v1, Val* v2) {
6162
return newValLike(v1, out_dtype);
6263
}
6364

65+
Val* newConstScalar(DataType dtype, int val) {
66+
switch (dtype) {
67+
case (DataType::Int):
68+
return new Int(val);
69+
default:
70+
break;
71+
}
72+
TORCH_CHECK(
73+
false,
74+
"Could not generate a new Scalar with data type ",
75+
dtype,
76+
"and constant value: ",
77+
val);
78+
}
79+
80+
Val* newConstScalar(DataType dtype, float val) {
81+
switch (dtype) {
82+
case (DataType::Float):
83+
return new Float(val);
84+
default:
85+
break;
86+
}
87+
TORCH_CHECK(
88+
false,
89+
"Could not generate a new Scalar with data type ",
90+
dtype,
91+
"and constant value: ",
92+
val);
93+
}
94+
6495
TORCH_CUDA_API Val* castOp(DataType dtype, Val* v1) {
6596
if (v1->getDataType().value() == dtype)
6697
return v1;
6798

68-
auto uop_type = cast_type(v1->getDataType().value(), dtype);
69-
if (uop_type == c10::nullopt) {
99+
if (cast_func_str(std::make_pair(v1->getDataType().value(), dtype)) ==
100+
c10::nullopt) {
70101
TORCH_CHECK(
71102
false,
72103
"Illegal Cast value from DataType: ",
@@ -76,16 +107,20 @@ TORCH_CUDA_API Val* castOp(DataType dtype, Val* v1) {
76107
}
77108

78109
Val* out = newValLike(v1, dtype);
79-
Statement* expr = new UnaryOp(uop_type.value(), out, v1);
110+
new UnaryOp(UnaryOpType::Cast, out, v1);
80111
return out;
81112
}
82113

114+
// UNARY OPERATIONS
115+
83116
TORCH_CUDA_API Val* unaryOp(UnaryOpType type, Val* v1) {
84117
Val* out = newValLike(v1);
85-
Statement* expr = new UnaryOp(type, out, v1);
118+
new UnaryOp(type, out, v1);
86119
return out;
87120
}
88121

122+
// BINARY OPERATIONS
123+
89124
TORCH_CUDA_API Val* binaryOp(BinaryOpType type, Val* v1, Val* v2) {
90125
Val* out = promoteNew(v1, v2);
91126
if (is_logical_op(type)) {
@@ -95,7 +130,7 @@ TORCH_CUDA_API Val* binaryOp(BinaryOpType type, Val* v1, Val* v2) {
95130
if (out->getDataType().value() != DataType::Int)
96131
out = newValLike(out, DataType::Int);
97132
}
98-
Statement* expr = new BinaryOp(type, out, v1, v2);
133+
new BinaryOp(type, out, v1, v2);
99134
return out;
100135
}
101136

@@ -139,6 +174,72 @@ TORCH_CUDA_API Val* andOp(Val* v1, Val* v2) {
139174
return binaryOp(BinaryOpType::And, v1, v2);
140175
}
141176

177+
// REDUCTION OPERATIONS
178+
179+
Val* reductionOp(
180+
BinaryOpType reduction_op_type,
181+
const std::vector<int>& axes,
182+
Val* init,
183+
Val* v1) {
184+
TORCH_CHECK(
185+
v1->getValType().value() == ValType::TensorView,
186+
"Cannot reduce on values that are not TensorViews, but recieved type ",
187+
v1->getValType().value());
188+
189+
TORCH_CHECK(
190+
init->isConstScalar(),
191+
"Cannot create a reduction operation where the initial value is not a const scalar.");
192+
193+
TensorView* tv = static_cast<TensorView*>(v1);
194+
195+
TORCH_CHECK(
196+
tv->getRootDomain() == tv->domain(),
197+
"Reducing a tensor once it's gone under transformations is not permitted at this time. Please set reductions before calling split/merge/reorder/computeAt.");
198+
199+
std::vector<unsigned int> uint_axes;
200+
for (int axis : axes) {
201+
if (axis < 0)
202+
axis += int(tv->nDims());
203+
204+
TORCH_CHECK(
205+
axis >= 0 && (unsigned int)axis < tv->nDims(),
206+
"Reduction on invalid axis, recieved: ",
207+
axis,
208+
" however tensor view only has ",
209+
tv->nDims(),
210+
" dims.");
211+
212+
uint_axes.push_back((unsigned int)axis);
213+
}
214+
215+
Val* out = tv->newForReduction(uint_axes);
216+
if (init->getDataType().value() != v1->getDataType().value())
217+
init = castOp(v1->getDataType().value(), init);
218+
new ReductionOp(reduction_op_type, init, out, v1);
219+
return out;
220+
}
221+
222+
TORCH_CUDA_API Val* sum(Val* v1, const std::vector<int>& axes) {
223+
Val* init;
224+
switch (v1->getDataType().value()) {
225+
case (DataType::Float):
226+
init = new Float(0.0);
227+
break;
228+
case (DataType::Int):
229+
init = new Int(0);
230+
break;
231+
default:
232+
TORCH_CHECK(
233+
false,
234+
"Could not generate a sum op for tensor with type: ",
235+
v1->getDataType().value());
236+
}
237+
238+
return reductionOp(BinaryOpType::Add, axes, init, v1);
239+
}
240+
241+
// COMPOUND OPERATIONS
242+
142243
TORCH_CUDA_API Val* add_alpha(Val* v1, Val* v2, Val* s) {
143244
TORCH_CHECK(
144245
s->getValType().value() == ValType::Scalar,
@@ -183,10 +284,12 @@ TORCH_CUDA_API Val* where(Val* c, Val* v1, Val* v2) {
183284
c->getDataType().value());
184285

185286
Val* out = promoteNew(v1, v2);
186-
Statement* expr = new TernaryOp(TernaryOpType::Where, out, c, v1, v2);
287+
new TernaryOp(TernaryOpType::Where, out, c, v1, v2);
187288
return out;
188289
}
189290

291+
// TERNARY OPERATIONS
292+
190293
TORCH_CUDA_API Val* threshold(Val* in, Val* thresh, Val* value) {
191294
TORCH_CHECK(
192295
in->getDataType().value() == thresh->getDataType().value() &&
@@ -199,8 +302,8 @@ TORCH_CUDA_API Val* threshold(Val* in, Val* thresh, Val* value) {
199302
"Thresh and Value values should be Scalars");
200303

201304
Val* out = newValLike(in);
202-
Statement* expr =
203-
new TernaryOp(TernaryOpType::Threshold, out, in, thresh, value);
305+
306+
new TernaryOp(TernaryOpType::Threshold, out, in, thresh, value);
204307
return out;
205308
}
206309

@@ -216,8 +319,8 @@ TORCH_CUDA_API Val* clamp(Val* in, Val* min_val, Val* max_val) {
216319
"Min and Max values should be Scalars");
217320

218321
Val* out = newValLike(in);
219-
Statement* expr =
220-
new TernaryOp(TernaryOpType::Clamp, out, in, min_val, max_val);
322+
323+
new TernaryOp(TernaryOpType::Clamp, out, in, min_val, max_val);
221324
return out;
222325
}
223326

torch/csrc/jit/codegen/cuda/arith.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,15 @@ TORCH_CUDA_API Val* unaryOp(UnaryOpType type, Val* v1);
3232
// Mod, CeilDiv, and LT are considered Int only output operations for now.
3333
TORCH_CUDA_API Val* binaryOp(BinaryOpType type, Val* v1, Val* v2);
3434

35+
// Perform a reduction operation on v1, initial value for reduction is init,
36+
// reduces across axes, and reduction operation defined by BinaryOp.
37+
TORCH_CUDA_API Val* reductionOp(
38+
BinaryOpType reduction_op_type,
39+
const std::vector<int>& axes,
40+
Val* init,
41+
Val* v1);
42+
43+
// BINARY OPAERATIONS
3544
TORCH_CUDA_API Val* add(Val* v1, Val* v2);
3645
TORCH_CUDA_API Val* sub(Val* v1, Val* v2);
3746
TORCH_CUDA_API Val* mul(Val* v1, Val* v2);
@@ -41,12 +50,17 @@ TORCH_CUDA_API Val* lt(Val* v1, Val* v2);
4150
TORCH_CUDA_API Val* ceilDiv(Val* v1, Val* v2);
4251
TORCH_CUDA_API Val* andOp(Val* v1, Val* v2);
4352

53+
// REDUCTION OPERATIONS
54+
TORCH_CUDA_API Val* sum(Val* v1, const std::vector<int>& reduction_axes);
55+
56+
// COMPOUND OPERATIONS
4457
TORCH_CUDA_API Val* add_alpha(Val* v1, Val* v2, Val* s);
4558
TORCH_CUDA_API Val* sub_alpha(Val* v1, Val* v2, Val* s);
4659
TORCH_CUDA_API Val* lerp(Val* start, Val* end, Val* weight);
4760
TORCH_CUDA_API Val* addcmul(Val* v1, Val* v2, Val* v3, Val* s);
48-
4961
TORCH_CUDA_API Val* where(Val* c, Val* v1, Val* v2);
62+
63+
// TERNARY OPERATIONS
5064
TORCH_CUDA_API Val* threshold(Val* in, Val* thresh, Val* value);
5165
TORCH_CUDA_API Val* clamp(Val* in, Val* min_val, Val* max_val);
5266

0 commit comments

Comments
 (0)