Skip to content

Commit

Permalink
Revert D22432885: [pytorch][PR] unsafe_split, unsafe_split_with_sizes…
Browse files Browse the repository at this point in the history
…, unsafe_chunk operations

Test Plan: revert-hammer

Differential Revision:
D22432885 (pytorch/pytorch@c17670a)

Original commit changeset: 324aef091b32

fbshipit-source-id: 6b7c52bde46932e1cf77f61e7035d8a641b0beb6
  • Loading branch information
Qiao Tan authored and facebook-github-bot committed Jul 14, 2020
1 parent 144f04e commit 359cdc2
Show file tree
Hide file tree
Showing 17 changed files with 92 additions and 285 deletions.
14 changes: 7 additions & 7 deletions aten/src/ATen/native/RNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ struct LSTMCell : Cell<std::tuple<Tensor, Tensor>, cell_params> {

const auto gates = params.linear_hh(hx).add_(
pre_compute_input ? input : params.linear_ih(input));
auto chunked_gates = gates.unsafe_chunk(4, 1);
auto chunked_gates = gates.chunk(4, 1);
auto ingate = chunked_gates[0].sigmoid_();
auto forgetgate = chunked_gates[1].sigmoid_();
auto cellgate = chunked_gates[2].tanh_();
Expand Down Expand Up @@ -738,9 +738,9 @@ struct GRUCell : Cell<Tensor, cell_params> {
return std::move(std::get<0>(result));
}
const auto chunked_igates = pre_compute_input
? input.unsafe_chunk(3, 1)
: params.linear_ih(input).unsafe_chunk(3, 1);
auto chunked_hgates = params.linear_hh(hidden).unsafe_chunk(3, 1);
? input.chunk(3, 1)
: params.linear_ih(input).chunk(3, 1);
auto chunked_hgates = params.linear_hh(hidden).chunk(3, 1);
const auto reset_gate =
chunked_hgates[0].add_(chunked_igates[0]).sigmoid_();
const auto input_gate =
Expand Down Expand Up @@ -1477,7 +1477,7 @@ _thnn_differentiable_lstm_cell_backward(
if (hidden_bias.defined()) {
gates = gates + hidden_bias;
}
auto chunked_gates = gates.unsafe_chunk(4, 1);
auto chunked_gates = gates.chunk(4, 1);
Tensor i = chunked_gates[0].sigmoid();
Tensor f = chunked_gates[1].sigmoid();
Tensor c = chunked_gates[2].tanh();
Expand Down Expand Up @@ -1524,11 +1524,11 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _thnn_differentiable_gru_cell
if (hidden_bias.defined()){
h_g = h_g + hidden_bias;
}
auto chunked_input_gates = in_g.unsafe_chunk(3, 1);
auto chunked_input_gates = in_g.chunk(3, 1);
Tensor ir = chunked_input_gates[0];
Tensor ii = chunked_input_gates[1];
Tensor in = chunked_input_gates[2];
auto chunked_hidden_gates = h_g.unsafe_chunk(3, 1);
auto chunked_hidden_gates = h_g.chunk(3, 1);
Tensor hr = chunked_hidden_gates[0];
Tensor hi = chunked_hidden_gates[1];
Tensor hn = chunked_hidden_gates[2];
Expand Down
35 changes: 0 additions & 35 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,25 +510,6 @@ std::vector<Tensor> chunk(const Tensor& self, int64_t chunks, int64_t dim) {
}
}

std::vector<Tensor> unsafe_chunk(const Tensor& self, int64_t chunks, int64_t dim) {
TORCH_CHECK(self.dim() > 0,
"chunk expects at least a 1-dimensional tensor");
TORCH_CHECK(chunks > 0,
"chunk expects `chunks` to be greater than 0, got: ", chunks);

std::vector<Tensor> result;
int64_t split_size = (self.size(dim) + chunks - 1) / chunks;

// See the comment above in chunk(...)
if (split_size == 0 && self.size(dim) == 0) {
std::vector<int64_t> split_sizes(chunks, split_size);
split_sizes[chunks - 1] = split_size - (split_size * chunks - self.size(dim));
return self.unsafe_split_with_sizes(split_sizes, dim);
} else {
return self.unsafe_split(split_size, dim);
}
}

Tensor diagflat(const Tensor& self, int64_t offset) {
return self.contiguous().view(-1).diag(offset);
}
Expand Down Expand Up @@ -1071,14 +1052,6 @@ std::vector<Tensor> split(const Tensor& self, int64_t split_size, int64_t dim) {
return splits;
}

std::vector<Tensor> unsafe_split(const Tensor& self, int64_t split_size, int64_t dim) {
auto result = at::native::split(self, split_size, dim);
for (auto& t : result) {
t.unsafeGetTensorImpl()->set_version_counter(c10::VariableVersion());
}
return result;
}

std::vector<Tensor> split_with_sizes(const Tensor& self, IntArrayRef split_sizes, int64_t dim) {
TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor");
int64_t dim_size = self.size(dim);
Expand All @@ -1101,14 +1074,6 @@ std::vector<Tensor> split_with_sizes(const Tensor& self, IntArrayRef split_sizes
return splits;
}

std::vector<Tensor> unsafe_split_with_sizes(const Tensor& self, IntArrayRef split_sizes, int64_t dim) {
auto result = at::native::split_with_sizes(self, split_sizes, dim);
for (auto& t : result) {
t.unsafeGetTensorImpl()->set_version_counter(c10::VariableVersion());
}
return result;
}

// Precondition: tensors is non-empty
static inline std::vector<Tensor> get_stack_inputs(TensorList tensors, int64_t dim) {
std::vector<Tensor> inputs(tensors.size());
Expand Down
20 changes: 1 addition & 19 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -712,12 +712,6 @@
use_c10_dispatcher: full
variants: function

- func: unsafe_chunk(Tensor self, int chunks, int dim=0) -> Tensor[]
use_c10_dispatcher: full
variants: function, method
device_guard: False
supports_named_tensor: True

- func: chunk(Tensor(a) self, int chunks, int dim=0) -> Tensor(a)[]
use_c10_dispatcher: full
variants: function, method
Expand Down Expand Up @@ -2563,24 +2557,12 @@
CPU: softmax_backward_cpu
CUDA: softmax_backward_cuda

- func: unsafe_split.Tensor(Tensor self, int split_size, int dim=0) -> Tensor[]
use_c10_dispatcher: full
variants: function, method
device_guard: False
supports_named_tensor: True

- func: split.Tensor(Tensor(a) self, int split_size, int dim=0) -> Tensor(a)[]
use_c10_dispatcher: full
variants: function, method
device_guard: False

- func: unsafe_split_with_sizes(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]
use_c10_dispatcher: full
variants: function, method
device_guard: False
supports_named_tensor: True

- func: split_with_sizes(Tensor(a) self, int[] split_sizes, int dim=0) -> Tensor(a)[]
- func: split_with_sizes(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]
use_c10_dispatcher: full
variants: function, method
device_guard: False
Expand Down
1 change: 0 additions & 1 deletion docs/source/tensor_view.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ For reference, here’s a full list of view ops in PyTorch:
- :meth:`~torch.Tensor.view_as`
- :meth:`~torch.Tensor.unbind`
- :meth:`~torch.Tensor.split`
- :meth:`~torch.Tensor.split_with_sizes`
- :meth:`~torch.Tensor.chunk`
- :meth:`~torch.Tensor.indices` (sparse tensor only)
- :meth:`~torch.Tensor.values` (sparse tensor only)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
('aten::dict', datetime.date(2020, 6, 30)),
('aten::tensor', datetime.date(2020, 6, 30)),
('aten::as_tensor', datetime.date(2020, 6, 30)),
('aten::split_with_sizes', datetime.date(2020, 7, 20)),
('quantized::linear_unpack_fp16', datetime.date(2020, 6, 1)),
('quantized::linear_unpack', datetime.date(2020, 6, 1)),
('quantized::linear_prepack_fp16', datetime.date(2020, 6, 1)),
Expand Down Expand Up @@ -115,6 +114,7 @@
('aten::__and__', datetime.date(2020, 6, 30)),
('aten::__or__', datetime.date(2020, 6, 30)),
('aten::__xor__', datetime.date(2020, 6, 30)),
('aten::split', datetime.date(2020, 6, 30)),
('aten::add', datetime.date(2020, 7, 30)),
('aten::__upsample_bilinear', datetime.date(2020, 7, 30)),
('aten::hash', datetime.date(2020, 7, 30)),
Expand Down
21 changes: 0 additions & 21 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6428,27 +6428,6 @@ def test_inplace_view_non_contig(self, device):
x.sum().backward()
self.assertEqual(root.grad.tolist(), [[1, 2], [1, 1], [1, 1]])

def test_inplace_view_multi_output_unsafe(self, device):
for f in [lambda t: t.unsafe_split(1),
lambda t: t.unsafe_split_with_sizes((1, 1, 1)),
lambda t: t.unsafe_chunk(3)]:
a = torch.randn(3, 3, device=device, requires_grad=True)
b = a + a
s1, s2, s3 = f(b)
s1.mul_(s2)
s1.sum().backward()

def test_inplace_view_multi_output_safe(self, device):
for f in [lambda t: t.split(1),
lambda t: t.split_with_sizes((1, 1, 1)),
lambda t: t.chunk(3)]:
a = torch.randn(3, 3, device=device, requires_grad=True)
b = a + a
s1, s2, s3 = f(b)
with warnings.catch_warnings(record=True) as w:
s1.mul_(s2)
self.assertIn('Consider using `unsafe_` version', str(w[0].message))

def test_mv_grad_stride_0(self, device):
# Reference: https://github.com/pytorch/pytorch/issues/38315
mat = torch.randn(2, 2, device=device)
Expand Down
1 change: 0 additions & 1 deletion test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ def test_namespace(ns, *skips):
'smm',
'softmax',
'split_with_sizes',
'unsafe_split_with_sizes',
'sspaddmm',
'to_dense',
'sparse_resize_',
Expand Down
8 changes: 1 addition & 7 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -916,13 +916,7 @@
- name: split.Tensor(Tensor(a) self, int split_size, int dim=0) -> Tensor(a)[]
self: split_backward(grads, split_size, dim, self.sizes(), self.options())

- name: unsafe_split.Tensor(Tensor self, int split_size, int dim=0) -> Tensor[]
self: split_backward(grads, split_size, dim, self.sizes(), self.options())

- name: split_with_sizes(Tensor(a) self, int[] split_sizes, int dim=0) -> Tensor(a)[]
self: split_with_sizes_backward(grads, split_sizes, dim, self.sizes(), self.options())

- name: unsafe_split_with_sizes(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]
- name: split_with_sizes(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]
self: split_with_sizes_backward(grads, split_sizes, dim, self.sizes(), self.options())

- name: sqrt(Tensor self) -> Tensor
Expand Down
14 changes: 3 additions & 11 deletions tools/autograd/gen_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,10 @@
'as_strided': 'self',
'diagonal': 'self',
'expand': 'self',
'narrow': 'self',
'permute': 'self',
'select': 'self',
'slice': 'self',
'split': 'self',
'split_with_sizes': 'self',
'squeeze': 'self',
't': 'self',
'transpose': 'self',
Expand All @@ -72,21 +71,14 @@
for key in VIEW_FUNCTIONS_WITH_METADATA_CHANGE:
VIEW_FUNCTIONS[key] = 'self'

# Functions for which we use CreationMeta::MULTI_OUTPUT_SAFE. I.e., the ones for
# which inplace modification of outputs is being gradually deprecated.
MULTI_OUTPUT_SAFE_FUNCTIONS = {
'split',
'split_with_sizes',
}

# note: some VIEW_FUNCTIONS are just compositions of the view functions above
# this list contains both the root view functions and any that are purely composed
# of viewing functions, and is used by the JIT to determine when an operator
# may return a view of its inputs; however they may sometimes return a copy.
# (e.g. `contiguous`)
RETURNS_VIEWS_OF_INPUT = set(VIEW_FUNCTIONS.keys()).union({
'chunk', 'detach', 'contiguous', 'reshape', 'reshape_as',
'expand_as', 'view_as', 'real', 'imag', 'narrow',
'chunk', 'split', 'detach', 'contiguous', 'reshape', 'reshape_as',
'expand_as', 'view_as', 'real', 'imag',
})

def format_return_type(returns):
Expand Down
13 changes: 5 additions & 8 deletions tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
#
from __future__ import print_function
from .utils import CodeTemplate, nested_dict, write, uninplace_api_name
from .gen_autograd import VIEW_FUNCTIONS, VIEW_FUNCTIONS_WITH_METADATA_CHANGE, \
MULTI_OUTPUT_SAFE_FUNCTIONS, RETURNS_VIEWS_OF_INPUT
from .gen_autograd import VIEW_FUNCTIONS, VIEW_FUNCTIONS_WITH_METADATA_CHANGE
from .gen_autograd_functions import uses_single_grad

# These functions we don't want to record for tracing, because we always want
Expand Down Expand Up @@ -824,8 +823,9 @@ def emit_body(declaration):

base_name = name[:-1] if inplace else name[:-4] if is_out_fn else name
view_info = VIEW_FUNCTIONS.get(base_name, None)
if view_info is None and base_name in RETURNS_VIEWS_OF_INPUT:
view_info = "self"
# TODO: Add back when https://github.com/pytorch/pytorch/pull/32044 lands again
# if view_info is None and base_name in RETURNS_VIEWS_OF_INPUT:
# view_info = "self"

def is_differentiable(arg):
if 'TensorOptions' in arg['type']:
Expand Down Expand Up @@ -1095,10 +1095,7 @@ def wrap_output(return_values, var):
# If we are in a no grad block, raise a warning
# See NOTE [ View + Inplace detection ] for more details about this logic
if return_info['dynamic_type'] == 'TensorList':
if base_name in MULTI_OUTPUT_SAFE_FUNCTIONS:
creation_meta = "CreationMeta::MULTI_OUTPUT_SAFE"
else:
creation_meta = "CreationMeta::MULTI_OUTPUT_NODE"
creation_meta = "CreationMeta::MULTI_OUTPUT_NODE"
rhs_value = ("as_view(/* base */ {}, /* output */ {}, /* is_differentiable */ true, "
"/* creation_meta */ {})").format(view_info, var, creation_meta)
else:
Expand Down
3 changes: 0 additions & 3 deletions torch/_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,9 +661,6 @@ def get_testing_overrides():
torch.unbind: lambda input, dim=0: -1,
torch.unique: lambda input, sorted=True, return_inverse=False, return_counts=False, dim=None: -1,
torch.unique_consecutive: lambda input, return_inverse=False, return_counts=False, dim=None: -1,
torch.unsafe_chunk: lambda input, chunks, dim=0: -1,
torch.unsafe_split: lambda tensor, split_size_or_sections, dim=0: -1,
torch.unsafe_split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1,
torch.unsqueeze: lambda input, dim, out=None: -1,
torch.var: lambda input: -1,
torch.var_mean: lambda input: -1,
Expand Down
14 changes: 0 additions & 14 deletions torch/_tensor_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3635,20 +3635,6 @@ def callable(a, b) -> number
See :func:`torch.chunk`
""")

add_docstr_all('unsafe_chunk',
r"""
unsafe_chunk(chunks, dim=0) -> List of Tensors
See :func:`torch.unsafe_chunk`
""")

add_docstr_all('unsafe_split',
r"""
unsafe_split(split_size, dim=0) -> List of Tensors
See :func:`torch.unsafe_split`
""")

add_docstr_all('stft',
r"""
stft(frame_length, hop, fft_size=None, return_onesided=True, window=None, pad_end=0) -> Tensor
Expand Down
30 changes: 0 additions & 30 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,36 +1030,6 @@ def merge_dicts(*dicts):
dim (int): dimension along which to split the tensor
""")

add_docstr(torch.unsafe_chunk,
r"""
unsafe_chunk(input, chunks, dim=0) -> List of Tensors

Works like :func:`torch.chunk` but without enforcing the autograd restrictions
on inplace modification of the outputs.

.. warning::
This function is safe to use as long as only the input, or only the outputs
are modified inplace after calling this function. It is user's
responsibility to ensure that is the case. If both the input and one or more
of the outputs are modified inplace, gradients computed by autograd will be
silently incorrect.
""")

add_docstr(torch.unsafe_split,
r"""
unsafe_split(tensor, split_size_or_sections, dim=0) -> List of Tensors

Works like :func:`torch.split` but without enforcing the autograd restrictions
on inplace modification of the outputs.

.. warning::
This function is safe to use as long as only the input, or only the outputs
are modified inplace after calling this function. It is user's
responsibility to ensure that is the case. If both the input and one or more
of the outputs are modified inplace, gradients computed by autograd will be
silently incorrect.
""")

add_docstr(torch.can_cast,
r"""
can_cast(from, to) -> bool
Expand Down
14 changes: 1 addition & 13 deletions torch/csrc/autograd/VariableTypeManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,19 +251,7 @@ Tensor detach(const Tensor & self) {
Tensor & detach_(Tensor & self) {
RECORD_FUNCTION("detach_", std::vector<c10::IValue>({self}));
if (self.is_view()) {
// NB: is_view() ==> get_autograd_meta()
auto diff_view_meta = static_cast<torch::autograd::DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(self));
// See NOTE [ View + Inplace detection ]
if (diff_view_meta->creation_meta == CreationMeta::MULTI_OUTPUT_SAFE) {
TORCH_WARN("This view is an output of a function that "
"returns multiple views. Detaching such views inplace "
"is being deprecated and will be forbidden "
"starting from version 1.8. Consider using detach() instead "
"of detach_(). Alternatively, create this view with an "
"`unsafe_` version of the function that produced it.");
} else {
AT_ERROR("Can't detach views in-place. Use detach() instead");
}
AT_ERROR("Can't detach views in-place. Use detach() instead");
}
// I think the choice here is conservative. In principle, doing
// an in-place detach should give us the ability to just clear
Expand Down
Loading

0 comments on commit 359cdc2

Please sign in to comment.