-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathdlio_postprocessor_test.py
61 lines (49 loc) · 2.21 KB
/
dlio_postprocessor_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
Copyright (c) 2022, UChicago Argonne, LLC
All Rights Reserved
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
#!/usr/bin/env python
from collections import namedtuple
import unittest
from dlio_benchmark.postprocessor import DLIOPostProcessor
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['AUTOGRAPH_VERBOSITY'] = '0'
class TestDLIOPostProcessor(unittest.TestCase):
def create_DLIO_PostProcessor(self, args):
return DLIOPostProcessor(args)
def test_process_loading_and_processing_times(self):
args = {
'output_folder': 'tests/test_data',
'name': '',
'num_proc': 2,
'epochs': 2,
'do_eval': False,
'do_checkpoint': False,
'batch_size': 4,
'batch_size_eval': 1,
'record_size':234560851
}
args = namedtuple('args', args.keys())(*args.values())
postproc = self.create_DLIO_PostProcessor(args)
postproc.process_loading_and_processing_times()
# Expected values: {
# 'samples/s': {'mean': '3.27', 'std': '2.39', 'min': '1.33', 'median': '2.33', 'p90': '7.60', 'p99': '8.00', 'max': '8.00'},
# 'sample_latency': {'mean': '3.27', 'std': '2.39', 'min': '1.33', 'median': '2.33', 'p90': '7.60', 'p99': '8.00', 'max': '8.00'},
# 'avg_process_loading_time': '21.00',
# 'avg_process_processing_time': '21.00'
# }
self.assertEqual(postproc.overall_stats['samples/s']['mean'], '5.10')
self.assertEqual(postproc.overall_stats['avg_process_loading_time'], '7.78')
self.assertEqual(postproc.overall_stats['avg_process_processing_time'], '65.87')
if __name__ == '__main__':
unittest.main()