Skip to content

Commit

Permalink
[fix] MathBits: serialization (pytorch#88182)
Browse files Browse the repository at this point in the history
Fixes pytorch#81690

TODO:

* [x] C++ Unpickler Fix (locally tested pickled in Python and unpickled in C++)
* [x] C++ Pickler Fix (locally tested pickled in C++ and unpickled in Python)
* [x] Do quant_tensor, sparse_tensor, etc require similar changes? (Sparse and Quant don't need this)
* [x] Add Comments
* [x] How to make sure C++ and Python are in sync? (Functions in `pickler.h` help in getting and setting Tensor Metadata (math-bits for now) on a tensor. They are the only place which should handle this.)

Notes:
Quant Tensor don't support complex dtypes and for float they segfault with `_neg_view` : pytorch#88484

Sparse Tensor:
```python
>>> a = torch.tensor([[0, 2.], [3j, 0]]).to_sparse()
>>> a.conj().is_conj()
False
>>> a._neg_view()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: Cannot access storage of SparseTensorImpl
```

Pull Request resolved: pytorch#88182
Approved by: https://github.com/ezyang, https://github.com/anjali411
  • Loading branch information
kshitij12345 authored and pytorchmergebot committed Nov 9, 2022
1 parent 525fe53 commit eb9b156
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 4 deletions.
33 changes: 33 additions & 0 deletions test/cpp/api/serialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,39 @@ TEST(SerializeTest, Basic) {
ASSERT_TRUE(x.allclose(y));
}

TEST(SerializeTest, MathBits) {
torch::manual_seed(0);

auto options = torch::TensorOptions{}.dtype(torch::kComplexFloat);
auto x = torch::randn({5, 5}, options);
{
auto expected = torch::conj(x);
auto actual = save_and_load(expected);

ASSERT_TRUE(actual.defined());
ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
ASSERT_TRUE(actual.allclose(expected));
}

{
auto expected = torch::_neg_view(x);
auto actual = save_and_load(expected);

ASSERT_TRUE(actual.defined());
ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
ASSERT_TRUE(actual.allclose(expected));
}

{
auto expected = torch::conj(torch::_neg_view(x));
auto actual = save_and_load(expected);

ASSERT_TRUE(actual.defined());
ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
ASSERT_TRUE(actual.allclose(expected));
}
}

TEST(SerializeTest, BasicToFile) {
torch::manual_seed(0);

Expand Down
20 changes: 20 additions & 0 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,26 @@ def __reduce__(self):
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported class"):
torch.load(f, weights_only=True)

@parametrize('weights_only', (False, True))
def test_serialization_math_bits(self, weights_only):
t = torch.randn(1, dtype=torch.cfloat)

def _save_load_check(t):
with BytesIOContext() as f:
torch.save(t, f)
f.seek(0)
# Unsafe load should work
self.assertEqual(torch.load(f, weights_only=weights_only), t)

t_conj = torch.conj(t)
_save_load_check(t_conj)

t_neg = torch._neg_view(t)
_save_load_check(t_neg)

t_n_c = torch._neg_view(torch.conj(t))
_save_load_check(t_n_c)

def run(self, *args, **kwargs):
with serialization_method(use_zip=True):
return super(TestSerialization, self).run(*args, **kwargs)
Expand Down
4 changes: 4 additions & 0 deletions torch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,10 @@ def _reduce_ex_internal(self, proto):
self.requires_grad,
backward_hooks,
) # previously was self._backward_hooks

metadata = torch._utils.get_tensor_metadata(self)
if metadata:
args = args + (metadata,) # type: ignore[assignment]
return (torch._utils._rebuild_tensor_v2, args)

def __setstate__(self, state):
Expand Down
20 changes: 19 additions & 1 deletion torch/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,29 @@ def _rebuild_tensor(storage, storage_offset, size, stride):
return t.set_(storage._untyped_storage, storage_offset, size, stride)


def get_tensor_metadata(tensor):
# Tensor's Metadata for serializing.
# Currently, this only returns a dict[string, bool] specifing whether
# `conj` or `neg` bit is set.
assert isinstance(tensor, torch.Tensor)
return torch._C._get_tensor_metadata(tensor) # type: ignore[attr-defined]


def set_tensor_metadata(tensor, metadata):
# See `get_tensor_metadata` above
assert isinstance(metadata, dict)
assert isinstance(tensor, torch.Tensor)
torch._C._set_tensor_metadata(tensor, metadata) # type: ignore[attr-defined]


def _rebuild_tensor_v2(
storage, storage_offset, size, stride, requires_grad, backward_hooks
storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None
):
tensor = _rebuild_tensor(storage, storage_offset, size, stride)
tensor.requires_grad = requires_grad
if metadata:
set_tensor_metadata(tensor, metadata)

# NB: This line exists only for backwards compatibility; the
# general expectation is that backward_hooks is an empty
# OrderedDict. See Note [Don't serialize hooks]
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#include <torch/csrc/jit/python/init.h>
#include <torch/csrc/jit/python/python_ir.h>
#include <torch/csrc/jit/python/python_tracer.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/lazy/python/init.h>
#include <torch/csrc/monitor/python_init.h>
#include <torch/csrc/multiprocessing/init.h>
Expand Down Expand Up @@ -1544,6 +1545,12 @@ Call this whenever a new thread is created in order to propagate values from
"_set_conj", [](const at::Tensor& x, bool conj) { x._set_conj(conj); });
py_module.def(
"_set_neg", [](const at::Tensor& x, bool neg) { x._set_neg(neg); });
py_module.def("_get_tensor_metadata", &torch::jit::getTensorMetadata);
py_module.def(
"_set_tensor_metadata",
static_cast<void (*)(
const at::Tensor&, std::unordered_map<std::string, bool>)>(
torch::jit::setTensorMetadata));
py_module.def("_dispatch_key_set", [](const at::Tensor& x) {
return toString(x.key_set());
});
Expand Down
14 changes: 14 additions & 0 deletions torch/csrc/jit/serialization/pickler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,20 @@ void Pickler::pushLiteralTensor(const IValue& ivalue) {
// Construct the collections.OrderedDict for the backward_hooks
push<PickleOpCode>(PickleOpCode::REDUCE);

if (!quantized) {
// Only push it for regular tensor if the dictionary is not empty.
auto metadata = torch::jit::getTensorMetadata(tensor);
if (!metadata.empty()) {
// IValues based on std::unordered_map<K, V> are slow and deprecated.
// Thus, pass a c10::Dict to pushDict.
c10::Dict<std::string, bool> math_bits_;
for (const auto& pair : metadata) {
math_bits_.insert(pair.first, pair.second);
}
pushDict(math_bits_);
}
}

push<PickleOpCode>(PickleOpCode::TUPLE);

// Call torch._utils._rebuild_tensor_v2
Expand Down
49 changes: 49 additions & 0 deletions torch/csrc/jit/serialization/pickler.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,5 +296,54 @@ uint64_t getStorageKey(const at::Tensor& tensor);
// otherwise return false
bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls);

// Return a map of Tensor Metadata for serialization.
// For now, it only takes care of `conj` and `neg` bit.
inline std::unordered_map<std::string, bool> getTensorMetadata(
const at::Tensor& t) {
std::unordered_map<std::string, bool> metadata{};

// Only add meta-data if the value is not default.
if (t.is_conj()) {
metadata["conj"] = true;
}
if (t.is_neg()) {
metadata["neg"] = true;
}
return metadata;
}

// set Tensor Metadata based on the map.
// Refer: getTensorMathdata
inline void setTensorMetadata(
const at::Tensor& t,
std::unordered_map<std::string, bool> metadata) {
for (auto& key_value_pair : metadata) {
if (key_value_pair.first == "conj") {
t._set_conj(true);
} else if (key_value_pair.first == "neg") {
t._set_neg(true);
} else {
TORCH_CHECK(
false,
"Unexpected key `",
key_value_pair.first,
"` passed to setTensorMetadata.");
}
}
}

// set Tensor metadata based on the map.
// NOTE: This overload is required by unpickler.cpp
inline void setTensorMetadata(
const at::Tensor& t,
c10::Dict<c10::IValue, c10::IValue> metadata_idict) {
std::unordered_map<std::string, bool> metadata;
for (auto& pair : metadata_idict) {
auto key = *pair.key().toString();
metadata[key] = pair.value().toBool();
}
setTensorMetadata(t, metadata);
}

} // namespace jit
} // namespace torch
19 changes: 17 additions & 2 deletions torch/csrc/jit/serialization/unpickler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -823,13 +823,28 @@ void Unpickler::rebuildTensor(bool quantized) {
} else {
result = at::empty({0}, storage_tensor.options());
}
bool requires_grad = elements.at(idx).toBool();
// elements[idx++] is empty backwards hooks
bool requires_grad = elements.at(idx++).toBool();
idx++; // backwards hooks is empty
at::TensorImpl* impl = result.unsafeGetTensorImpl();
impl->set_storage_keep_dtype(storage_tensor.storage());
impl->set_storage_offset(storage_offset);
impl->set_sizes_and_strides(size, stride);
result = autograd::make_variable(result, requires_grad);

// Handle if math_bits were pickled.
// See `args` of _reduce_ex_internal
// for a regular tensor (final else case).
// Tensors pickled before this patch didn't
// have this argument for storing MathBits,
// in that case, we do nothing.
// NOTE: `math_bits` is the 7th arg.
// NOTE: This is only meant for regular tensor and not quantized
// which also has 7 args serialized.
if (!quantized && elements.size() == 7) {
auto math_bits = elements.at(idx++).toGenericDict();
torch::jit::setTensorMetadata(result, math_bits);
}

stack_.emplace_back(std::move(result));
});
}
Expand Down
5 changes: 4 additions & 1 deletion torch/utils/model_dump/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ def hierarchical_pickle(data):
}
if typename == "torch._utils._rebuild_tensor_v2":
assert data.state is None
storage, offset, size, stride, requires_grad, hooks = data.args
if len(data.args) == 6:
storage, offset, size, stride, requires_grad, hooks = data.args
else:
storage, offset, size, stride, requires_grad, hooks, metadata = data.args
storage_info = get_storage_info(storage)
return {"__tensor_v2__": [storage_info, offset, size, stride, requires_grad]}
if typename == "torch._utils._rebuild_qtensor":
Expand Down

0 comments on commit eb9b156

Please sign in to comment.