Skip to content

Commit

Permalink
add back lamb stage1/2 to amp_C python
Browse files Browse the repository at this point in the history
  • Loading branch information
FDecaYed committed Aug 20, 2019
1 parent f855f85 commit b9f0995
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions csrc/amp_C_frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,26 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);

void multi_tensor_lamb_stage1_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_decay,
const int step,
const float beta1,
const float beta2,
const float epsilon,
const float global_grad_norm,
const float max_global_grad_norm);

void multi_tensor_lamb_stage2_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
const float step_size);

void multi_tensor_adam_cuda(
int chunk_size,
at::Tensor noop_flag,
Expand Down Expand Up @@ -86,6 +106,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"out = a*x + b*y for a list of contiguous tensors");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors");
m.def("multi_tensor_lamb_stage1_cuda", &multi_tensor_lamb_stage1_cuda,
"Computes update part of LAMB optimizer");
m.def("multi_tensor_lamb_stage2_cuda", &multi_tensor_lamb_stage2_cuda,
"Completes application of gradient to parameters for LAMB optimizer");
m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer");
m.def("multi_tensor_novograd", &multi_tensor_novograd_cuda,
Expand Down

0 comments on commit b9f0995

Please sign in to comment.