Skip to content

Commit ea24f19

Browse files
committed
Web visualizer
1 parent 97ef4fd commit ea24f19

File tree

5 files changed

+500
-0
lines changed

5 files changed

+500
-0
lines changed

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
numpy==2.2.2
22
torch==2.4.0
3+
Flask==3.1.0

web_vis.py

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""
2+
Usage:
3+
1. Verify flask is installed (run "pip install -r requirements.txt" if not)
4+
2. Run "flask --app web_vis run" in a terminal
5+
"""
6+
7+
# Imports
8+
from flask import Flask, request, jsonify
9+
import torch
10+
11+
from game_node import GameNode
12+
from network import NeuralNet
13+
14+
from data_preprocess import node_to_tensor
15+
16+
# Model setup
17+
MODEL_STATE_DICT_PATH = "model.pt" # Update this as needed
18+
19+
model = NeuralNet()
20+
21+
try:
22+
model.load_state_dict(torch.load(MODEL_STATE_DICT_PATH, weights_only=True))
23+
except:
24+
print(f"Failed to load model at {MODEL_STATE_DICT_PATH}")
25+
26+
res = ""
27+
28+
while res not in list("yn"):
29+
res = input("Load random model (y/n)? ")
30+
res = res.lower()
31+
32+
if res == "n":
33+
print("Program exited early: cannot run without model")
34+
exit(1)
35+
36+
# Set up game node
37+
SIZE = 9
38+
curr_node = GameNode(SIZE)
39+
40+
# Game node utils
41+
def small_string(node: GameNode):
42+
global SIZE
43+
invert = lambda s: s.replace("○", "B").replace("●", "W").replace("W", "○").replace("B", "●")
44+
return "\n".join([invert(s.replace(" ", "")[-SIZE:]) for s in str(node).split("\n")[3:]])
45+
46+
# Flask things (assumes model behaves well)
47+
app = Flask(__name__, static_folder="web_vis")
48+
49+
@app.route("/")
50+
def main():
51+
return app.send_static_file("index.html")
52+
53+
@app.route("/play_move", methods=["POST"])
54+
def play_move():
55+
global curr_node, SIZE
56+
57+
data = request.get_json()
58+
59+
if not data:
60+
return jsonify({"error": "No JSON data provided"}), 400
61+
62+
if "row" not in data or "col" not in data:
63+
return jsonify({"error": "JSON data missing row and/or col fields"}), 400
64+
65+
if (data["row"], data["col"]) != (-1, -1) and (not (0 <= data["row"] < SIZE) or not (0 <= data["col"] < SIZE)):
66+
return jsonify({"error": f"Specified location {data['row'], data['col']} is out of bounds"}), 400
67+
68+
if not curr_node.is_valid_move(data["row"], data["col"]):
69+
return jsonify({"error": f"Specified location {data['row'], data['col']} is an invalid move"}), 400
70+
71+
curr_node = curr_node.create_child((data["row"], data["col"]))
72+
73+
return "Good", 200
74+
75+
@app.route("/get_board", methods=["POST"])
76+
def get_board():
77+
return small_string(curr_node), 200
78+
79+
@app.route("/reset", methods=["POST"])
80+
def reset():
81+
global curr_node, SIZE
82+
83+
curr_node = GameNode(SIZE)
84+
85+
return "Good", 200
86+
87+
@app.route("/undo", methods=["POST"])
88+
def undo():
89+
global curr_node
90+
91+
if curr_node.prev is None:
92+
return jsonify({"error": "No move to undo"}), 400
93+
94+
curr_node = curr_node.prev
95+
96+
return "Good", 200
97+
98+
@app.route("/network", methods=["POST"])
99+
def network():
100+
global curr_node
101+
102+
policy, val = model(node_to_tensor(curr_node).unsqueeze(0))
103+
104+
policy = policy.softmax(1).flatten().detach()
105+
106+
policy /= policy.max()
107+
policy = policy / 5
108+
109+
policy *= torch.tensor(curr_node.available_moves_mask())
110+
111+
return jsonify({
112+
"policy": policy.tolist(),
113+
"value": val.detach().item()
114+
}), 200

web_vis/index.html

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
<!DOCTYPE html>
2+
<html lang="en">
3+
<head>
4+
<meta charset="UTF-8">
5+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
6+
<title>Mini-AlphaGo</title>
7+
8+
<link rel="preconnect" href="https://fonts.googleapis.com">
9+
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
10+
<link href="https://fonts.googleapis.com/css2?family=42dot+Sans&display=swap" rel="stylesheet">
11+
12+
<link rel="stylesheet" href="web_vis/styles.css">
13+
<body>
14+
<h1 id="title">Mini-AlphaGo Web Interface</h1>
15+
16+
<main>
17+
<div id="main-game">
18+
<div id="eval-bar">
19+
<div id="eval-bar-black"></div>
20+
<div id="eval-bar-middle"></div>
21+
22+
<p id="eval-num" class="black">0.00</p>
23+
</div>
24+
25+
<div id="board" class="show-policy"></div>
26+
27+
<div id="game-nav">
28+
<button id="reset">Reset</button>
29+
<button id="undo">Undo</button>
30+
<button id="pass">Pass</button>
31+
</div>
32+
</div>
33+
34+
<div id="settings">
35+
<p class="header">Settings</p>
36+
37+
<div class="toggle">
38+
<input type="checkbox" name="show-policy" id="show-policy" checked>
39+
<label class="box" for="show-policy"></label>
40+
<p>Show policy</p>
41+
</div>
42+
43+
<div class="toggle">
44+
<input type="checkbox" name="show-value" id="show-value" checked>
45+
<label class="box" for="show-value"></label>
46+
<p>Show value</p>
47+
</div>
48+
</div>
49+
</main>
50+
51+
<script src="web_vis/script.js"></script>
52+
</body>
53+
</html>

web_vis/script.js

+142
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
// Global
2+
const board = []
3+
const SIZE = 9
4+
5+
let show_policy = true
6+
let show_value = true
7+
8+
// Elts
9+
const eval_bar = document.getElementById("eval-bar")
10+
const eval_num = document.getElementById("eval-num")
11+
12+
const board_elt = document.getElementById("board")
13+
14+
const nav_reset = document.getElementById("reset")
15+
const nav_undo = document.getElementById("undo")
16+
const nav_pass = document.getElementById("pass")
17+
18+
const settings_policy = document.getElementById("show-policy")
19+
const settings_value = document.getElementById("show-value")
20+
21+
// Functions
22+
async function updateBoard() {
23+
const response1 = await fetch("/get_board", {"method": "POST"})
24+
const newBoardStr = await response1.text()
25+
26+
const temp = newBoardStr.split("\n")
27+
28+
for (let i = 0; i < SIZE; i++) {
29+
for (let j = 0; j < SIZE; j++) {
30+
board[i][j].innerText = temp[i][j]
31+
}
32+
}
33+
34+
const response2 = await fetch("/network", {"method": "POST"})
35+
const things = await response2.json()
36+
37+
const policy = things["policy"]
38+
const value = things["value"]
39+
40+
updatePolicy(policy)
41+
updateEval(value)
42+
}
43+
44+
function updateEval(value) {
45+
eval_bar.setAttribute("style", "--eval: " + value)
46+
eval_num.innerText = Math.round(value * 100) / 100
47+
48+
if (value >= 0) {
49+
eval_num.classList.remove("white")
50+
eval_num.classList.add("black")
51+
} else {
52+
eval_num.classList.add("white")
53+
eval_num.classList.remove("black")
54+
}
55+
}
56+
57+
function updatePolicy(policy) {
58+
for (let i = 0; i < SIZE; i++) {
59+
for (let j = 0; j < SIZE; j++) {
60+
board[i][j].setAttribute("style", "--policy: " + policy[i * SIZE + j])
61+
}
62+
}
63+
}
64+
65+
async function playMove(i, j) {
66+
await fetch("/play_move", {
67+
"method": "POST",
68+
"headers": {
69+
"Content-Type": "application/json",
70+
},
71+
"body": JSON.stringify({
72+
"row": i,
73+
"col": j
74+
})
75+
})
76+
77+
updateBoard()
78+
}
79+
80+
async function reset() {
81+
await fetch("/reset", {"method": "POST"})
82+
83+
updateBoard()
84+
}
85+
86+
async function undo() {
87+
await fetch("/undo", {"method": "POST"})
88+
89+
updateBoard()
90+
}
91+
92+
async function pass() {
93+
playMove(-1, -1)
94+
}
95+
96+
function togglePolicy() {
97+
show_policy = !show_policy
98+
99+
if (show_policy) {
100+
board_elt.classList.add("show-policy")
101+
} else {
102+
board_elt.classList.remove("show-policy")
103+
}
104+
}
105+
106+
function toggleValue() {
107+
show_value = !show_value
108+
109+
if (show_value) {
110+
eval_bar.classList.remove("hidden")
111+
} else {
112+
eval_bar.classList.add("hidden")
113+
}
114+
}
115+
116+
// Main
117+
nav_reset.addEventListener("click", reset)
118+
nav_undo.addEventListener("click", undo)
119+
nav_pass.addEventListener("click", pass)
120+
121+
settings_policy.addEventListener("click", togglePolicy)
122+
settings_value.addEventListener("click", toggleValue)
123+
124+
for (let i = 0; i < SIZE; i++) {
125+
const row = []
126+
const rowElt = document.createElement("div")
127+
rowElt.classList.add("row")
128+
129+
for (let j = 0; j < SIZE; j++) {
130+
const temp = document.createElement("button")
131+
132+
temp.addEventListener("click", () => {playMove(i, j)})
133+
row.push(temp)
134+
135+
rowElt.appendChild(temp)
136+
}
137+
138+
board.push(row)
139+
board_elt.appendChild(rowElt)
140+
}
141+
142+
updateBoard()

0 commit comments

Comments
 (0)