forked from hiive/mlrose
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_decay.py
94 lines (66 loc) · 2.36 KB
/
test_decay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
""" Unit tests for decay.py"""
# Author: Genevieve Hayes
# License: BSD 3 clause
try:
import mlrose_hiive
except:
import sys
sys.path.append("..")
import unittest
from mlrose_hiive import GeomDecay, ArithDecay, ExpDecay, CustomSchedule
class TestDecay(unittest.TestCase):
"""Tests for decay.py."""
@staticmethod
def test_geom_above_min():
"""Test geometric decay evaluation function for case where result is
above the minimum"""
schedule = GeomDecay(init_temp=10, decay=0.95, min_temp=1)
x = schedule.evaluate(5)
assert round(x, 5) == 7.73781
@staticmethod
def test_geom_below_min():
"""Test geometric decay evaluation function for case where result is
below the minimum"""
schedule = GeomDecay(init_temp=10, decay=0.95, min_temp=1)
x = schedule.evaluate(50)
assert x == 1
@staticmethod
def test_arith_above_min():
"""Test arithmetic decay evaluation function for case where result is
above the minimum"""
schedule = ArithDecay(init_temp=10, decay=0.95, min_temp=1)
x = schedule.evaluate(5)
assert x == 5.25
@staticmethod
def test_arith_below_min():
"""Test arithmetic decay evaluation function for case where result is
below the minimum"""
schedule = ArithDecay(init_temp=10, decay=0.95, min_temp=1)
x = schedule.evaluate(50)
assert x == 1
@staticmethod
def test_exp_above_min():
"""Test exponential decay evaluation function for case where result is
above the minimum"""
schedule = ExpDecay(init_temp=10, exp_const=0.05, min_temp=1)
x = schedule.evaluate(5)
assert round(x, 5) == 7.78801
@staticmethod
def test_exp_below_min():
"""Test exponential decay evaluation function for case where result is
below the minimum"""
schedule = ExpDecay(init_temp=10, exp_const=0.05, min_temp=1)
x = schedule.evaluate(50)
assert x == 1
@staticmethod
def test_custom():
"""Test custom evaluation function"""
# Define custom schedule function
def custom(t, c):
return t + c
kwargs = {'c': 10}
schedule = CustomSchedule(custom, **kwargs)
x = schedule.evaluate(5)
assert x == 15
if __name__ == '__main__':
unittest.main()