Skip to content

searching: A*, \(\alpha-\beta\) and MCTS

A*

算法

function A*():
  // Step 1 
  Mark s "open" and calculate f(s).

  // Step 2 
  Select the open node n whose value of f is smallest.
  Resolve ties arbitrarily, but always in favor of any node n in T.

  // Step 3
   If n is in T, mark n "closed" and terminate the algorithm.

  // Step 4 
  Otherwise, mark n "closed" and apply the successor operator V to n. 

  Calculate f for each successor of n and mark as open each successor not already marked closed.

  Remark as open any closed node m, which is a successor of n,
  and for which f(m) is smaller now than it was when m was marked closed. 

  Go to Step 2.

定义

我们使用如下定义:

一个图 \(G\) 含有起点 \(s\) 和目标点集 \(T\), 边表示为 \(d(m,n)\)

\(f(n)\) 表示经过 \(n\) 的从 \(s\) 到目标点集 \(T\) 的最短距离

我们有 \(f(n)=g(n)+h(n)\), \(g(n)\) 表示从 \(s\)\(n\) 的最短距离,\(h(n)\) 表示从 \(n\) 到目标点集的最短距离

同时我们需要两个估计函数 \(\hat g(n)\)\(\hat h(n)\)

我们可以取 \(\hat g(n)\)当前\(s\)\(n\) 的最短距离

\(\hat g(n)\geq g(n)\), 因为任何局部最优解不会比全局最优解更优

同时可以取 \(\hat h(n)\) 为我们在反图上跑出来的,从目标点集到 \(n\) 的最短距离

这里要求 \(\hat h(n)\leq h(n)\),即 \(\hat h(n)\)\(h(n)\) 的一个 lower bound

可容性

A*算法可以找到最优解,则满足可容性/可采纳性 (admissible)

在之前的定义中,\(\hat h(n)\leq h(n)\), 意为我们不会过分估计未来需要的花费

我们接下来要证明,如果 \(\hat h(n)\leq h(n)\), 则 A* 算法满足可容性

