MCTS
# main function for the Monte Carlo Tree Search
def monte_carlo_tree_search(root):
while resources_left(time, computational power):
leaf = traverse(root)
simulation_result = rollout(leaf)
backpropagate(leaf, simulation_result)
return best_child(root)
# function for node traversal
def traverse(node):
while fully_expanded(node):
node = best_uct(node)
# in case no children are present / node is terminal
return pick_unvisited(node.children) or node
# function for the result of the simulation
def rollout(node):
while non_terminal(node):
node = rollout_policy(node)
return result(node)
# function for randomly selecting a child node
def rollout_policy(node):
return pick_random(node.children)
# function for backpropagation
def backpropagate(node, result):
if is_root(node) return
node.stats = update_stats(node, result)
backpropagate(node.parent, result)
# function for selecting the best child
# node with highest number of visits
def best_child(node):
pick child with highest number of visits
Code
import math
import random
class Node:
def __init__(self, value, children=None):
self.value = value
self.children = children or []
self.visits = 0
self.score = 0
def uct_score(node, parent_visits):
if node.visits == 0:
return float('inf')
return node.score / node.visits + math.sqrt(2 * math.log(parent_visits) / node.visits)
def mcts(root, iterations):
for _ in range(iterations):
node = root
path = []
# Selection and Expansion
while node.children:
# This should alternate to min() for 2-player games
node = max(node.children, key=lambda n: uct_score(n, node.visits))
path.append(node)
# Expansion - this is where the
if node.children:
while node.children:
node = random.choice(node.children)
path.append(node)
# Simulation
if node.children:
is_maximizing = len(path) % 2 == 0 # Even depth is Max's turn
while node.children:
if is_maximizing:
node = max(node.children, key=lambda n: n.value)
else:
node = min(node.children, key=lambda n: n.value)
path.append(node)
is_maximizing = not is_maximizing
# Backpropagation
value = node.value
for node in reversed(path):
node.visits += 1
node.score += value
return max(root.children, key=lambda n: n.score / n.visits)
# Tree setup
A = Node('A', [Node('B', [Node(3), Node(5)]), Node('C', [Node(2), Node(9)])])
# MCTS evaluation
best_move = mcts(A, 1000)
print(f"Best move: {best_move.value}")
Stages:
Contrast with minimax
def minimax(node, is_maximizing):
if node.is_leaf():
return node.value
if is_maximizing:
return max(minimax(child, False) for child in node.children)
else:
return min(minimax(child, True) for child in node.children)
Example
Let's use a simple two-player game called "Number Picking" for our example:
Game tree:
A
/ \\
B C
/ \\ / \\
3 5 2 9
Minimax:
MCTS:
UCB1 for Trees (UCT) is ($w_i$ can be a cumulative numeric score)
Intuition: balance exploitation (first term is the reward) and exploration (second term divides by number of visits)
Generalization
More generic pseudocode:
def MCTS(root):
for _ in range(number_of_simulations):
node = Selection(root)
child_node = Expansion(node)
result = Simulation(child_node)
Backpropagation(child_node, result)
return BestMove(root)
def Selection(node):
while node is fully expanded and not terminal:
node = best_child(node) # this is where UCT is incorporated
return node
def Expansion(node):
if not node is terminal:
return expand_node(node)
return node
def Simulation(node):
while not node is terminal:
node = random_play(node)
return result_of_the_game(node)
def Backpropagation(node, result):
while node is not None:
node.update(result)
node = node.parent
def BestMove(root):
return child_with_highest_visit_count(root)
Vs. value/advantage functions only?
UCB1 vs UCT vs PUCT
UCB1 (Upper Confidence Bound 1):
UCB1 = X̄ᵢ + C * √(ln N / nᵢ)
Where:
UCB1 was originally designed for the multi-armed bandit problem, not specifically for trees
UCT (Upper Confidence Bounds for Trees):
UCT = X̄ᵢ + C * √(ln N / nᵢ)
Where:
PUCT formula: PUCT = Q(s,a) + c_puct * P(s,a) * (√N(s)) / (1 + N(s,a))
X or Q:
Vs beam search
MCTS builds up a tree for (future) simulations, unlike beam search which just tries to walk a tree
MCTS explores the built tree using UCT (balancing value estimation and exploration), and randomly simulates the rest until the end. Beam search just keeps the top k according to some value estimation.
MCTS makes most obvious sense with a smallish action space/branching factor, since unlikely to re-traverse the same actions
Beam search is greedy, so can miss ultimately-optimal solutions, esp. if greedy is not a good heuristic
Note: MCBS (Monte Carlo Beam Search) is a much simpler algorithm, that just incorporates MC simulation into beam search as the value estimator.
function MCBS(root_state, beam_width, num_iterations, simulation_depth):
beam = [root_state]
for iteration in 1 to num_iterations:
next_beam = []
for state in beam:
children = generate_children(state)
for child in children:
score = monte_carlo_simulation(child, simulation_depth)
child.score = score
next_beam.append(child)
next_beam = sort_by_score(next_beam)
beam = select_top_k(next_beam, beam_width)
return best_state_in(beam)
References