forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_recordio.py
92 lines (82 loc) · 2.84 KB
/
test_recordio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: skip-file
import sys
import mxnet as mx
import numpy as np
import tempfile
import random
import string
from common import setup_module, with_seed, teardown
@with_seed()
def test_recordio():
frec = tempfile.mktemp()
N = 255
writer = mx.recordio.MXRecordIO(frec, 'w')
for i in range(N):
if sys.version_info[0] < 3:
writer.write(str(chr(i)))
else:
writer.write(bytes(str(chr(i)), 'utf-8'))
del writer
reader = mx.recordio.MXRecordIO(frec, 'r')
for i in range(N):
res = reader.read()
if sys.version_info[0] < 3:
assert res == str(chr(i))
else:
assert res == bytes(str(chr(i)), 'utf-8')
@with_seed()
def test_indexed_recordio():
fidx = tempfile.mktemp()
frec = tempfile.mktemp()
N = 255
writer = mx.recordio.MXIndexedRecordIO(fidx, frec, 'w')
for i in range(N):
if sys.version_info[0] < 3:
writer.write_idx(i, str(chr(i)))
else:
writer.write_idx(i, bytes(str(chr(i)), 'utf-8'))
del writer
reader = mx.recordio.MXIndexedRecordIO(fidx, frec, 'r')
keys = reader.keys
assert sorted(keys) == [i for i in range(N)]
random.shuffle(keys)
for i in keys:
res = reader.read_idx(i)
if sys.version_info[0] < 3:
assert res == str(chr(i))
else:
assert res == bytes(str(chr(i)), 'utf-8')
@with_seed()
def test_recordio_pack_label():
frec = tempfile.mktemp()
N = 255
for i in range(1, N):
for j in range(N):
content = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(j))
content = content.encode('utf-8')
label = np.random.uniform(size=i).astype(np.float32)
header = (0, label, 0, 0)
s = mx.recordio.pack(header, content)
rheader, rcontent = mx.recordio.unpack(s)
assert (label == rheader.label).all()
assert content == rcontent
if __name__ == '__main__':
test_recordio_pack_label()
test_recordio()
test_indexed_recordio()