Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeyt8 committed Dec 5, 2023
1 parent 4b2fdaf commit 1fa2bd0
Showing 1 changed file with 53 additions and 21 deletions.
74 changes: 53 additions & 21 deletions Rubiks_cube.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"from heapq import heappush, heappop\n",
"from typing import Callable\n",
"import time\n",
"\n",
""
]
},
Expand Down Expand Up @@ -42,7 +43,8 @@
" node: FrontierItem = FrontierItem(score + heuristic(neighbor), neighbor)\n",
" heappush(frontier, node)\n",
" # get path\n",
" return (get_path(currentCube.hash(), discovered), len(discovered))"
" return (get_path(currentCube.hash(), discovered), len(discovered))\n",
""
]
},
{
Expand All @@ -53,7 +55,8 @@
"source": [
"# Test A*\n",
"test_res_astar: list[TestCase] = test(lambda cube: a_star(cube, manhattan), test_list)\n",
"draw_graph(test_res_astar)"
"draw_graph(test_res_astar)\n",
""
]
},
{
Expand Down Expand Up @@ -89,7 +92,8 @@
" path2: list[Move] = get_path(met_cube_key, discovereds[1])\n",
" path2.reverse()\n",
" path2 = list(map(Move.opposite, path2))\n",
" return (path1 + path2, len(discovereds[0]) + len(discovereds[1]))"
" return (path1 + path2, len(discovereds[0]) + len(discovereds[1]))\n",
""
]
},
{
Expand All @@ -100,7 +104,8 @@
"source": [
"# Test Bidirectional BFS\n",
"test_res_bfs: list[TestCase] = test(bidirectional_bfs, test_list)\n",
"draw_graph(test_res_bfs)"
"draw_graph(test_res_bfs)\n",
""
]
},
{
Expand All @@ -109,13 +114,12 @@
"metadata": {},
"outputs": [],
"source": [
"# # MTCS with UCB\n",
"# MTCS with UCB\n",
"from math import sqrt, log\n",
"\n",
"N = \"N\"\n",
"Q = \"Q\"\n",
"PARENT = \"PARENT\"\n",
"MOVE = \"MOVE\"\n",
"CHILDREN = \"CHILDREN\"\n",
"Node = dict[int, int, Cube, dict[Move, Cube]]\n",
"\n",
Expand All @@ -132,7 +136,8 @@
" if expr > max_expr:\n",
" max_expr = expr\n",
" max_move = move\n",
" return max_move"
" return max_move\n",
""
]
},
{
Expand All @@ -143,7 +148,7 @@
"source": [
"from random import choice\n",
"\n",
"def mcts(cube0: Cube, budget: int, tree: Node, cp: float, heuristic: Callable[[Cube], int]) -> Node:\n",
"def mcts(cube0: Cube, budget: int, tree: Node, cp: float, heuristic: Callable[[Cube], int]) -> tuple[list[Move], Node, int]:\n",
" states_visited: int = 0\n",
" if not tree:\n",
" tree = init_node()\n",
Expand Down Expand Up @@ -175,21 +180,35 @@
" max_h = max(max_h, 1 / max(heuristic(cube), 0.1))\n",
" max_moves -= 1\n",
" states_visited += 1\n",
" if is_solved(cube):\n",
" # return path from node to root\n",
" path: list[Move] = []\n",
" while node[PARENT]:\n",
" parent: Node = node[PARENT]\n",
" for move in parent[CHILDREN]:\n",
" if parent[CHILDREN][move] == node:\n",
" path.append(move)\n",
" break\n",
" node = parent\n",
" path.reverse()\n",
" return (path, tree, states_visited)\n",
" while node:\n",
" node[N] += 1\n",
" node[Q] += max_h\n",
" node = node[PARENT]\n",
" return (tree, states_visited)\n",
" return ([], tree, states_visited)\n",
"\n",
"def play_mcts(cube: Cube, budget: int, cp: float, heuristic: Callable[[Cube], int]) -> (list[Move], int):\n",
" (tree, states) = mcts(cube, budget, None, cp, heuristic)\n",
" node: Node = tree\n",
" path: list[Move] = []\n",
" while node and node[CHILDREN]:\n",
" move: Move = select_move(node, 0)\n",
" node = node[CHILDREN][move]\n",
" path.append(move)\n",
" return (path, states)"
" (path, tree, states) = mcts(cube, budget, None, cp, heuristic)\n",
" return (path, states)\n",
" #node: Node = tree\n",
" #path: list[Move] = []\n",
" #while node and node[CHILDREN]:\n",
" #move: Move = select_move(node, 0)\n",
" #node = node[CHILDREN][move]\n",
" #path.append(move)\n",
" #return (path, states)\n",
""
]
},
{
Expand All @@ -199,7 +218,8 @@
"outputs": [],
"source": [
"# Test MTCS\n",
"test_mcts(lambda cube, budget, c, heuristic: play_mcts(cube, budget, c, heuristic), [manhattan, blocked_hamming])"
"test_mcts(lambda cube, budget, c, heuristic: play_mcts(cube, budget, c, heuristic), [manhattan, blocked_hamming])\n",
""
]
},
{
Expand All @@ -212,7 +232,8 @@
"start = time.time()\n",
"database = build_database()\n",
"end = time.time()\n",
"print(f\"Database built in {end - start} seconds.\")"
"print(f\"Database built in {end - start} seconds.\")\n",
""
]
},
{
Expand All @@ -223,7 +244,8 @@
"source": [
"# Test A* with database\n",
"test_result_astar_database: list[TestCase] = test(lambda cube: a_star(cube, lambda cube: database_heuristic(cube, database, manhattan)), test_list)\n",
"draw_graph(test_result_astar_database)"
"draw_graph(test_result_astar_database)\n",
""
]
},
{
Expand All @@ -233,7 +255,17 @@
"outputs": [],
"source": [
"# Test MTCS with database\n",
"test_mcts(lambda cube, budget, c, heuristic: play_mcts(cube, budget, c, heuristic), lambda cube: database_heuristic(cube, database, manhattan))"
"test_mcts(lambda cube, budget, c, heuristic: play_mcts(cube, budget, c, heuristic), [lambda cube: database_heuristic(cube, database, manhattan)])\n",
""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
""
]
}
],
Expand Down

0 comments on commit 1fa2bd0

Please sign in to comment.