用Python复现论文AlphaTensor(有偿)
原论文只给出了神经网络部分伪代码和MCTS的描述,我将伪代码转化为Python代码,且用GPT写出了大致符合描述的MCTS算法,但仍无法完成整个复现过程。希望有同伴可以帮助我完成我的首次复现,感激不尽!希望可以进行具体交流直至完成整个过程,我会在完成后追加相应酬金,谢谢!请使用GPT生产无用答案的答主不必再回答了。
可加v: 2741347761
原论文:Discovering faster matrix multiplication algorithms with reinforcement learning
Paper:https://www.nature.com/articles/s41586-022-05172-4
GitHub:https://github.com/deepmind/alphatensor
神经网络部分代码
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class NeuralNetwork:
def __init__(self):
pass
# A.1
def Attention(self, x, y, causal_mask=False, Nheads=16, d=32, w=4):
Nx, Ny = x.size(0), y.size(0)
c1, c2 = x.size(1), y.size(1)
# 1, 2
layer_norm_x = nn.LayerNorm(c1)
layer_norm_y = nn.LayerNorm(c2)
xnorm = layer_norm_x(x)
ynorm = layer_norm_y(y)
# 3, 4, 5
linear_q = nn.Linear(c1, Nheads * d)
linear_k = nn.Linear(c2, Nheads * d)
linear_v = nn.Linear(c2, Nheads * d)
q = linear_q(xnorm).view(Nx, Nheads, d)
k = linear_k(ynorm).view(Ny, Nheads, d)
v = linear_v(ynorm).view(Ny, Nheads, d)
# 6
q = q / torch.sqrt(torch.tensor(d).float())
attn_scores = torch.einsum('bih,bjh->bij', q, k)
# 7: 如果指定了因果掩码参数,将注意力权重中未来位置的信息屏蔽掉,
# 以确保模型在自回归任务中按正确的顺序生成输出
if causal_mask:
mask = torch.tril(torch.ones(Nx, Ny)).unsqueeze(1).repeat(1, Nheads, 1, 1).to(x.device)
attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
a= F.softmax(attn_scores, dim=-1)
# 8, 9
o = torch.einsum('bij,bjh->bih', a, v)
# 11: 将多头注意力的结果合并,线性变换,残差连接
o = o.contiguous().view(Nx, -1)
linear_out = nn.Linear(Nheads * d, c1)
x = x + linear_out(o)
# 这部分代码实现了一个稠密连接块(Dense Block),包括层归一化、全连接、GELU 激活函数和残差连接
# 12:
layer_norm_dense = nn.LayerNorm(c1)
linear_dense1 = nn.Linear(c1, c1 * w)
linear_dense2 = nn.Linear(c1 * w, c1)
x = x + linear_dense2(F.gelu(linear_dense1(layer_norm_dense(x))))
# 13:
return x
# 这段代码实现了一个基于注意力机制的神经网络层,用于处理输入张量之间的关系,并产生更新后的输出张量
# A.2
def AttentiveModes(self, x1, x2, x3):
g = [x1, x2, x3]
pairs = [(0, 1), (2, 0), (1, 2)]
for m1, m2 in pairs:
# Concatenate g[m1] and transpose of g[m2] along the channel axis
a = torch.cat((g[m1], g[m2].transpose(-2, -1)), dim=1)
# Parallel loop through each row
for i in range(g[m1].shape[1]): # assuming g[m1].shape[1] is S
c = self.Attention(a[i, :, :], a[i, :, :])
# Update parts of g[m1] and g[m2] based on attention output
g[m1][i, :, :] = c[:c.shape[0]//2, :]
g[m2][i, :, :] = c[c.shape[0]//2:, :].transpose(0, 1)
return g
# Assuming x1, x2, x3 are torch tensors of shape [S, S, c]
# You need to define the Attention function which is used in the code.
# A.3
def Torso(self, x, s, c):
S = x.shape[1] # Assuming x is of shape [T, S, S, S]
T = x.shape[0]
# Reshape and transpose operations
x1 = x.permute(1, 2, 3, 0).reshape(S, S, S*T)
x2 = x.permute(3, 1, 2, 0).reshape(S, S, S*T)
x3 = x.permute(2, 3, 1, 0).reshape(S, S, S*T)
g = [x1, x2, x3]
# Applying linear transformation on s and concatenating it to each g[i]
linear_s = nn.Linear(s.shape[0], S**2)
p = linear_s(s).reshape(S, S, 1)
for i in range(3):
g[i] = torch.cat([g[i], p.expand(-1, -1, S * T + 1)], dim=-1)
linear_g = nn.Linear(g[i].shape[-1], c)
g[i] = linear_g(g[i])
[x1, x2, x3] = g
# Applying AttentiveModes repeatedly
attentive_modes = self.AttentiveModes(c, S)
for _ in range(8):
x1, x2, x3 = attentive_modes(x1, x2, x3)
# Stacking and reshaping
e = torch.stack([x1, x2, x3], dim=1).reshape(3 * S**2, c)
return e
# Note: The function AttentiveModes should be defined elsewhere as per the model's requirement.
# A.4
def predict_action_logits(self, a, e, is_training, Nsteps, Nlogits, m, c, Nfeatures=64, Nheads=32, Nlayers=2):
# 确保必要的模块已经初始化
linear = nn.Linear(Nlogits, Nfeatures * Nheads)
learnable_pos_enc = nn.Parameter(torch.randn(Nsteps, Nfeatures * Nheads))
layer_norm = nn.LayerNorm(Nfeatures * Nheads)
attention_causal = self.Attention(Nfeatures * Nheads, Nheads, causal=True)
attention_cross = self.Attention(Nfeatures * Nheads, Nheads, causal=False)
dropout = nn.Dropout(0.1)
final_linear = nn.Linear(Nfeatures * Nheads, Nlogits)
# 1: 线性变换
x = linear(a)
# 2: 可学习位置编码
x = x + learnable_pos_enc
# 3: 多层网络
for i in range(Nlayers):
# 4: 层归一化
x = layer_norm(x)
# 5: 因果自注意力
c = attention_causal(x, x)
# 6, 7: 训练时应用Dropout
if is_training:
c = dropout(c)
# 8: 残差连接和再次归一化
x = x + c
x = layer_norm(x)
# 10: 交叉注意力
c = attention_cross(x, e)
# 11, 12: 训练时应用Dropout
if is_training:
c = dropout(c)
# 13: 残差连接
x = x + c
# 14: 线性输出层和激活函数
o = final_linear(F.relu(x))
return o, x
# A.5
def PolicyHead_training(self, e, g, Nlogits, Nsteps):
# 假设 g 已经是一个序列或向量,Nlogits 是类别总数,Nsteps 是序列长度
# onehot 和 shifted 操作
def shifted(self, g, Nsteps):
# 假设这里 g 是整数标签序列,进行移位操作,最后一位丢弃,前面补0
shifted_g = torch.zeros_like(g)
shifted_g[1:] = g[:-1]
return shifted_g
def onehot(self, g, Nlogits):
# 创建独热编码
return F.one_hot(g, num_classes=Nlogits).float()
# 生成独热编码的移位版本
shifted_g = self.shifted(g, Nsteps)
onehot_g = self.onehot(shifted_g, Nlogits)
# 计算行动对数概率和嵌入
o, z = self.predict_action_logits(onehot_g, e)
# 由于 z 可能是一个序列嵌入,我们只返回第一个时间步的嵌入,因为它不依赖于真值标签
z1 = z[0] if len(z.shape) > 1 else z # 确保 z 是序列
return o, z1
# A.6
def PolicyHead_inference(self, e, Nlogits, Nsteps, Nsamples=32):
# 初始化采样的动作和概率
a = torch.zeros((Nsamples, Nsteps), dtype=torch.long)
p = torch.ones(Nsamples)
# 外部循环,用于遍历样本集合中的每个样本
for s in range(Nsamples):
# 内部循环,用于遍历每个时间步
for i in range(Nsteps):
# 生成动作的对数概率和嵌入
o, z = self.predict_action_logits(F.one_hot(a[s], num_classes=Nlogits).float(), e)
# 从对数概率中采样动作和计算概率
prob = F.softmax(o[i], dim=-1)
action = torch.multinomial(prob, 1).item()
a[s, i] = action
p[s] *= prob[action]
# 保留第一个时间步的嵌入
if i == 0:
z1 = z if s == 0 else torch.cat((z1, z.unsqueeze(0)), 0)
return a, p, z1
# A.7
def ValueHead(self, x, n=8):
for _ in range(3):
x = F.relu(nn.Linear(x.shape[-1], 512)(x))
q = nn.Linear(512, n)(x)
return q
# A.8
def Quantile_loss(self, q, g, delta=1):
n = q.shape[0]
# 计算n个分位数τ
tau = torch.arange(1, n+1).float() / n - 0.5 / n
# 计算差异d
d = g - q
# 计算Huber损失
h = torch.nn.functional.huber_loss(q, g, delta=delta, reduction='none')
# 计算k
k = torch.abs(tau - (d < 0).float())
# 返回加权平均损失
return torch.mean(k * h)
# A.9
def ValueRiskManagement(self, q, uq=0.75):
n = len(q)
j = int(np.ceil(uq * n)) # 计算索引,根据给定的分位数uq
# 返回从j到n的元素的均值
return np.mean(q[j-1:]) # Python索引从0开始,因此使用j-1
# A.10
def AlphaTensor_Net_training(self, x, s, gaction, gvalue, c, Nlogits):
# 计算特征提取
e = self.Torso(x, s, c)
# 进行策略头训练,获取输出和状态
o, z1 = self.PolicyHead_training(e, gaction, Nlogits)
# 计算策略损失
lpolicy = torch.sum(F.cross_entropy(o, gaction, reduction='none'))
# 计算价值头输出
q = self.ValueHead(z1)
# 计算价值损失
lvalue = self.Quantile_loss(q, gvalue)
# 返回策略损失和价值损失
return lpolicy, lvalue
# A.11
def AlphaTensor_Net_inference(self, x, s, c, Nsamples, Nsteps, Nlogits):
# 计算特征提取
e = self.Torso(x, s, c)
# 进行策略头训练,获取输出和状态
a, p, z1 = self.PolicyHead_inference(e, Nsamples, Nsteps, Nlogits)
# 计算价值头输出
q = self.ValueHead(z1)
# 计算价值损失
q_1 = self.Quantile_loss(q)
# 返回策略损失和价值损失
return a,p,q_1