m0_69388475 2024-05-14 17:26 采纳率: 0%
浏览 7
已结题

论文AlphaTensor复现(有偿)

用Python复现论文AlphaTensor(有偿)
原论文只给出了神经网络部分伪代码和MCTS的描述,我将伪代码转化为Python代码,且用GPT写出了大致符合描述的MCTS算法,但仍无法完成整个复现过程。希望有同伴可以帮助我完成我的首次复现,感激不尽!希望可以进行具体交流直至完成整个过程,我会在完成后追加相应酬金,谢谢!请使用GPT生产无用答案的答主不必再回答了。
可加v: 2741347761
原论文:Discovering faster matrix multiplication algorithms with reinforcement learning
Paper:https://www.nature.com/articles/s41586-022-05172-4
GitHub:https://github.com/deepmind/alphatensor
神经网络部分代码


import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class NeuralNetwork:
    def __init__(self):
        pass

    # A.1    
    def Attention(self, x, y, causal_mask=False, Nheads=16, d=32, w=4):
        Nx, Ny = x.size(0), y.size(0)
        c1, c2 = x.size(1), y.size(1) 

        # 1, 2
        layer_norm_x = nn.LayerNorm(c1)
        layer_norm_y = nn.LayerNorm(c2)
        xnorm = layer_norm_x(x)
        ynorm = layer_norm_y(y)

        # 3, 4, 5
        linear_q = nn.Linear(c1, Nheads * d)
        linear_k = nn.Linear(c2, Nheads * d)
        linear_v = nn.Linear(c2, Nheads * d)
        q = linear_q(xnorm).view(Nx, Nheads, d)
        k = linear_k(ynorm).view(Ny, Nheads, d)
        v = linear_v(ynorm).view(Ny, Nheads, d)

        # 6
        q = q / torch.sqrt(torch.tensor(d).float())
        attn_scores = torch.einsum('bih,bjh->bij', q, k)

        # 7: 如果指定了因果掩码参数,将注意力权重中未来位置的信息屏蔽掉,
        # 以确保模型在自回归任务中按正确的顺序生成输出
        if causal_mask:
            mask = torch.tril(torch.ones(Nx, Ny)).unsqueeze(1).repeat(1, Nheads, 1, 1).to(x.device)
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))

        a= F.softmax(attn_scores, dim=-1)

        # 8, 9
        o = torch.einsum('bij,bjh->bih', a, v)

        # 11: 将多头注意力的结果合并,线性变换,残差连接
        o = o.contiguous().view(Nx, -1)
        linear_out = nn.Linear(Nheads * d, c1)
        x = x + linear_out(o)

        # 这部分代码实现了一个稠密连接块(Dense Block),包括层归一化、全连接、GELU 激活函数和残差连接
        # 12:
        layer_norm_dense = nn.LayerNorm(c1)
        linear_dense1 = nn.Linear(c1, c1 * w)
        linear_dense2 = nn.Linear(c1 * w, c1)
        x = x + linear_dense2(F.gelu(linear_dense1(layer_norm_dense(x))))
        # 13:
        return x
        # 这段代码实现了一个基于注意力机制的神经网络层,用于处理输入张量之间的关系,并产生更新后的输出张量

    # A.2
    def AttentiveModes(self, x1, x2, x3):
        g = [x1, x2, x3]
        pairs = [(0, 1), (2, 0), (1, 2)]
        
        for m1, m2 in pairs:
            # Concatenate g[m1] and transpose of g[m2] along the channel axis
            a = torch.cat((g[m1], g[m2].transpose(-2, -1)), dim=1)
            
            # Parallel loop through each row
            for i in range(g[m1].shape[1]):  # assuming g[m1].shape[1] is S
                c = self.Attention(a[i, :, :], a[i, :, :])
                # Update parts of g[m1] and g[m2] based on attention output
                g[m1][i, :, :] = c[:c.shape[0]//2, :]
                g[m2][i, :, :] = c[c.shape[0]//2:, :].transpose(0, 1)
        
        return g

    # Assuming x1, x2, x3 are torch tensors of shape [S, S, c]
    # You need to define the Attention function which is used in the code.

    # A.3
    def Torso(self, x, s, c):
        S = x.shape[1]  # Assuming x is of shape [T, S, S, S]
        T = x.shape[0]

        # Reshape and transpose operations
        x1 = x.permute(1, 2, 3, 0).reshape(S, S, S*T)
        x2 = x.permute(3, 1, 2, 0).reshape(S, S, S*T)
        x3 = x.permute(2, 3, 1, 0).reshape(S, S, S*T)
        g = [x1, x2, x3]

        # Applying linear transformation on s and concatenating it to each g[i]
        linear_s = nn.Linear(s.shape[0], S**2)
        p = linear_s(s).reshape(S, S, 1)
        
        for i in range(3):
            g[i] = torch.cat([g[i], p.expand(-1, -1, S * T + 1)], dim=-1)
            linear_g = nn.Linear(g[i].shape[-1], c)
            g[i] = linear_g(g[i])

        [x1, x2, x3] = g

        # Applying AttentiveModes repeatedly
        attentive_modes = self.AttentiveModes(c, S)
        for _ in range(8):
            x1, x2, x3 = attentive_modes(x1, x2, x3)

        # Stacking and reshaping
        e = torch.stack([x1, x2, x3], dim=1).reshape(3 * S**2, c)
        return e

    # Note: The function AttentiveModes should be defined elsewhere as per the model's requirement.    

    # A.4 
    def predict_action_logits(self, a, e, is_training, Nsteps, Nlogits, m, c, Nfeatures=64, Nheads=32, Nlayers=2):
    # 确保必要的模块已经初始化
        linear = nn.Linear(Nlogits, Nfeatures * Nheads)
        learnable_pos_enc = nn.Parameter(torch.randn(Nsteps, Nfeatures * Nheads))
        layer_norm = nn.LayerNorm(Nfeatures * Nheads)
        attention_causal = self.Attention(Nfeatures * Nheads, Nheads, causal=True)
        attention_cross = self.Attention(Nfeatures * Nheads, Nheads, causal=False)
        dropout = nn.Dropout(0.1)
        final_linear = nn.Linear(Nfeatures * Nheads, Nlogits)

        # 1: 线性变换
        x = linear(a)

        # 2: 可学习位置编码
        x = x + learnable_pos_enc

        # 3: 多层网络
        for i in range(Nlayers):
            # 4: 层归一化
            x = layer_norm(x)
            # 5: 因果自注意力
            c = attention_causal(x, x)
            # 6, 7: 训练时应用Dropout
            if is_training:
                c = dropout(c)
            # 8: 残差连接和再次归一化
            x = x + c
            x = layer_norm(x)

        # 10: 交叉注意力
        c = attention_cross(x, e)
        # 11, 12: 训练时应用Dropout
        if is_training:
            c = dropout(c)
        # 13: 残差连接
        x = x + c

        # 14: 线性输出层和激活函数
        o = final_linear(F.relu(x))

        return o, x
    
    # A.5
    def PolicyHead_training(self, e, g, Nlogits, Nsteps):
    # 假设 g 已经是一个序列或向量,Nlogits 是类别总数,Nsteps 是序列长度

    # onehot 和 shifted 操作
        def shifted(self, g, Nsteps):
            # 假设这里 g 是整数标签序列,进行移位操作,最后一位丢弃,前面补0
            shifted_g = torch.zeros_like(g)
            shifted_g[1:] = g[:-1]
            return shifted_g

        def onehot(self, g, Nlogits):
            # 创建独热编码
            return F.one_hot(g, num_classes=Nlogits).float()

        # 生成独热编码的移位版本
        shifted_g = self.shifted(g, Nsteps)
        onehot_g = self.onehot(shifted_g, Nlogits)

            # 计算行动对数概率和嵌入
        o, z = self.predict_action_logits(onehot_g, e)

            # 由于 z 可能是一个序列嵌入,我们只返回第一个时间步的嵌入,因为它不依赖于真值标签
        z1 = z[0] if len(z.shape) > 1 else z  # 确保 z 是序列

        return o, z1

    # A.6
    def PolicyHead_inference(self, e, Nlogits, Nsteps, Nsamples=32):
        # 初始化采样的动作和概率  
        a = torch.zeros((Nsamples, Nsteps), dtype=torch.long)
        p = torch.ones(Nsamples)

        # 外部循环,用于遍历样本集合中的每个样本
        for s in range(Nsamples):
            # 内部循环,用于遍历每个时间步
            for i in range(Nsteps):
                # 生成动作的对数概率和嵌入
                o, z = self.predict_action_logits(F.one_hot(a[s], num_classes=Nlogits).float(), e)
                # 从对数概率中采样动作和计算概率
                prob = F.softmax(o[i], dim=-1)
                action = torch.multinomial(prob, 1).item()
                a[s, i] = action
                p[s] *= prob[action]
                
                # 保留第一个时间步的嵌入
                if i == 0:
                    z1 = z if s == 0 else torch.cat((z1, z.unsqueeze(0)), 0)
        return a, p, z1
    
    # A.7
    def ValueHead(self, x, n=8):
        for _ in range(3):
            x = F.relu(nn.Linear(x.shape[-1], 512)(x))
        q = nn.Linear(512, n)(x)
        return q
    
    # A.8
    def Quantile_loss(self, q, g, delta=1):
        n = q.shape[0]
        # 计算n个分位数τ
        tau = torch.arange(1, n+1).float() / n - 0.5 / n
        # 计算差异d
        d = g - q
        # 计算Huber损失
        h = torch.nn.functional.huber_loss(q, g, delta=delta, reduction='none')
        # 计算k
        k = torch.abs(tau - (d < 0).float())
        # 返回加权平均损失
        return torch.mean(k * h)
    
    # A.9
    def ValueRiskManagement(self, q, uq=0.75):
        n = len(q)
        j = int(np.ceil(uq * n))  # 计算索引,根据给定的分位数uq
        # 返回从j到n的元素的均值
        return np.mean(q[j-1:])  # Python索引从0开始,因此使用j-1
    
    # A.10
    def AlphaTensor_Net_training(self, x, s, gaction, gvalue, c, Nlogits):
        # 计算特征提取
        e = self.Torso(x, s, c)
        
        # 进行策略头训练,获取输出和状态
        o, z1 = self.PolicyHead_training(e, gaction, Nlogits)
        
        # 计算策略损失
        lpolicy = torch.sum(F.cross_entropy(o, gaction, reduction='none'))
        
        # 计算价值头输出
        q = self.ValueHead(z1)
        
        # 计算价值损失
        lvalue = self.Quantile_loss(q, gvalue)
        
        # 返回策略损失和价值损失
        return lpolicy, lvalue
    
    # A.11
    def AlphaTensor_Net_inference(self, x, s, c, Nsamples, Nsteps, Nlogits):
        # 计算特征提取
        e = self.Torso(x, s, c)
        
        # 进行策略头训练,获取输出和状态
        a, p, z1 = self.PolicyHead_inference(e, Nsamples, Nsteps, Nlogits)
        
        # 计算价值头输出
        q = self.ValueHead(z1)
        
        # 计算价值损失
        q_1 = self.Quantile_loss(q)
        
        # 返回策略损失和价值损失
        return a,p,q_1
  • 写回答

15条回答 默认 最新

  • 专家-郭老师 Java领域新星创作者 2024-05-14 17:27
    关注

    🕒下午好,题主!目前:🟢[在线]
    📢参考通义千问、由郭老师整理提供!
    📣今日诗号:春风又绿江南岸,明月何时照我还。

    AlphaTensor是DeepMind在2022年发布的一篇研究论文中介绍的一种使用强化学习来发现矩阵乘法(Matrix Multiplication, MM)算法的方法。他们在论文《Discovering faster matrix multiplication algorithms with reinforcement learning》中展示了如何通过一种混合策略搜索(MCTS)算法来优化矩阵乘法的计算步骤。由于你提到原论文只提供了部分伪代码和MCTS的描述,复现这样的方法可能会比较复杂,但以下是一些可能的步骤和指导:

    1. 理解论文架构

      • 首先,你需要详细阅读论文,确保你对AlphaTensor的框架、搜索空间(可能涉及不同矩阵乘法子集的组合)、奖励函数(可能基于运行时间和性能提升)以及MCTS的参数设置有深入理解。
    2. MCTS实现

      • 将论文中的MCTS描述转换成Python代码。这通常包括节点扩展、模拟、选择和回溯等核心步骤。你可以参考OpenAI的uct库或自行实现一个版本。
    3. 搜索空间定义

      • 创建一个数据结构来表示矩阵乘法的不同子集(比如,是否使用特定的优化技术,如Strassen's或Coppersmith-Winograd算法)。这可能是一个树状结构,每个节点代表一个操作序列。
    4. 环境模拟

      • 实现一个模拟器,根据当前操作序列执行矩阵乘法,并返回运行时间和性能指标。这需要对NumPy或类似库有深入了解,特别是对于如何利用并行计算加速。
    5. 奖励函数

      • 设计一个奖励函数,它应该鼓励更快的运行时间和性能提升。这可能涉及到比较实际运行时间和理想最优时间,或者使用一个相对效率得分。
    6. 强化学习训练

      • 使用Python的深度学习库(如TensorFlow或PyTorch)实现强化学习模型,如Q-learning或策略梯度方法。训练模型以找到最优的操作序列。
    7. 评估与复现

      • 用你的实现运行多个实验,观察模型在不同规模的矩阵乘法任务上的表现。如果可能,对比原论文的结果以验证复现的准确性。

    由于这个过程涉及到多个技术和库的集成,这里不可能提供完整的代码,但我可以提供一些关键代码片段作为示例:

    # 假设你已经有了MCTS类和矩阵乘法环境类
    
    class MCTS:
        # ...
    
    def simulate(node):
        # 模拟函数,根据当前节点执行MM并返回运行时间
    
    class MatrixMultiplicationEnvironment:
        def __init__(self, matrix_size):
            self.size = matrix_size
            # ...
    
    def train_agent(env, mcts, num_episodes):
        # 使用强化学习训练过程,这里可能包括Q学习更新等
    
    agent = train_agent(MatrixMultiplicationEnvironment(10), MCTS(), 1000)
    optimal_sequence = agent.get_optimal_sequence()
    
    # 执行优化后的矩阵乘法
    optimized_mm = execute_matrix_multiplication(optimal_sequence)
    

    请确保在实际操作中查阅论文中的具体细节,因为这些步骤可能需要根据论文中的描述进行调整。如果你在某个部分遇到困难,可以分别提问,我会尽力帮助你解决具体问题。

    评论

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 5月16日
  • 修改了问题 5月14日
  • 修改了问题 5月14日
  • 修改了问题 5月14日
  • 展开全部

悬赏问题

  • ¥20 limma多组间分析最终p值只有一个
  • ¥15 nopCommerce开发问题
  • ¥15 torch.multiprocessing.spawn.ProcessExitedException: process 1 terminated with signal SIGKILL
  • ¥15 QuartusⅡ15.0编译项目后,output_files中的.jdi、.sld、.sof不更新怎么解决
  • ¥15 pycharm输出和导师的一样,但是标红
  • ¥15 想问问富文本拿到的html怎么转成docx的
  • ¥15 我看了您的文章,遇到了个问题。
  • ¥15 GitHubssh虚拟机连接不上
  • ¥15 装完kali之后下载Google输入法 重启电脑后出现以下状况 且退不出去 桌面消失 反复重启没用
  • ¥15 ESP-IDP-BLE配网连接wifi