Skip to content

Commit

Permalink
[TEST] Add memoize to save test data (apache#424)
Browse files Browse the repository at this point in the history
* [TEST] Add memoize to save test data

* Update comment

* mark py version
  • Loading branch information
tqchen authored Sep 5, 2017
1 parent 071b138 commit df3c996
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 45 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ jvm/*/*/target/
*.perspectivev3
!default.perspectivev3
xcuserdata/
.pkl_memoize_*

.emscripten*
.m2
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/_ffi/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def register_extension(cls):
.. code-block:: python
@tvm.register_dltensor
@tvm.register_extension
class MyTensor(object):
def __init__(self):
self.handle = _LIB.NewDLTensor()
Expand Down
91 changes: 91 additions & 0 deletions python/tvm/contrib/pickle_memoize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Memoize result of function via pickle, used for cache testcases."""
# pylint: disable=broad-except,superfluous-parens
import os
import sys
import atexit
from decorator import decorate
from .._ffi.base import string_types
try:
import cPickle as pickle
except ImportError:
import pickle

class Cache(object):
"""A cache object for result cache.
Parameters
----------
key: str
The file key to the function
"""
cache_by_key = {}
def __init__(self, key):
cache_dir = ".pkl_memoize_py{0}".format(sys.version_info[0])
if not os.path.exists(cache_dir):
os.mkdir(cache_dir)
self.path = os.path.join(cache_dir, key)
if os.path.exists(self.path):
try:
self.cache = pickle.load(open(self.path, "rb"))
except Exception:
self.cache = {}
else:
self.cache = {}
self.dirty = False

def save(self):
if self.dirty:
print("Save memoize result to %s" % self.path)
with open(self.path, "wb") as out_file:
pickle.dump(self.cache, out_file, pickle.HIGHEST_PROTOCOL)

@atexit.register
def _atexit():
"""Save handler."""
for value in Cache.cache_by_key.values():
value.save()


def memoize(key):
"""Memoize the result of function and reuse multiple times.
Parameters
----------
key: str
The unique key to the file
Returns
-------
fmemoize : function
The decorator function to perform memoization.
"""
def _register(f):
"""Registration function"""
allow_types = (string_types, int, float)
fkey = key + "." + f.__name__ + ".pkl"
if fkey not in Cache.cache_by_key:
Cache.cache_by_key[fkey] = Cache(fkey)
cache = Cache.cache_by_key[fkey]
cargs = tuple(x.cell_contents for x in f.__closure__)
cargs = (len(cargs),) + cargs

def _memoized_f(func, *args, **kwargs):
assert not kwargs, "Only allow positional call"
key = cargs + args
for arg in key:
if isinstance(arg, tuple):
for x in arg:
assert isinstance(x, allow_types)
else:
assert isinstance(arg, allow_types)
if key in cache.cache:
print("Use memoize {0}{1}".format(fkey, key))
return cache.cache[key]
res = func(*args)
cache.cache[key] = res
cache.dirty = True
return res

return decorate(f, _memoized_f)

return _register
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import numpy as np
import tvm
import topi
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple


def verify_conv2d_hwcn_map(batch, in_channel, in_size, num_filter, kernel, stride, padding):
def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, padding):
in_height = in_width = in_size

A = tvm.placeholder((in_height, in_width, in_channel, batch), name='A')
Expand All @@ -16,10 +17,18 @@ def verify_conv2d_hwcn_map(batch, in_channel, in_size, num_filter, kernel, strid
s1 = topi.cuda.schedule_conv2d_hwcn([B])
s2 = topi.cuda.schedule_conv2d_hwcn([C])

a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
w_np = np.random.uniform(size=get_const_tuple(W.shape)).astype(W.dtype)
b_np = topi.testing.conv2d_hwcn_python(a_np, w_np, stride, padding)
c_np = np.maximum(b_np, 0)
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
dtype = A.dtype

@memoize("topi.tests.test_topi_conv2d_hwcn.verify_hwcn")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = topi.testing.conv2d_hwcn_python(a_np, w_np, stride, padding)
c_np = np.maximum(b_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()

def check_device(device):
if not tvm.module.enabled(device):
Expand All @@ -44,16 +53,16 @@ def check_device(device):
check_device(device)


def test_conv2d_hwcn_map():
verify_conv2d_hwcn_map(1, 256, 32, 256, 3, 1, "SAME")
verify_conv2d_hwcn_map(1, 256, 32, 256, 3, 1, "SAME")
verify_conv2d_hwcn_map(4, 128, 16, 128, 5, 2, "SAME")
verify_conv2d_hwcn_map(4, 128, 16, 256, 5, 2, "SAME")
verify_conv2d_hwcn_map(1, 256, 32, 256, 3, 1, "VALID")
verify_conv2d_hwcn_map(1, 256, 32, 256, 3, 1, "VALID")
verify_conv2d_hwcn_map(4, 128, 16, 128, 5, 2, "VALID")
verify_conv2d_hwcn_map(4, 128, 16, 256, 5, 2, "VALID")
def test_conv2d_hwcn():
verify_conv2d_hwcn(1, 256, 32, 256, 3, 1, "SAME")
verify_conv2d_hwcn(1, 256, 32, 256, 3, 1, "SAME")
verify_conv2d_hwcn(4, 128, 16, 128, 5, 2, "SAME")
verify_conv2d_hwcn(4, 128, 16, 256, 5, 2, "SAME")
verify_conv2d_hwcn(1, 256, 32, 256, 3, 1, "VALID")
verify_conv2d_hwcn(1, 256, 32, 256, 3, 1, "VALID")
verify_conv2d_hwcn(4, 128, 16, 128, 5, 2, "VALID")
verify_conv2d_hwcn(4, 128, 16, 256, 5, 2, "VALID")


if __name__ == "__main__":
test_conv2d_hwcn_map()
test_conv2d_hwcn()
18 changes: 14 additions & 4 deletions topi/tests/python/test_topi_conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import tvm
import topi
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple


Expand All @@ -16,10 +17,19 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
s1 = topi.cuda.schedule_conv2d_nchw([B])
s2 = topi.cuda.schedule_conv2d_nchw([C])

a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
w_np = np.random.uniform(size=get_const_tuple(W.shape)).astype(W.dtype)
b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
c_np = np.maximum(b_np, 0)
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
dtype = A.dtype

@memoize("topi.tests.test_topi_conv2d.verify_con2d_nchw")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
c_np = np.maximum(b_np, 0)
return a_np, w_np, b_np, c_np

a_np, w_np, b_np, c_np = get_ref_data()

def check_device(device):
if not tvm.module.enabled(device):
Expand Down
17 changes: 14 additions & 3 deletions topi/tests/python/test_topi_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import tvm
import topi
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple


Expand All @@ -13,11 +14,21 @@ def verify_convolution(batch, in_size, in_channel, num_filter, kernel, stride, p
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
B = topi.nn.convolution(A, W, stride, padding)

s = topi.rasp.schedule_convolution([B])

a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
w_np = np.random.uniform(size=get_const_tuple(W.shape)).astype(W.dtype)
b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
dtype = A.dtype

@memoize("topi.tests.test_topi_convolution.verify_convolution")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
return a_np, w_np, b_np

a_np, w_np, b_np = get_ref_data()

ctx = tvm.cpu(0)
a = tvm.nd.array(a_np, ctx)
Expand Down
82 changes: 60 additions & 22 deletions topi/tests/python/test_topi_depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import numpy as np
from scipy import signal
from topi.util import get_const_tuple
from tvm.contrib.pickle_memoize import memoize
from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc


def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):
in_width = in_height
filter_channel = in_channel
Expand All @@ -25,11 +27,6 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
s2 = schedule_depthwise_conv2d_nchw(ScaleShift)
s3 = schedule_depthwise_conv2d_nchw(Relu)

input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype)
filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype)
scale_np = np.random.uniform(size=get_const_tuple(Scale.shape)).astype(Scale.dtype)
shift_np = np.random.uniform(size=get_const_tuple(Shift.shape)).astype(Shift.dtype)

def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
Expand All @@ -39,7 +36,35 @@ def check_device(device):
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device)
# prepare data

# Prepare pod type for test data closure
dtype = Input.dtype
input_shape = get_const_tuple(Input.shape)
filter_shape = get_const_tuple(Filter.shape)
scale_shape = get_const_tuple(Scale.shape)
shift_shape = get_const_tuple(Shift.shape)
scale_shift_shape = get_const_tuple(ScaleShift.shape)

# Use memoize, pickle the test data for next time use.
@memoize("topi.tests.test_topi_depthwise_conv2d.nchw")
def get_ref_data():
input_np = np.random.uniform(size=input_shape).astype(dtype)
filter_np = np.random.uniform(size=filter_shape).astype(dtype)
scale_np = np.random.uniform(size=scale_shape).astype(dtype)
shift_np = np.random.uniform(size=shift_shape).astype(dtype)
# correctness with scipy
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(
input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
scale_shift_scipy = np.zeros(shape=scale_shift_shape)
for c in range(in_channel * channel_multiplier):
scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
relu_scipy = np.maximum(scale_shift_scipy, 0)
return (input_np, filter_np, scale_np, shift_np,
depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy)
# Get the test data
(input_np, filter_np, scale_np, shift_np,
depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy) = get_ref_data()

input_tvm = tvm.nd.array(input_np, ctx)
filter_tvm = tvm.nd.array(filter_np, ctx)
scale_tvm = tvm.nd.array(scale_np, ctx)
Expand All @@ -56,12 +81,6 @@ def check_device(device):
# launch kernel 3 (depthwise_conv2d + scale_shift + relu)
timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1)
tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean
# correctness with scipy
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape))
for c in range(in_channel * channel_multiplier):
scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
relu_scipy = np.maximum(scale_shift_scipy, 0)
np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
Expand Down Expand Up @@ -90,11 +109,6 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
s2 = schedule_depthwise_conv2d_nhwc(ScaleShift)
s3 = schedule_depthwise_conv2d_nhwc(Relu)

input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype)
filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype)
scale_np = np.random.uniform(size=get_const_tuple(Scale.shape)).astype(Scale.dtype)
shift_np = np.random.uniform(size=get_const_tuple(Shift.shape)).astype(Shift.dtype)

def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
Expand All @@ -104,6 +118,35 @@ def check_device(device):
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device)

# Prepare pod type for test data closure
dtype = Input.dtype
input_shape = get_const_tuple(Input.shape)
filter_shape = get_const_tuple(Filter.shape)
scale_shape = get_const_tuple(Scale.shape)
shift_shape = get_const_tuple(Shift.shape)
scale_shift_shape = get_const_tuple(ScaleShift.shape)

# Use memoize, pickle the test data for next time use.
@memoize("topi.tests.test_topi_depthwise_conv2d.nhwc")
def get_ref_data():
input_np = np.random.uniform(size=input_shape).astype(dtype)
filter_np = np.random.uniform(size=filter_shape).astype(dtype)
scale_np = np.random.uniform(size=scale_shape).astype(dtype)
shift_np = np.random.uniform(size=shift_shape).astype(dtype)
# correctness with scipy
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nhwc(
input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
scale_shift_scipy = np.zeros(shape=scale_shift_shape)
for c in range(in_channel * channel_multiplier):
scale_shift_scipy[:,:,:,c] = depthwise_conv2d_scipy[:,:,:,c] * scale_np[c] + shift_np[c]
relu_scipy = np.maximum(scale_shift_scipy, 0)
return (input_np, filter_np, scale_np, shift_np,
depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy)
# Get the test data
(input_np, filter_np, scale_np, shift_np,
depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy) = get_ref_data()

# prepare data
input_tvm = tvm.nd.array(input_np, ctx)
filter_tvm = tvm.nd.array(filter_np, ctx)
Expand All @@ -121,11 +164,6 @@ def check_device(device):
# launch kernel 3 (depthwise_conv2d + scale_shift + relu)
timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1)
tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean
# correctness with scipy
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nhwc(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape))
for c in range(in_channel * channel_multiplier):
scale_shift_scipy[:,:,:,c] = depthwise_conv2d_scipy[:,:,:,c] * scale_np[c] + shift_np[c]
relu_scipy = np.maximum(scale_shift_scipy, 0)
np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
Expand Down

0 comments on commit df3c996

Please sign in to comment.