Skip to content

Commit 34da537

Browse files
committed
Fast move validation
1 parent e4d8a28 commit 34da537

File tree

3 files changed

+67
-14
lines changed

3 files changed

+67
-14
lines changed

board.py

+53-8
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,35 @@ def pick_one(s: set[int]) -> int:
282282
return (score[0], score[1])
283283

284284

285+
def is_valid_move(self, row: int, col: int) -> bool:
286+
# Compute stone value
287+
val = self.move % 2 + 1
288+
289+
# Pass is always allowed
290+
if (row, col) == (-1, -1):
291+
return True
292+
293+
# On top of another stone is never allowed
294+
if self.grid[row, col] != 0:
295+
return False
296+
297+
# If capturing, needs play_stone
298+
for group in self.groups:
299+
if group.group_type == val:
300+
continue
301+
302+
if len(group - {(row, col)}) == 0:
303+
return self.__play_stone(row, col, False)
304+
305+
# Prohibit suicide
306+
for group in self.groups:
307+
if group.group_type == val:
308+
continue
309+
310+
if len(group - {(row, col)}) == 0:
311+
return False
312+
313+
285314
def play_stone(self, row: int, col: int, move: bool = True) -> bool:
286315
"""
287316
Attempts to place a stone of value val at (row, col)
@@ -294,6 +323,23 @@ def play_stone(self, row: int, col: int, move: bool = True) -> bool:
294323
move (optional): whether or not to update the board, default True
295324
"""
296325

326+
return self.__play_stone(row, col, move)
327+
328+
329+
def __play_stone(self, row: int, col: int, move: bool = True) -> bool:
330+
"""
331+
THIS IS A PRIVATE METHOD! DO NOT USE THIS OUTSIDE BOARD.PY
332+
333+
Attempts to place a stone of value val at (row, col)
334+
335+
Returns True if the move is valid, False if not
336+
337+
Args:
338+
row: index of the row to place the stone
339+
col: index of the column to place the stone
340+
move (optional): whether or not to update the board, default True
341+
"""
342+
297343
# Compute stone value
298344
val = self.move % 2 + 1
299345

@@ -363,18 +409,14 @@ def play_stone(self, row: int, col: int, move: bool = True) -> bool:
363409
if len(group.liberties) > 0:
364410
new_candidate_groups.append(group)
365411
continue
366-
367-
# Remove captured stones from the board
368-
for i in group.intersections:
369-
candidate[i // self.size, i % self.size] = 0
370412

371413
# Record captures
372414
captured |= group.intersections
373415

374416
# Update for newly opened intersections
375417
for group in candidate_groups:
376418
group.replenish_liberties(captured)
377-
419+
378420
# Prohibit suicide
379421
for group in candidate_groups:
380422
# Skip opposite color
@@ -386,6 +428,10 @@ def play_stone(self, row: int, col: int, move: bool = True) -> bool:
386428

387429
new_candidate_groups.append(group)
388430

431+
# Remove captured stones from the board
432+
for i in captured:
433+
candidate[i // self.size, i % self.size] = 0
434+
389435
candidate_groups = new_candidate_groups
390436

391437
# Prohibit repetition
@@ -430,10 +476,9 @@ def available_moves(self) -> list[tuple[int, int]]:
430476

431477
for i in range(self.size):
432478
for j in range(self.size):
433-
available = self.play_stone(
479+
available = self.is_valid_move(
434480
row = i,
435-
col = j,
436-
move = False
481+
col = j
437482
)
438483

439484
if available:

game_node.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,9 @@ def available_moves_mask(self) -> NDArray:
8989
out = np.zeros((self.size**2 + 1, ), dtype=bool)
9090

9191
for i in range(self.size**2):
92-
available = super().play_stone(
92+
available = self.is_valid_move(
9393
row = i // self.size,
9494
col = i % self.size,
95-
move = False
9695
)
9796

9897
if available:

group.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,19 @@ def add_union(groups: list[Self], new_stone_group: Self) -> list[Self]:
9292
out.append(group.copy())
9393

9494
# Union groups of same stone containing the new stone as a liberty
95-
unioned = Group(set(), set(), set(), new_stone_group.group_type)
96-
97-
for group in need_union:
98-
unioned.union_in_place_(group)
95+
new_intersections = set().union(*[group.intersections for group in need_union])
96+
new_borders = set().union(*[group.borders for group in need_union])
97+
new_liberties = set().union(*[group.liberties for group in need_union])
98+
99+
new_borders -= new_intersections
100+
new_liberties -= new_intersections
101+
102+
unioned = Group(
103+
intersections=new_intersections,
104+
borders=new_borders,
105+
liberties=new_liberties,
106+
group_type=new_stone_group.group_type
107+
)
99108

100109
# Add unioned group back to group list
101110
out.append(unioned)

0 commit comments

Comments
 (0)