Skip to content

Commit

Permalink
Comparisson graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeyt8 committed Dec 6, 2023
1 parent 1fa2bd0 commit 993b7ed
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 13 deletions.
21 changes: 11 additions & 10 deletions Rubiks_cube.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"source": [
"from pocket_cube.cube import Cube\n",
"from pocket_cube.cube import Move\n",
"from tests import test_list, test, is_solved, TestCase, draw_graph, test_mcts\n",
"from tests import test_list, test, is_solved, TestCase, draw_graph, test_mcts, draw_comparison_graph\n",
"from heuristics import hamming, blocked_hamming, manhattan, build_database, database_heuristic, is_admissible\n",
"from utils import get_neighbors, get_path, met_in_the_middle, FrontierItem, DiscoveredDict\n",
"\n",
Expand Down Expand Up @@ -108,6 +108,16 @@
""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"draw_comparison_graph(test_res_astar, test_res_bfs, \"A*\", \"Bidirectional BFS\")\n",
""
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -258,15 +268,6 @@
"test_mcts(lambda cube, budget, c, heuristic: play_mcts(cube, budget, c, heuristic), [lambda cube: database_heuristic(cube, database, manhattan)])\n",
""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
""
]
}
],
"nbformat": 4,
Expand Down
48 changes: 45 additions & 3 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ def test(algorithm: Callable[[Cube], tuple[list[Move], int]], tests: list[list[M
return res

def test_mcts(algorithm: Callable[[Cube, int, float, Callable[[Cube], int]], tuple[list[Move], int]], heuristic_list: list[Callable[[Cube], int]]) -> None:
for heuristic in heuristic_list:
for c in [0.1, 0.5]:
for budget in [1000, 5000, 10000, 20000]:
for c in [0.1, 0.5]:
for budget in [1000, 5000, 10000, 20000]:
test_results_for_compare = []
for heuristic in heuristic_list:
print(f"Heuristic: {heuristic.__name__} Budget: {budget}, c: {c}")
test_results: list[list[TestCase]] = []
for _ in range(0, 20):
Expand All @@ -84,7 +85,9 @@ def test_mcts(algorithm: Callable[[Cube, int, float, Callable[[Cube], int]], tup
else:
print(f"Accuracy: {no_passed / len(test_results) * 100}%. Test {i} average: Time: {test_result[0] / no_passed} seconds. States expanded: {test_result[1] / no_passed}. Path length: {test_result[2] / no_passed}")
test_results_averaged.append((True, test_result[0] / no_passed, test_result[1] / no_passed, test_result[2] / no_passed))
test_results_for_compare.append(test_results_averaged)
draw_graph(test_results_averaged)
draw_comparison_graph(test_results_for_compare[0], test_results_for_compare[1], heuristic_list[0].__name__, heuristic_list[1].__name__)

def draw_graph(test_cases: list[TestCase]) -> None:
# time plot
Expand Down Expand Up @@ -113,4 +116,43 @@ def draw_graph(test_cases: list[TestCase]) -> None:
ax.set_xlabel("Test")
ax.set_ylabel("Path length")
ax.set_title("Path length of each test")
plt.show()

def draw_comparison_graph(test_cases1: list[TestCase], test_cases2: list[TestCase], label1, label2) -> None:
# draw graph similar to the ones in draw_graph, but make it a comparison graph by using 2 columns per x value, one for each test case
# time plot
fig, ax = plt.subplots()
ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
bars1 = ax.bar(np.arange(len(test_cases1))-0.2, [test_case[1] for test_case in test_cases1], width=0.4, label=label1)
bars2 = ax.bar(np.arange(len(test_cases2))+0.2, [test_case[1] for test_case in test_cases2], width=0.4, label=label2)
ax.bar_label(bars1)
ax.bar_label(bars2)
ax.set_xlabel("Test")
ax.set_ylabel("Time")
ax.set_title("Time taken by each test")
plt.legend()
plt.show()
# same for states expanded
fig, ax = plt.subplots()
ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
bars1 = ax.bar(np.arange(len(test_cases1))-0.2, [test_case[2] for test_case in test_cases1], width=0.4, label=label1)
bars2 = ax.bar(np.arange(len(test_cases2))+0.2, [test_case[2] for test_case in test_cases2], width=0.4, label=label2)
ax.bar_label(bars1)
ax.bar_label(bars2)
ax.set_xlabel("Test")
ax.set_ylabel("States expanded")
ax.set_title("States expanded by each test")
plt.legend()
plt.show()
# same for path length
fig, ax = plt.subplots()
ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
bars1 = ax.bar(np.arange(len(test_cases1))-0.2, [test_case[3] for test_case in test_cases1], width=0.4, label=label1)
bars2 = ax.bar(np.arange(len(test_cases2))+0.2, [test_case[3] for test_case in test_cases2], width=0.4, label=label2)
ax.bar_label(bars1)
ax.bar_label(bars2)
ax.set_xlabel("Test")
ax.set_ylabel("Path length")
ax.set_title("Path length of each test")
plt.legend()
plt.show()

0 comments on commit 993b7ed

Please sign in to comment.