Skip to content

Commit

Permalink
set device guard for multi tensor optimizer implementations (NVIDIA#927)
Browse files Browse the repository at this point in the history
* add device guards to the optimizers

* add untracked file

* set deviceGuard in multi_tensor_apply

* address review comments; fix lamb

* indent

* typo
  • Loading branch information
ngimel authored Aug 5, 2020
1 parent 5b53121 commit 274cc06
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 157 deletions.
5 changes: 3 additions & 2 deletions apex/optimizers/fused_lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, params, lr=1e-3, bias_correction=True,
import amp_C
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
self.multi_tensor_lamb = amp_C.multi_tensor_lamb
else:
raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')
Expand Down Expand Up @@ -117,7 +117,8 @@ def step(self, closure=None):
else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.')

g_norm_32, g_norm_16 = torch.zeros(1, device='cuda'), torch.zeros(1, device='cuda')
device = self.param_groups[0]["params"][0].device
g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device)
# compute grad norm for two lists
if len(g_all_32) > 0:
g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
Expand Down
2 changes: 1 addition & 1 deletion apex/optimizers/fused_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(self, params, lr=required, momentum=0, dampening=0,
if multi_tensor_applier.available:
import amp_C
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
self.multi_tensor_sgd = amp_C.multi_tensor_sgd
else:
raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions')
Expand Down
17 changes: 10 additions & 7 deletions csrc/multi_tensor_apply.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include "compat.h"

#include <assert.h>
Expand Down Expand Up @@ -34,7 +35,7 @@ __global__ void multi_tensor_apply_kernel(
ArgTypes... args)
{
// Hand the chunk information to the user-supplied functor to process however it likes.
callable(chunk_size, noop_flag, tl, args...);
callable(chunk_size, noop_flag, tl, args...);
}

template<int depth, typename T, typename... ArgTypes>
Expand All @@ -49,8 +50,9 @@ void multi_tensor_apply(
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");

for(int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
{
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for(int t = 0; t < tensor_lists[l].size(); t++)
Expand All @@ -61,7 +63,7 @@ void multi_tensor_apply(
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
#endif
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].is_cuda(), "A tensor was not cuda.");
TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
}
}
Expand All @@ -70,8 +72,9 @@ void multi_tensor_apply(

TensorListMetadata<depth> tl;

const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
auto stream = at::cuda::getCurrentCUDAStream();

tl.start_tensor_this_launch = 0;
int loc_block_info = 0;
int loc_tensor_info = 0;
Expand All @@ -90,7 +93,7 @@ void multi_tensor_apply(
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk;
loc_block_info++;

bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth-1] &&
chunk == chunks_this_tensor - 1);
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]);
Expand All @@ -112,7 +115,7 @@ void multi_tensor_apply(
if(chunk == chunks_this_tensor - 1)
{
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
loc_tensor_info = 0;
loc_tensor_info = 0;
tl.start_tensor_this_launch = t + 1;
}
else
Expand Down
5 changes: 3 additions & 2 deletions csrc/multi_tensor_l2norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
// Another possibility:
// #include <torch/all.h>

Expand Down Expand Up @@ -335,13 +336,13 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
max_chunks_per_tensor);)

AT_CUDA_CHECK(cudaGetLastError());

// AT_CUDA_CHECK(cudaDeviceSynchronize());

// This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now
auto ret = at::empty({1}, output.options());
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
auto stream = at::cuda::getCurrentCUDAStream();
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
output.DATA_PTR<float>(),
Expand Down Expand Up @@ -369,7 +370,7 @@ void multi_tensor_norm_out_cuda(
const int norm_type)
{
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);

TORCH_CHECK(tensor_lists[0][0].device() == noop_flag.device(), "noop flag should be on the same device as tensors");
// we don't need global thus uses empty here
auto output = at::empty({320}, float_options);

Expand Down
2 changes: 2 additions & 0 deletions csrc/multi_tensor_sgd_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ void multi_tensor_sgd_cuda(
TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
"Additional output tensors should always be fp16.");

TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), "expected noop flag to be on the same device as tensors");

// We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No
Expand Down
114 changes: 0 additions & 114 deletions tests/L0/run_optimizers/test_adagrad.py

This file was deleted.

Loading

0 comments on commit 274cc06

Please sign in to comment.