Skip to content

Commit

Permalink
Deterministic hashing for tensors (#398)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkgr authored Sep 12, 2022
1 parent 230d78e commit 15196f2
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
was found to be in an invalid state.
- Improved cluster choice logic in `BeakerExecutor` to ensure greater diversity of clusters when submitting many steps at once.
- Fixed bug where sub-processes of the multicore executor would use the wrong executor if `executor` was defined in a `tango.yml` file.
- Deterministic hashes for numpy and torch tensors were not deterministic. Now they are.


## [v0.12.0](https://github.com/allenai/tango/releases/tag/v0.12.0) - 2022-08-23
Expand Down
29 changes: 27 additions & 2 deletions tango/common/det_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,23 @@
import hashlib
import io
from abc import abstractmethod
from typing import Any, MutableMapping, Optional
from typing import Any, MutableMapping, Optional, Type

import base58
import dill

ndarray: Optional[Type]
try:
from numpy import ndarray
except ModuleNotFoundError:
ndarray = None

TorchTensor: Optional[Type]
try:
from torch import Tensor as TorchTensor
except ModuleNotFoundError:
TorchTensor = None


class CustomDetHash:
"""
Expand Down Expand Up @@ -82,9 +94,12 @@ def det_hash_object(self) -> Any:
return None # When you return `None` from here, it falls back to just hashing the object itself.


_PICKLE_PROTOCOL = 4


class _DetHashPickler(dill.Pickler):
def __init__(self, buffer: io.BytesIO):
super().__init__(buffer, protocol=4)
super().__init__(buffer, protocol=_PICKLE_PROTOCOL)

# We keep track of how deeply we are nesting the pickling of an object.
# If a class returns `self` as part of `det_hash_object()`, it causes an
Expand All @@ -111,6 +126,16 @@ def persistent_id(self, obj: Any) -> Any:
return None
elif isinstance(obj, type):
return obj.__module__, obj.__qualname__
elif ndarray is not None and isinstance(obj, ndarray):
# It's unclear why numpy arrays don't pickle in a consistent way.
return obj.dumps()
elif TorchTensor is not None and isinstance(obj, TorchTensor):
# It's unclear why torch tensors don't pickle in a consistent way.
import torch

with io.BytesIO() as buffer:
torch.save(obj, buffer, pickle_protocol=_PICKLE_PROTOCOL)
return buffer.getvalue()
else:
return None

Expand Down
Binary file added templates.pkl
Binary file not shown.
18 changes: 18 additions & 0 deletions tests/integrations/torch/det_hash_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import numpy
import torch

from tango.common import det_hash


def test_numpy_det_hash():
a1 = numpy.array([[1, 2], [3, 4]], order="C")
a2 = numpy.array([[1, 2], [3, 4]], order="K")
assert det_hash(a1) == det_hash(a2)


def test_torch_det_hash():
a1 = numpy.array([[1, 2], [3, 4]], order="C")
a2 = numpy.array([[1, 2], [3, 4]], order="K")
a1 = torch.tensor(a1)
a2 = torch.tensor(a2)
assert det_hash(a1) == det_hash(a2)

0 comments on commit 15196f2

Please sign in to comment.