Lemma 1: 对于任意一个未关闭节点 \(n\) 和对应的从 \(s\)\(n\) 的最短路 \(P\), 存在一个节点 \(n'\) 使得 \(\hat g(n') = g(n')\)

证明:

如果 \(s\) 是打开的,那么令 \(n'=s\) 即可

如果 \(s\) 关闭,令 \(\Delta\)\(P\) 上的节点 \(n_i\) 组成的集合,\(n_i\) 满足已关闭且满足 \(\hat g(n_i) = g(n_i)\)

由于 \(s\) 关闭,所以 \(\Delta\) 非空

选出一个标号最大的点 \(n^*\), 由于 \(n\) 未关闭,所以 \(n^* \not= n\)

\(n'\)\(n^*\) 的后继节点

\(\hat g\) 的定义我们有 \(\hat g(n') \leq \hat g(n^*) + d(n^*,n')\)

同时 \(\hat g(n^*) = g(n^*)\), 因为 \(n^*\in\Delta\)

\(g(n')=g(n^*)+d(n^*,n')\), 因为 \(P\) 是最短路

所以我们有 \(\hat g(n') \leq g(n')\)

而一般地,\(\hat g(n') \geq g(n')\) 一定成立

所以 \(\hat g(n') = g(n')\), 且 \(n'\not\in\Delta\) 未关闭

Corollary: 假设 \(\hat h(n) \leq h(n)\), 且 A* 算法未停止。那么在任意一条最短路 \(P\) 上, 存在一个 \(P\) 上未关闭点 \(n'\) 满足 \(\hat f(n') \leq f(s)\)

证明:由引理1,存在一个 \(n'\) 使得 \(\hat g(n') = g(n')\)

所以

\(\hat f(n') = \hat g(n') + \hat h(n') = g(n') + \hat h(n') \leq g(n') + h(n') = f(n')\)

而因为是最短路,所以有 \(f(n')=f(s)\) 对所有 \(n'\) 成立

所以存在一个 \(n'\) 使得 \(\hat f(n') \leq f(s)\)

Theorem 1: 如果 \(\hat h(n) \leq h(n)\) 对所有 \(n\) 成立,那么 A* 是 Admissable (可容) 的

证明:反证

假设不成立,有 3 中情况:

  • Case 1: 不会停止在目标点集

这与停止条件矛盾

  • Case 2: 不会停止

假设每条边最小值是 \(\delta\), 那么对于距离 \(s\) 点步数大于 \(M=f(s)/\delta\) 步的所有点 \(n\),我们有 \(\hat f(n) \geq \hat g(n)\geq g(n)> M\delta = f(s)\), 这说明这些节点都不会被拓展

这是由于引理1的推论,存在点 \(n'\) 使得 \(\hat f(n') \leq f(s) < \hat f(n)\), 所以算法选点时会选择 \(n'\) 而不是 \(n\)

那么算法不停止只会是因为距离起点 \(M\) 步以内的点被重复打开

\(\chi(M)\) 表示这些点组成的点集, \(v(M)\) 表示点集大小

但是经过这些点的路径只有有限条,记为 \(\rho(n, M)\)

\(\rho(M) = \max_{n\in \chi(M)}\rho(n, M)\)

那么经过 \(\rho(M)v(M)\) 步后,所有节点一定都关闭,此时算法结束

  • Case 3: 停止时解不是最优解

假设算法停在了 目标点 \(t\) 上,满足 \(\hat f(t) = \hat g(t) > f(s)\),后一个不等式说明了不是最优解

但是由引理1的推论,在终止前存在一个点 \(n'\) 满足 \(\hat f(n') \leq f(s) < \hat f(t)\)

所以算法应该先选择 \(n'\) 而不是 \(t\), 故得出矛盾

综上,三种情况都不满足,命题得证

一致性

如果 \(h(m, n) +\hat h(n) \geq \hat h(m)\), 则满足一致性

这里 \(h(m,n) \geq d(m,n)\), 可以直接看成实际距离 \(d(m,n)\)

Lemma 2: 在一致性条件下,如果一个节点 \(n\) 被关闭,那么 \(g(n)=\hat g(n)\)

先说一下这个引理的重要性

它主要重要在两方面:

首先,它可以用来证明 A* 的最优性

其次,它说明算法不需要重新打开一个已经关闭的节点

即 A* 算法不会重复扩展同一节点

这样,我们就可以把算法第 \(4\) 步的重新打开节点的操作省去

这一点与 Dijkstra 是一样的

下面是证明

证明:反证

假设关闭 \(n\) 之前有 \(\hat g(n)>g(n)\)

此时存在一个从 \(s\)\(n\) 的最短路 \(P\)

因为 \(\hat g(n)>g(n)\),所以当前的路径不是最短路,即算法尚未发现 \(P\)

根据引理1, 存在一个 \(P\) 上的点 \(n'\) 满足 \(\hat g(n') = g(n')\)

如果 \(n'=n\), 我们就证明了这个引理

否则我们有 \(g(n) =g(n') + h(n', n)= \hat g(n')+h(n',n)\)

所以 \(\hat g(n)>\hat g(n')+h(n',n)\)

两边同加 \(\hat h(n)\):

\(\hat g(n) + \hat h(n)>\hat g(n')+h(n',n) + \hat h(n)\)

由于一致性: \(h(n', n) +\hat h(n) \geq \hat h(n')\)

所以 \(\hat g(n) + \hat h(n)>\hat g(n')+\hat h(n')\)

\(\hat f(n)>\hat f(n')\)

这与我们选择 \(n\) 矛盾,因为我们应该选择 \(n'\)

故得证

Lemma 3: 令 \((n_1\cdots n_t)\) 表示 A* 依次关闭的节点. 那么对于 \(p\leq q\), 有\(\hat f(n_p) \leq \hat f(n_q)\)

这说明了 \(\hat f\) 的单调不降性

证明:

假设 \(n\) 是 A* 关闭 \(m\) 后要关闭的节点

如果到 \(n\) 的最短路不经过 \(m\),说明 \(n\) 不由 \(m\) 拓展,他们是并列关系

在选择 \(m\) 后才选 \(n\), 就说明了 \(\hat f(m) \leq \hat f(n)\), 则引理成立

如果到 \(n\) 的最短路经过 \(m\),我们有 \(g(n)=g(m)+h(m,n)\)

由引理2, 我们有 \(\hat g(n)=g(n)\)\(\hat g(m)=g(m)\)

\(\hat f(n) = \hat g(n) + \hat h(n)\)

\(= g(n) + \hat h(n)\)

\(= g(m) + h(m, n) + \hat h(n)\)

\(\geq g(m) + \hat h(m)\)

\(= \hat g(m) + \hat h(m) = \hat f(m)\)

故得证

Corollary: 如果 \(n\) 关闭,那么 \(\hat f(n)\leq f(s)\)

证明:

由引理3,\(\hat f(n) \leq \hat f(t) = f(t) = f(s)\)

最优性

下面的定理和推论我们都不加证明地给出,具体证明参见原论文

首先是没有出现 ties 的情况

这里的 ties 指 A* 同时可以选择两个及以上的节点

具体地,存在 \(n_1\cdots n_k\) 满足 \(\hat f(n_1)=\cdots=\hat f(n_k)<\hat f(n)\) 对其他 \(n\) 成立

Theorem 2: 令 A 是任意可容的,不比 A 有更多信息 (informed) 的算法,\(G\) 为边权最小值为 \(\delta\) ,满足 \(\hat f(n)\not=\hat f(m),n\not=m\) 的图, 且满足一致性。则如果 \(n\) 在 A 中被拓展,则在 A 中也会被拓展

这里的 informed 未定义,可以感性理解一下,就是 \(\hat h\) 函数包含的信息有效性

比如,\(\hat h = 0\) 就不如之前从 \(t\)\(n\) 的最短路更有信息

Corollary: 定义 \(N(A, G)\) 为 A 算法在 \(G\) 图上拓展的节点,那么 \(N(A^*, G) \leq N(A, G)\)

这样我们就说 A* 是一个最优的算法, 因为它拓展了最少的节点

然后是有 ties 存在的情况

令 G 表示所有处理 ties 方式不同的 A 算法的集合

Theorem 3: 令 A 是任意可容的,不比 G 有更多信息的算法,\(G\) 为边权最小值为 \(\delta\) ,满足 \(\hat f(n)\not=\hat f(m),n\not=m\) 的图, 且满足一致性。则存在一个 A \(\in\) G, 如果 \(n\) 在 A 中被拓展,则在 A 中也会被拓展

Corollary 1: 存在一个 A \(\in\) G 满足 \(N(A^*, G) \leq N(A, G)\)

Corollary 2: 定义 \(R(A^*, G)\) 为 A 应用于 \(G\) 上的总 critical ties 个数, 则对任意一个 A \(\in\) G*, 我们有 \(N(A*,G)\leq N(A,G)+R(A^*,G)\)

对于 noncritical ties, 满足所有其他可选的节点一定同时被 A 和 A* 拓展

critical ties 即为相反定义

可容性与一致性的关系

Theorem 4: 如果启发函数满足一致性, 一定满足可容性

证明:

对于任意节点 \(x_0\), 从 \(x_0\) 到目标点集的最短路为 \(P\),最短路上节点为 \(x_1 \cdots x_n\)

\(\displaystyle \hat h(x_0) \leq \hat h(x_1)+h(x_0, x_1)\leq \hat h(x_2)+h(x_0,x_1)+h(x_1,x_2)\leq \cdots\leq\sum_{i=1}^nh(x_{i-1},x_i)=h(x_0)\)

\(\hat g(n)\)\(\hat h(n)\) 的选取

\(\hat h=0\) 退化为 Dijkstra

\(\hat h=0\) 且边权为 \(1\) 退化为 BFS

\(\hat g=0\) 退化为 Greedy Best First Search

参考文献

更加严谨,具体的证明与分析可以见A*算法的原论文

A Formal Basis for the Heuristic Determination of Minimum Cost Paths

\(\alpha-\beta\)

Minimax

在对抗博弈中,Alice希望得到最大收益,Bob希望得到最小收益,在整个搜索树中,如果奇数层是Alice,偶数层是Bob,那么奇数层取子结点中最大,偶数层取最小

\(\alpha-\beta\) pruning

每个节点有两个参数, \(\alpha\)表示最大下界,\(\beta\)表示最小上界

每次递归地搜索,将当前节点的\(\alpha\), \(\beta\) 下传到子节点

如果有\(\alpha\geq \beta\), 则可以直接退出当前节点,因为剩余子节点不会产生更优解

这个条件的意义是,如果当前节点要对父亲节点产生贡献,就必须满足\(\alpha\leq v \leq\beta\), 而\(\alpha\), \(\beta\) 是由父亲节点传下来的,如果\(\alpha\geq \beta\),说明能够更新父节点的最值条件已经不被满足了(最小值超过了父节点要求的最大值,或最大值超过了父节点要求的最小值),那么剩余子节点就不用再搜了,直接退出即可

int alpha_beta(int u, int alpha, int beta, int minimax) { // minimax = 0 表示最小值点, 反之为最大值点
    if (!son[u].size()) return val[u];
    if (minimax) {
        for (auto d:son[u]) {
            alpha = max(alpha, alpha_beta(d, alpha, beta, minimax ^ 1));
            if (alpha >= beta) break;
        }
        return alpha;
    } else {
        for (auto d:son[u]) {
            beta = min(beta, alpha_beta(d, alpha, beta, minimax ^ 1));
            if (alpha >= beta) break;
        }
        return beta;
    }
}

MCTS

UCB (Upper Confidence Bounds)

定义\(UCB\)函数为 \(\displaystyle\frac{w_i}{n_i}+c\cdot \sqrt{\frac{2\ln N_i}{n_i}}\)

\(w_i\) 是当前节点胜利次数 (如果胜利有分数,则\(w_i\)可以更正为胜利所得分数), \(n_i\) 是节点模拟次数, \(N_i\) 是父亲节点总模拟次数, \(c\) 是常数,理论为\(\frac{1}{\sqrt 2}\)

可以看到,如果只有前一部分,算法将不断搜索胜率最高的节点,而导致新节点不被搜索

如果只有后一部分,在 \(N\) 一定时,算法优先搜索没被访问的节点,但忽略的新节点的胜率问题

所以,这个置信度综合考虑了最大胜率与探索新节点

这个式子是怎么构造的呢?

我们考虑霍夫丁不等式:

\(\displaystyle P(\mu_i-\bar x_{i,t-1}>\delta)\leq e^{-2T_{(i,t-1)}\delta^2}\)

这个式子的意义是,期望奖励 \(\mu_i\) 大小超过当前动作 \(a_i\) 的奖励均值 \(\bar x_{i,t-1}\) 加上标准差 \(\delta\) 的概率不超过 \(e^{-2T_{(i,t-1)}\delta^2}\)

那么我们只要让 \(e^{-2T_{(i,t-1)}\delta^2}\) 快速收敛就行

所以我们找到了一个函数 \(t^{-4}\)

我们让 \(e^{-2T_{(i,t-1)}\delta^2} = t^{-4}\), 取log得出 \(-2T_{(i,t-1)}\delta^2 = -4\ln t\)

\(\delta = \sqrt{\frac{2\ln t}{T_{(i,t-1)}}}\), 即 \(\sqrt{\frac{2\ln N_i}{n_i}}\)

UCT

node definition

那么\(UCB\)函数:

def UCB(self, v):
        return v.total_value/v.num_visited + self.constant*(2*math.log(v.parent.num_visited)/v.num_visited)**0.5

SelectPolicy

先从根开始,不断向下递归博弈树,当遇到叶子节点时,停止并返回当前节点\(v_{select}\);当遇到有未扩展的子节点的节点时,可以以\(\frac12\)的概率向下递归,以\(\frac12\)的概率停止并返回

这样的目的是为了更多地选择更优的分支,来判断这一分支是否更好,而不是更多地扩展未扩展的节点

def SelectPolicy(self, v):
    while True:
        all_actions = list(v.state.get_legal_actions(v.color))
        children_actions = [action for action, child in v.children.items()]
        actions = list(set(all_actions) - set(children_actions))
        if not v.children:  # leaf node
            break
        choice = random.uniform(0, 1)
        if choice < 0.5 and len(actions):  # not fully expanded
            break
        else:
            v = max(v.children.items(), key=lambda x: self.UCB(x[1]))[1]
    return v

Expand

随机选择当前节点\(v_{select}\)一个为扩展的节点\(v_{expand}\)并创建节点,返回\(v_{expand}\)

def Expand(self, v):
    all_actions = list(v.state.get_legal_actions(v.color))
    children_actions = [action for action, child in v.children.items()]
    actions = list(set(all_actions) - set(children_actions))
    if len(actions) == 0:
        return None
    action = random.choice(actions)
    state = deepcopy(v.state)
    state._move(action, v.color)
    color = "X" if v.color == "O" else "O"
    v.children[action] = TreeNode(state, color, v.depth + 1, v, action)
    return v.children[action]

SimulatePolicy

使用一个更简单的策略,如随机策略,模拟\(v_{expand}\)至游戏结束,得到输赢结果,同色赢贡献为\(1\),平局为\(0.5\),输为\(0\)

def game_over(self, state):  # 如果双方都不能落子,游戏结束
    return len(list(state.get_legal_actions("X"))) == 0 and len(list(state.get_legal_actions("O"))) == 0

def SimulatePolicy(self, v):
    s_result = deepcopy(v.state)
    color = v.color
    while not self.game_over(s_result):
        actions = list(s_result.get_legal_actions(color))
        if len(actions) == 0:
            color = "X" if color == "O" else "O"
            continue
        action = random.choice(actions)
        s_result._move(action, color)
        color = "X" if color == "O" else "O"
    return s_result

BackPropagate

\(v_{expand}\)开始向上更新整个分支

如果当前节点为\(MAX\)节点,则要减去贡献,因为对手是\(MIN\)节点,他选择子节点时会选择对于他最有利的分支,所以贡献越高,说明对手的对手(自己)赢得越多,对手越不会选择,所以要减去贡献

反之,\(MIN\)节点加上贡献,因为自己是\(MAX\)节点,选择\(MIN\)节点中贡献最大的分支

同时,每个节点的访问次数加\(1\)

def BackPropagate(self, v, s_result):
    winner, diff = s_result.get_winner()
    # print("winner is: ", winner, diff)
    color = 0 if self.color == "X" else 1
    result = 1 if winner == color else 0.5 if winner == 2 else 0
    while v:
        v.num_visited += 1
        v.total_value += result * (-1 if v.color == self.color else 1)
        v = v.parent

整体的\(AIPlayer\)代码:

import datetime
import math
from copy import deepcopy
import random


class TreeNode:
    def __init__(self, state, color, depth=0, parent=None, action=None):
        self.state = state
        self.parent = parent
        self.children = {}
        self.num_visited = 0
        self.total_value = 0
        self.depth = depth
        self.color = color
        self.action = action


class AIPlayer:
    """
    AI 玩家
    """
    def __init__(self, color, time=1, constant=1 / 2 ** 0.5):
        """
        玩家初始化
        :param color: 下棋方,'X' - 黑棋,'O' - 白棋
        """
        self.color = color
        self.time = time
        self.constant = constant

    def get_move(self, board):
        """
        根据当前棋盘状态获取最佳落子位置
        :param board: 棋盘
        :return: action 最佳落子位置, e.g. 'A1'
        """
        if self.color == 'X':
            player_name = '黑棋'
        else:
            player_name = '白棋'
        print("请等一会,对方 {}-{} 正在思考中...".format(player_name, self.color))
        action = self.UCTSearch(board)
        return action

    def UCTSearch(self, state):
        # state = board
        root = TreeNode(state, self.color)
        start_time = datetime.datetime.now()
        while True:
            if datetime.datetime.now() - start_time > datetime.timedelta(seconds=self.time):
                break
            v_select = self.SelectPolicy(root)
            v_expand = self.Expand(v_select)
            if not v_expand:
                continue
            s_result = self.SimulatePolicy(v_expand)
            self.BackPropagate(v_expand, s_result)
        action = self.get_action(root)
        return action

    def SelectPolicy(self, v):
        while True:
            all_actions = list(v.state.get_legal_actions(v.color))
            children_actions = [action for action, child in v.children.items()]
            actions = list(set(all_actions) - set(children_actions))
            if not v.children:  # leaf node or not fully expanded
                break
            choice = random.uniform(0, 1)
            if choice < 0.5 and len(actions):
                break
            else:
                v = max(v.children.items(), key=lambda x: self.UCB(x[1]))[1]
        return v

    def UCB(self, v):
        return v.total_value/v.num_visited + self.constant*(2*math.log(v.parent.num_visited)/v.num_visited)**0.5

    def Expand(self, v):
        all_actions = list(v.state.get_legal_actions(v.color))
        children_actions = [action for action, child in v.children.items()]
        actions = list(set(all_actions) - set(children_actions))
        if len(actions) == 0:
            return None
        action = random.choice(actions)
        state = deepcopy(v.state)
        state._move(action, v.color)
        color = "X" if v.color == "O" else "O"
        v.children[action] = TreeNode(state, color, v.depth + 1, v, action)
        return v.children[action]

    def game_over(self, state):
        return len(list(state.get_legal_actions("X"))) == 0 and len(list(state.get_legal_actions("O"))) == 0

    def SimulatePolicy(self, v):
        s_result = deepcopy(v.state)
        color = v.color
        while not self.game_over(s_result):
            actions = list(s_result.get_legal_actions(color))
            if len(actions) == 0:
                color = "X" if color == "O" else "O"
                continue
            action = random.choice(actions)
            s_result._move(action, color)
            color = "X" if color == "O" else "O"
        return s_result

    def BackPropagate(self, v, s_result):
        winner, diff = s_result.get_winner()
        color = 0 if self.color == "X" else 1
        result = 1 if winner == color else 0.5 if winner == 2 else 0
        while v:
            v.num_visited += 1
            v.total_value += result * (-1 if v.color == self.color else 1)
            v = v.parent

    def get_action(self, root):
        return max(root.children.items(), key=lambda x: x[1].total_value/x[1].num_visited)[0]