forked from apple/corenet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_file_logger.py
45 lines (36 loc) · 1.23 KB
/
test_file_logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
import os
import tempfile
import pytest
import torch
from corenet.utils.file_logger import FileLogger
@pytest.mark.parametrize(
"metric_name, epoch1, value1, epoch2, value2",
[("metric", 0, 1.0, 1, 2.0), ("metric2", 5, 1.0, 6, 2.0)],
)
def test_file_logger(
metric_name: str, epoch1: int, value1: float, epoch2: int, value2: float
) -> None:
with tempfile.TemporaryDirectory() as tempdir:
# Case 1: The file doesn't exist.
filename = os.path.join(tempdir, "stats.pt")
logger = FileLogger(filename)
logger.add_scalar(metric_name, value1, epoch1)
logger.close()
assert os.path.exists(filename)
a = torch.load(filename)
assert a == {"epochs": {epoch1: {"metrics": {metric_name: value1}}}}
# Case 2: The file does exist.
logger = FileLogger(filename)
logger.add_scalar(metric_name, value2, epoch2)
logger.close()
a = torch.load(filename)
assert a == {
"epochs": {
epoch1: {"metrics": {metric_name: value1}},
epoch2: {"metrics": {metric_name: value2}},
}
}