-
Notifications
You must be signed in to change notification settings - Fork 0
/
tests.py
158 lines (146 loc) · 6.9 KB
/
tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from typing import Callable
import time
from pocket_cube.cube import Cube, Move
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
case1 = "R U' R' F' U"
case2 = "F' R U R U F' U'"
case3 = "F U U F' U' R R F' R"
case4 = "U' R U' F' R F F U' F U U"
# list of tests
test_list: list[list[Move]] = list(
map(lambda t: list(map(Move.from_str, t.split(" "))),
[case1, case2, case3, case4])
)
TestCase = tuple[bool, float, int, int]
def is_solved(cube: Cube) -> bool:
"""
Checks if the cube is solved.
Args:
cube (Cube): The cube to check.
Returns:
bool: True if the cube is solved, False otherwise.
"""
return np.array_equal(cube.state, cube.goal_state)
def test(algorithm: Callable[[Cube], tuple[list[Move], int]], tests: list[list[Move]], log: bool = True) -> list[TestCase]:
"""
Tests the algorithm with the given tests.
Args:
algorithm (Callable[[Cube], list[Move]]): The algorithm to test.
tests (list[list[Move]]): The tests to run.
Returns:
list[tuple[float, int, int]]: The time taken, the number of states expanded and the length of the path for each test.
"""
res: list[TestCase] = []
for idx, test in enumerate(tests):
success: bool = True
cube: Cube = Cube(test)
start = time.time()
(path, states) = algorithm(cube)
end = time.time()
for move in path:
cube = cube.move(move)
if not is_solved(cube):
if log:
print(f"Test {idx} failed. Time: {end - start} seconds. States expanded: {states}. Path length: {len(path)}")
success = False
else:
if log:
print(f"Test {idx} passed. Time: {end - start} seconds. States expanded: {states}. Path length: {len(path)}")
res.append((success, end - start, states, len(path)))
return res
def test_mcts(algorithm: Callable[[Cube, int, float, Callable[[Cube], int]], tuple[list[Move], int]], heuristic_list: list[Callable[[Cube], int]]) -> None:
for c in [0.1, 0.5]:
for budget in [1000, 5000, 10000, 20000]:
test_results_for_compare = []
for heuristic in heuristic_list:
print(f"Heuristic: {heuristic.__name__} Budget: {budget}, c: {c}")
test_results: list[list[TestCase]] = []
for _ in range(0, 20):
test_results.append(test(lambda cube: algorithm(cube, budget, c, heuristic), test_list, False))
# compute average test result
test_results_averaged = []
for i in range(len(test_list)):
# average for test i
test_result: tuple[float, int, int] = (0, 0, 0)
no_passed: int = 0
for result in test_results:
if result[i][0]:
no_passed += 1
test_result = (test_result[0] + result[i][1], test_result[1] + result[i][2], test_result[2] + result[i][3])
if (no_passed == 0):
print(f"Test {i} failed.")
test_results_averaged.append((False, 0, 0, 0))
else:
print(f"Accuracy: {no_passed / len(test_results) * 100}%. Test {i} average: Time: {test_result[0] / no_passed} seconds. States expanded: {test_result[1] / no_passed}. Path length: {test_result[2] / no_passed}")
test_results_averaged.append((True, test_result[0] / no_passed, test_result[1] / no_passed, test_result[2] / no_passed))
test_results_for_compare.append(test_results_averaged)
draw_graph(test_results_averaged)
draw_comparison_graph(test_results_for_compare[0], test_results_for_compare[1], heuristic_list[0].__name__, heuristic_list[1].__name__)
def draw_graph(test_cases: list[TestCase]) -> None:
# time plot
fig, ax = plt.subplots()
ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
bars = ax.bar(range(len(test_cases)), [test_case[1] for test_case in test_cases])
ax.bar_label(bars)
ax.set_xlabel("Test")
ax.set_ylabel("Time")
ax.set_title("Time taken by each test")
plt.show()
# states expanded plot
fig, ax = plt.subplots()
ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
bars = ax.bar(range(len(test_cases)), [test_case[2] for test_case in test_cases])
ax.bar_label(bars)
ax.set_xlabel("Test")
ax.set_ylabel("States expanded")
ax.set_title("States expanded by each test")
plt.show()
# path length plot
fig, ax = plt.subplots()
ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
bars = ax.bar(range(len(test_cases)), [test_case[3] for test_case in test_cases])
ax.bar_label(bars)
ax.set_xlabel("Test")
ax.set_ylabel("Path length")
ax.set_title("Path length of each test")
plt.show()
def draw_comparison_graph(test_cases1: list[TestCase], test_cases2: list[TestCase], label1, label2) -> None:
# draw graph similar to the ones in draw_graph, but make it a comparison graph by using 2 columns per x value, one for each test case
# time plot
fig, ax = plt.subplots()
ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
bars1 = ax.bar(np.arange(len(test_cases1))-0.2, [test_case[1] for test_case in test_cases1], width=0.4, label=label1)
bars2 = ax.bar(np.arange(len(test_cases2))+0.2, [test_case[1] for test_case in test_cases2], width=0.4, label=label2)
ax.bar_label(bars1)
ax.bar_label(bars2)
ax.set_xlabel("Test")
ax.set_ylabel("Time")
ax.set_title("Time taken by each test")
plt.legend()
plt.show()
# same for states expanded
fig, ax = plt.subplots()
ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
bars1 = ax.bar(np.arange(len(test_cases1))-0.2, [test_case[2] for test_case in test_cases1], width=0.4, label=label1)
bars2 = ax.bar(np.arange(len(test_cases2))+0.2, [test_case[2] for test_case in test_cases2], width=0.4, label=label2)
ax.bar_label(bars1)
ax.bar_label(bars2)
ax.set_xlabel("Test")
ax.set_ylabel("States expanded")
ax.set_title("States expanded by each test")
plt.legend()
plt.show()
# same for path length
fig, ax = plt.subplots()
ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
bars1 = ax.bar(np.arange(len(test_cases1))-0.2, [test_case[3] for test_case in test_cases1], width=0.4, label=label1)
bars2 = ax.bar(np.arange(len(test_cases2))+0.2, [test_case[3] for test_case in test_cases2], width=0.4, label=label2)
ax.bar_label(bars1)
ax.bar_label(bars2)
ax.set_xlabel("Test")
ax.set_ylabel("Path length")
ax.set_title("Path length of each test")
plt.legend()
plt.show()