forked from NVIDIA/apex
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinterface.cpp
83 lines (72 loc) · 2.43 KB
/
interface.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include <torch/torch.h>
// Ideally, I'd like to call this file "weight_norm.cu" and put the interface and the implementation
// here, but I can't make nvcc play well with torch.h. For now, use a layer of indirection
// and separate .cu implementation files.
// If we want everything to be part of "apex_C", we need all the interface functions defined
// in this file, or linker will complain about "multiple definitions of PyInit".
// TODO: multiple modules?
// TODO: modify fwd+bwd calls to return a tuple of Tensors. This will require changing the
// Python client code as well. For now, get things working with the same Python-side API.
void weight_norm_fwd_cuda
(const at::Tensor& w,
const at::Tensor& norms,
const at::Tensor& v,
const at::Tensor& g,
int dim);
void weight_norm_fwd
(at::Tensor w,
at::Tensor norms,
at::Tensor v,
at::Tensor g,
int dim)
{
weight_norm_fwd_cuda(w, norms, v, g, dim);
}
void weight_norm_bwd_cuda
(const at::Tensor& pLpv,
const at::Tensor& pLpg,
const at::Tensor& pLpw,
const at::Tensor& savedv,
const at::Tensor& savedg,
const at::Tensor& savedNorms,
int dim);
void weight_norm_bwd
(at::Tensor pLpv,
at::Tensor pLpg,
at::Tensor pLpw,
at::Tensor savedv,
at::Tensor savedg,
at::Tensor savedNorms,
int dim)
{
weight_norm_bwd_cuda(pLpv, pLpg, pLpw, savedv, savedg, savedNorms, dim);
}
void scale_check_overflow_cuda
(const at::Tensor& d_grads,
float scale,
const at::Tensor& d_buf);
#ifdef VERSION_LE_04
#define VERSION_AGNOSTIC_CHECK AT_ASSERT
#else
#define VERSION_AGNOSTIC_CHECK AT_CHECK
#endif
void scale_check_overflow
(at::Tensor grads,
float scale,
at::Tensor overflow_buf)
{
VERSION_AGNOSTIC_CHECK
(grads.type().is_cuda(), "x must be a CUDA tensor");
VERSION_AGNOSTIC_CHECK
(overflow_buf.type().is_cuda(), "y must be a CUDA tensor");
// Make sure we are downscaling the FP32 master grads
VERSION_AGNOSTIC_CHECK
(grads.type().scalarType() == at::ScalarType::Float,
"grads supplied to scale_check_overflow should be fp32 (master grads).")
scale_check_overflow_cuda(grads, scale, overflow_buf);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("weight_norm_fwd", &weight_norm_fwd, "Fused weight norm, forward pass");
m.def("weight_norm_bwd", &weight_norm_bwd, "Fused weight norm, backward pass");
m.def("scale_check_overflow", &scale_check_overflow, "Fused overflow check + scale for FP32 tensors");
}