Skip to content

Commit

Permalink
[Bugfix] Add bool data type to backend. (dmlc#1487)
Browse files Browse the repository at this point in the history
* add bool to F.data_type_dict

* add utest

* skip bool test for mx
  • Loading branch information
jermainewang authored Apr 29, 2020
1 parent f1e4f37 commit 3c4506e
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 7 deletions.
1 change: 1 addition & 0 deletions python/dgl/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def data_type_dict():
int16
int32
int64
bool
This function will be called only *once* during the initialization fo the
backend module. The returned dictionary will become the attributes of the
Expand Down
3 changes: 2 additions & 1 deletion python/dgl/backend/mxnet/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def data_type_dict():
'int8' : np.int8,
'int16' : np.int16,
'int32' : np.int32,
'int64' : np.int64}
'int64' : np.int64,
'bool' : np.bool}

def cpu():
return mx.cpu()
Expand Down
3 changes: 2 additions & 1 deletion python/dgl/backend/pytorch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def data_type_dict():
'int8' : th.int8,
'int16' : th.int16,
'int32' : th.int32,
'int64' : th.int64}
'int64' : th.int64,
'bool' : th.bool}

def cpu():
return th.device('cpu')
Expand Down
3 changes: 2 additions & 1 deletion python/dgl/backend/tensorflow/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def data_type_dict():
'int8': tf.int8,
'int16': tf.int16,
'int32': tf.int32,
'int64': tf.int64}
'int64': tf.int64,
'bool' : tf.bool}

def cpu():
return "/cpu:0"
Expand Down
28 changes: 24 additions & 4 deletions tests/compute/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import backend as F
import dgl
import unittest
import pickle
import pytest
import io

N = 10
D = 5
Expand All @@ -15,10 +18,10 @@ def check_fail(fn):
except:
return True

def create_test_data(grad=False):
c1 = F.randn((N, D))
c2 = F.randn((N, D))
c3 = F.randn((N, D))
def create_test_data(grad=False, dtype=F.float32):
c1 = F.astype(F.randn((N, D)), dtype)
c2 = F.astype(F.randn((N, D)), dtype)
c3 = F.astype(F.randn((N, D)), dtype)
if grad:
c1 = F.attach_grad(c1)
c2 = F.attach_grad(c2)
Expand Down Expand Up @@ -357,6 +360,23 @@ def test_inplace():
newa2addr = id(f['a2'])
assert a2addr == newa2addr

def _reconstruct_pickle(obj):
f = io.BytesIO()
pickle.dump(obj, f)
f.seek(0)
obj = pickle.load(f)
f.close()
return obj

@pytest.mark.parametrize('dtype',
[F.float32, F.int32] if dgl.backend.backend_name == "mxnet" else [F.float32, F.int32, F.bool])
def test_pickle(dtype):
f = create_test_data(dtype=dtype)
newf = _reconstruct_pickle(f)
assert F.array_equal(f['a1'], newf['a1'])
assert F.array_equal(f['a2'], newf['a2'])
assert F.array_equal(f['a3'], newf['a3'])

if __name__ == '__main__':
test_create()
test_column1()
Expand Down

0 comments on commit 3c4506e

Please sign in to comment.