Skip to content

Commit

Permalink
🎨 Perf and toying with MCTS
Browse files Browse the repository at this point in the history
  • Loading branch information
wrenger committed Feb 8, 2024
1 parent 0cb5a92 commit 9c634a9
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 18 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ owo-colors = "4.0"
async-recursion = "1.0"
log = { version = "0.4", features = ["release_max_level_info"] }
env_logger = { version = "0.11", default_features = false }
mocats = "0.2"

[dev-dependencies]
criterion = { version = "0.5", features = ["async_tokio"] }
Expand Down
10 changes: 10 additions & 0 deletions src/agents/mcts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
use std::sync::Arc;

use crate::env::MoveResponse;
use crate::game::Game;
use crate::search::{mcts, Heuristic};

pub async fn step(heuristic: Arc<dyn Heuristic>, timeout: u64, game: &Game) -> MoveResponse {
let dir = mcts(heuristic, timeout, game).await;
MoveResponse::new(dir)
}
4 changes: 4 additions & 0 deletions src/agents/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ pub use random::*;
pub mod maxn;
mod solo;
pub use solo::*;
mod mcts;
pub use mcts::*;

use crate::game::Game;

Expand All @@ -26,6 +28,7 @@ pub enum Agent {
Mobility(MobilityAgent),
Tree(TreeHeuristic),
Flood(FloodHeuristic),
MonteCarlo(FloodHeuristic),
Solo(SoloHeuristic),
Random(RandomAgent),
}
Expand Down Expand Up @@ -53,6 +56,7 @@ impl Agent {
Agent::Mobility(agent) => agent.step(game).await,
Agent::Tree(agent) => maxn::step(Arc::new(agent.clone()), timeout, game).await,
Agent::Flood(agent) => maxn::step(Arc::new(agent.clone()), timeout, game).await,
Agent::MonteCarlo(agent) => mcts::step(Arc::new(agent.clone()), timeout, game).await,
Agent::Solo(agent) => maxn::step(Arc::new(agent.clone()), timeout, game).await,
Agent::Random(agent) => agent.step(game).await,
}
Expand Down
6 changes: 6 additions & 0 deletions src/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ pub enum Direction {
Left,
}

impl fmt::Display for Direction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<Self as Debug>::fmt(self, f)
}
}

