Skip to content

Commit

Permalink
add l2norm source for FusedLAMB
Browse files Browse the repository at this point in the history
  • Loading branch information
Kexin Yu committed Mar 23, 2020
1 parent 04927b3 commit a3ffb8a
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
ext_modules.append(
CUDAExtension(name='fused_lamb_cuda',
sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu'],
'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu',
'csrc/multi_tensor_l2norm_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
Expand Down

0 comments on commit a3ffb8a

Please sign in to comment.