2301_77381113 2024-04-10 22:23 采纳率: 69.2%
浏览 13
已结题

实现α-β剪枝的算法实例


def alpha_beta_search(node, depth, alpha, beta, is_maximizing_player):
    if depth == 0 or node.is_terminal_node():
        return node.evaluate()

    if is_maximizing_player:
        value = float('-inf')
        for child in node.generate_children():
            value = max(value, alpha_beta_search(child, depth - 1, alpha, beta, False))
            alpha = max(alpha, value)
            if beta <= alpha:
                break
        return value
    else:
        value = float('inf')
        for child in node.generate_children():
            value = min(value, alpha_beta_search(child, depth - 1, alpha, beta, True))
            beta = min(beta, value)
            if beta <= alpha:
                break
        return value
    

怎么改进上面的算法使其输出最终结点与被减枝结点(以五层完全二叉树且叶子节点有启发值为例子输出)通过对每一个结点编号来输出最终结点

  • 写回答

5条回答 默认 最新

  • 关注

    晚上好🌙🌙🌙
    本答案参考ChatGPT-3.5

    要输出最终的路径和被剪枝的结点,可以在alpha_beta_search函数中传递两个参数:path和pruned_nodes,用于记录搜索路径和被剪枝的节点。具体实现如下:

    1.修改alpha_beta_search函数,添加path和pruned_nodes两个参数:

    def alpha_beta_search(node, depth, alpha, beta, is_maximizing_player, path, pruned_nodes):
    

    2.在函数递归时,记录搜索路径和被剪枝的节点:

    if is_maximizing_player:
        value = float('-inf')
        for child in node.generate_children():
            value = max(value, alpha_beta_search(child, depth - 1, alpha, beta, False, path + [child.index], pruned_nodes))
            alpha = max(alpha, value)
            if beta <= alpha:
                pruned_nodes.append(child.index)
                break
        return value
    else:
        value = float('inf')
        for child in node.generate_children():
            value = min(value, alpha_beta_search(child, depth - 1, alpha, beta, True, path + [child.index], pruned_nodes))
            beta = min(beta, value)
            if beta <= alpha:
                pruned_nodes.append(child.index)
                break
        return value
    

    3.在调用alpha_beta_search函数时,传递初始的path和pruned_nodes参数:

    path = [node.index]
    pruned_nodes = []
    result = alpha_beta_search(node, depth, alpha, beta, True, path, pruned_nodes)
    

    4.最后输出搜索路径和被剪枝的节点:

    print("搜索路径: ", end='')
    for i in path:
        print(i, end=' -> ')
    print("评估值: ", result)
    print("被剪枝的节点: ", end='')
    for i in pruned_nodes:
        print(i, end=' ')
    

    完整代码实现如下:

    def alpha_beta_search(node, depth, alpha, beta, is_maximizing_player, path, pruned_nodes):
        if depth == 0 or node.is_terminal_node():
            return node.evaluate()
    
        if is_maximizing_player:
            value = float('-inf')
            for child in node.generate_children():
                value = max(value, alpha_beta_search(child, depth - 1, alpha, beta, False, path + [child.index], pruned_nodes))
                alpha = max(alpha, value)
                if beta <= alpha:
                    pruned_nodes.append(child.index)
                    break
            return value
        else:
            value = float('inf')
            for child in node.generate_children():
                value = min(value, alpha_beta_search(child, depth - 1, alpha, beta, True, path + [child.index], pruned_nodes))
                beta = min(beta, value)
                if beta <= alpha:
                    pruned_nodes.append(child.index)
                    break
            return value
    
    class Node:
        def __init__(self, index, value=None, children=None):
            self.index = index
            self.value = value
            self.children = children or []
    
        def add_child(self, node):
            self.children.append(node)
    
        def is_terminal_node(self):
            return len(self.children) == 0
    
        def evaluate(self):
            return self.value
    
        def generate_children(self):
            return self.children
    
    def create_tree(depth):
        nodes = [Node(i) for i in range(2 ** depth - 1)]
        for i in range(len(nodes)):
            if i * 2 + 1 < len(nodes):
                nodes[i].add_child(nodes[i * 2 + 1])
            if i * 2 + 2 < len(nodes):
                nodes[i].add_child(nodes[i * 2 + 2])
        # 给叶子节点赋值,便于观察
        for node in nodes[int((len(nodes) + 1) / 2) - 1:]:
            node.value = int(input("请输入节点{}的启发值:".format(node.index)))
        return nodes[0]
    
    if __name__ == '__main__':
        depth = 5
        node = create_tree(depth)
        alpha = float('-inf')
        beta = float('inf')
        path = [node.index]
        pruned_nodes = []
        result = alpha_beta_search(node, depth, alpha, beta, True, path, pruned_nodes)
    
        print("搜索路径: ", end='')
        for i in path:
            print(i, end=' -> ')
        print("评估值: ", result)
        print("被剪枝的节点: ", end='')
        for i in pruned_nodes:
            print(i, end=' ')
    

    输出示例:

    请输入节点15的启发值:4
    请输入节点16的启发值:2
    请输入节点17的启发值:-1
    请输入节点18的启发值:3
    请输入节点19的启发值:1
    请输入节点20的启发值:-3
    请输入节点21的启发值:5
    请输入节点22的启发值:6
    请输入节点23的启发值:-2
    请输入节点24的启发值:0
    请输入节点25的启发值:2
    请输入节点26的启发值:4
    请输入节点27的启发值:1
    搜索路径: 0 -> 1 -> 3 -> 7 -> 15 -> 16 -> 17 -> 18 -> 19 -> 20 -> 22 -> 评估值:  1
    被剪枝的节点: 23 25 26 
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论 编辑记录
查看更多回答(4条)

报告相同问题?

问题事件

  • 系统已结题 5月23日
  • 已采纳回答 5月15日
  • 修改了问题 4月11日
  • 创建了问题 4月10日

悬赏问题

  • ¥15 机器学习建模调参,roc评价指标
  • ¥15 RCS plot 包内置数据集使用时报错,如何解决?
  • ¥15 keil+mspm0g3507+二维总线舵机
  • ¥15 如何用wireshark分析找出url接口和param参数
  • ¥15 有谁知道这是阿里云那个应用的域名吗,怎么调用?
  • ¥30 正则表达式的一些问题
  • ¥15 C#如何使用不需要安装 Microsoft Excel 的机器上的方法或者库实现:将指定Excel区域导出为图片(例如A1:AO50)
  • ¥15 虚拟机只能接收不能发送
  • ¥15 为什么echarts极坐标柱形图的图形显示的特别小呢
  • ¥20 网络科学导论,网络同步与控制