impl Direction {
pub fn all() -> [Self; 4] {
[Self::Up, Self::Right, Self::Down, Self::Left]
Expand Down
134 changes: 134 additions & 0 deletions src/search/mcts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
//! # Monte Carlo Tree Search
//!
//! Idea: use mcts with a fast agent to simulate games instead of random playouts
use std::sync::Arc;

use log::{info, warn};
use mocats::UctPolicy;
use tokio::time::Instant;

use crate::{env::Direction, game::Game};

use super::Heuristic;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Player(u8);

impl mocats::Player for Player {}
impl mocats::GameAction for Direction {}

#[derive(Debug, Clone)]
struct MctsGame {
start: usize,
game: Game,
actions: Vec<Direction>,
player: Player,
heuristic: Arc<dyn Heuristic>,
}

impl mocats::GameState<Direction, Player> for MctsGame {
fn get_actions(&self) -> Vec<Direction> {
if self.game.turn > self.start + 8 {
return Vec::new();
}

let mut moves: Vec<Direction> = self.game.valid_moves(self.player.0).collect();
if self.player.0 != 0 && moves.is_empty() {
moves.push(Direction::Up);
}
moves
}

fn apply_action(&mut self, action: &Direction) {
self.actions.push(*action);
if self.actions.len() == self.game.snakes.len() {
info!("step={:?}", self.actions);
self.game.step(&self.actions);
self.actions.clear();
}
self.player = Player((self.player.0 + 1) % self.game.snakes.len() as u8);
}

fn get_turn(&self) -> Player {
self.player
}

fn get_reward_for_player(&self, player: Player) -> f32 {
let mut game = self.game.clone();
game.snakes.swap(0, player.0 as usize);
let res = self.heuristic.eval(&self.game) as f32;
info!("reward={res} for {player:?}");
res
}
}

pub async fn mcts(heuristic: Arc<dyn Heuristic>, timeout: u64, game: &Game) -> Direction {
let tree_policy = UctPolicy::new(2.0);

let game = MctsGame {
start: game.turn,
game: game.clone(),
actions: Vec::new(),
player: Player(0),
heuristic,
};
let mut search_tree = mocats::SearchTree::new(game, tree_policy);

let start = Instant::now();
while start.elapsed().as_millis() < timeout as _ {
warn!(">>> mcts {:?}", start.elapsed().as_millis());
async {
search_tree.run(4);
}
.await;
}

search_tree.get_best_action().unwrap_or_default()
}

#[cfg(test)]
mod test {
use crate::game::Game;
use crate::logging;
use crate::search::{mcts, Heuristic};
use log::info;
use std::sync::Arc;

#[tokio::test]
async fn simple() {
logging();

#[derive(Debug, Clone, Default)]
struct SimpleHeuristic;
impl Heuristic for SimpleHeuristic {
fn eval(&self, game: &Game) -> f64 {
if game.snake_is_alive(0) {
1.0
} else {
0.0
}
}
}

let game = Game::parse(
r#"
. . . . . . . . . . .
. . . . . . . . . . .
. . . . 0 . 1 . . . .
. . . . ^ . ^ . . . .
. . . . ^ . ^ . . . .
. . . . . . . . . . .
. . . . . . . . . . .
. . . . . . . . . . .
. . . . . . . . . . .
. . . . . . . . . . .
. . . . . . . . . . ."#,
)
.unwrap();

let heuristic = Arc::new(SimpleHeuristic);
let dir = mcts(heuristic, 1000, &game).await;
info!("dir={:?}", dir);
}
}
34 changes: 16 additions & 18 deletions src/search/minimax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::game::Game;
use crate::{env::Direction, game::Outcome};

use async_recursion::async_recursion;
use tokio::task::JoinSet;

use super::{Heuristic, DRAW, LOSS, WIN};

Expand All @@ -17,7 +18,7 @@ use super::{Heuristic, DRAW, LOSS, WIN};
pub async fn async_max_n(game: &Game, depth: usize, heuristic: Arc<dyn Heuristic>) -> [f64; 4] {
assert!(game.snakes.len() <= 4);

let mut futures = [None, None, None, None];
let mut set = JoinSet::new();
for d in Direction::all() {
if !game.move_is_valid(0, d) {
continue;
Expand All @@ -28,17 +29,16 @@ pub async fn async_max_n(game: &Game, depth: usize, heuristic: Arc<dyn Heuristic
let heuristic = heuristic.clone();

// Create tasks for subtrees.
futures[d as usize] = Some(tokio::task::spawn(async move {
async_max_n_rec(&game, depth, 1, actions, heuristic).await
}));
set.spawn(async move {
let r = async_max_n_rec(&game, depth, 1, actions, heuristic).await;
(d, r)
});
}

let mut result = [LOSS; 4];
for (i, future) in futures.into_iter().enumerate() {
if let Some(f) = future {
if let Ok(r) = f.await {
result[i] = r;
}
while let Some(r) = set.join_next().await {
if let Ok((d, r)) = r {
result[d as usize] = r;
}
}
result
Expand Down Expand Up @@ -72,7 +72,7 @@ async fn async_max_n_rec(
}
} else if ply == 0 {
// max
let mut futures = [None, None, None, None];
let mut set = JoinSet::new();
for d in Direction::all() {
if !game.move_is_valid(0, d) {
continue;
Expand All @@ -83,17 +83,15 @@ async fn async_max_n_rec(
let heuristic = heuristic.clone();

// Create tasks for subtrees.
futures[d as usize] = Some(tokio::task::spawn(async move {
async_max_n_rec(&game, depth, ply + 1, actions, heuristic).await
}));
set.spawn(
async move { async_max_n_rec(&game, depth, ply + 1, actions, heuristic).await },
);
}

let mut max = LOSS;
for future in futures {
if let Some(f) = future {
if let Ok(r) = f.await {
max = max.max(r);
}
while let Some(r) = set.join_next().await {
if let Ok(r) = r {
max = max.max(r);
}
}
max
Expand Down
2 changes: 2 additions & 0 deletions src/search/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ mod minimax;
pub use minimax::*;
mod alphabeta;
pub use alphabeta::*;
mod mcts;
pub use mcts::*;

use std::fmt::Debug;

Expand Down

0 comments on commit 9c634a9

Please sign in to comment.