forked from NVIDIA/apex
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathamp_C_frontend.cpp
174 lines (158 loc) · 5.27 KB
/
amp_C_frontend.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#include <torch/extension.h>
void multi_tensor_scale_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float scale);
void multi_tensor_sgd_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float wd,
float momentum,
float dampening,
float lr,
bool nesterov,
bool first_run,
bool wd_after_momentum,
float scale);
void multi_tensor_axpby_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float a,
float b,
int arg_to_check);
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_mp_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_scale_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float scale,
at::optional<bool> per_tensor_python);
void multi_tensor_lamb_stage1_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_decay,
const int step,
const float beta1,
const float beta2,
const float epsilon,
at::Tensor global_grad_norm,
const float max_global_grad_norm);
void multi_tensor_lamb_stage2_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
const float lr,
const float weight_decay,
at::optional<bool> use_nvlamb_python);
void multi_tensor_adam_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr,
const float beta1,
const float beta2,
const float epsilon,
const int step,
const int mode,
const int bias_correction,
const float weight_decay);
void multi_tensor_adagrad_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr,
const float epsilon,
const int mode,
const float weight_decay);
void multi_tensor_novograd_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor grad_norms,
const float lr,
const float beta1,
const float beta2,
const float epsilon,
const int step,
const int bias_correction,
const float weight_decay,
const int grad_averaging,
const int mode,
const int norm_type);
void multi_tensor_lamb_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr,
const float beta1,
const float beta2,
const float epsilon,
const int step,
const int bias_correction,
const float weight_decay,
const int grad_averaging,
const int mode,
at::Tensor global_grad_norm,
const float max_grad_norm,
at::optional<bool> use_nvlamb_python);
void multi_tensor_lamb_mp_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr,
const float beta1,
const float beta2,
const float epsilon,
at::Tensor step,
const int bias_correction,
const float weight_decay,
const int grad_averaging,
const int mode,
at::Tensor global_grad_norm,
at::Tensor max_grad_norm,
at::optional<bool> use_nvlamb_python,
at::Tensor found_inf,
at::Tensor inv_scale);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors");
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
"Fused SGD optimizer for list of contiguous tensors");
m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda,
"out = a*x + b*y for a list of contiguous tensors");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors");
m.def("multi_tensor_l2norm_mp", &multi_tensor_l2norm_mp_cuda,
"Computes L2 norm for a list of contiguous tensors");
m.def("multi_tensor_l2norm_scale", &multi_tensor_l2norm_scale_cuda,
"Computes L2 norm for a list of contiguous tensors and does scaling");
m.def("multi_tensor_lamb_stage1_cuda", &multi_tensor_lamb_stage1_cuda,
"Computes update part of LAMB optimizer");
m.def("multi_tensor_lamb_stage2_cuda", &multi_tensor_lamb_stage2_cuda,
"Completes application of gradient to parameters for LAMB optimizer");
m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer");
m.def("multi_tensor_adagrad", &multi_tensor_adagrad_cuda,
"Compute and apply gradient update to parameters for Adam optimizer");
m.def("multi_tensor_novograd", &multi_tensor_novograd_cuda,
"Compute and apply gradient update to parameters for Adam optimizer");
m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda,
"Computes and apply update for LAMB optimizer");
m.def("multi_tensor_lamb_mp", &multi_tensor_lamb_mp_cuda,
"Computes and apply update for LAMB optimizer");
}