class LinearLayer(torch.nn.Module):
def __init__(self,in_feature,hid_feature,out_feature):
super(LinearLayer,self).__init__()
self.dense1=torch.nn.Linear(in_feature,hid_feature)
self.dense2=torch.nn.Linear(hid_feature,out_feature)
def forward(self,x):
x=self.dense1(x)
x=torch.nn.functional.leaky_relu(x)
x=self.dense2(x)
x=torch.nn.functional.leaky_relu(x)
return x
class GraphAttentionLayer(torch.nn.Module):
def __init__(self,in_feature,out_feature,concat=True):
super(GraphAttentionLayer, self).__init__()
self.concat = concat
self.in_feature=in_feature
self.out_feature=out_feature
self.adj = torch.tensor([[1, 1, 1, 1, 1],
[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0]
])
# 定义可训练参数,即论文中的W和a
self.W = torch.nn.Parameter(torch.zeros(size=(1,in_feature, out_feature)))
torch.nn.init.xavier_uniform_(self.W.data, gain=1.414) # xavier初始化
self.a = torch.nn.Parameter(torch.zeros(size=(1,5, 2 * out_feature, 1)))
torch.nn.init.xavier_uniform_(self.a.data, gain=1.414) # xavier初始化
def forward(self,x):#x:[batch,5,in_feature]
batch = x.size()[0]
N = 5
W_batch=self.W.repeat(batch,1,1)## [batch,in_feature,out_feature]???????????/repeat存疑
MLP_result = torch.matmul(x, W_batch) # [batch,5,in_feature]x[batch,in_feature,out_feature]->[batch,5,out_feature]
adj = self.adj # [5,5]
## 两种不同的repeat方式
h_i=MLP_result.repeat_interleave(N, dim=1)#[batch,5,out_feature]->[batch,25,out_feature]
h_j = MLP_result.repeat(1, N, 1) # [batch,25, out_feature]
a_input = torch.cat([h_i, h_j], dim=2).view(batch, N, N, 2 * self.out_feature) # [batch,5*5,out_feature*2]->[batch,5,5,out_feature*2]
a_batch=self.a.repeat(batch,1,1,1)#[1,5,out_feature*2,1]->[batch,5,out_feature*2,1]
e = torch.matmul(a_input,a_batch).squeeze(3) # [batch,5,5,out_feature*2]x[batch,5,out_feature*2,1]->[batch,5,5,1]->[batch,5,5]
e = torch.nn.functional.leaky_relu(e)
zero_vec = -1e12 * torch.ones_like(e)
attention = torch.where(adj > 0, e, zero_vec) # [batch,5,5]
attention = torch.nn.functional.softmax(attention, dim=2) # 归一化
attention = torch.nn.functional.dropout(attention)
h_prime = torch.matmul(attention, MLP_result) # [batch,N,N]X[batch,N,out_feature]->[batch,N,out_feature]
print(h_prime.size())
if self.concat:
h_prime = torch.nn.functional.elu(h_prime)
out = h_prime[:, 0, :] # [batch,out_feature]
return out
class Network(torch.nn.Module):
def __init__(self):
super(Network, self).__init__()
self.mlp_hid_feature = 64
self.mlp_out_feature = 64
self.att_out_feature=64
self.out_hid_feature=64
self.mlp=LinearLayer(N_STATE,self.mlp_hid_feature,self.mlp_out_feature)#[batch,5,state]->[batch,5,mlp_out_feature]
self.att=GraphAttentionLayer(self.mlp_out_feature,self.att_out_feature)#[batch,5,mlp_out_feature]->[batch,att_out_feature]
self.out=LinearLayer(self.att_out_feature,self.out_hid_feature,N_ACTION)#[batch,out_feature]->[batch,action]
# 定义可训练参数,即论文中的W和a
def forward(self, x):
MLP_result = self.mlp(x) #[batch,5,state]->[batch,5,mlp_out_feature]
attention_result = self.att(MLP_result)#[batch,5,mlp_out_feature]->[batch,att_out_feature]
out = self.out(attention_result)#[batch,out_feature]->[batch,action]
return out
代码如上所示 训练是强化学习DQN 可以断定是网络结构的问题 因为之前把网络结构改成简单的全连接loss是可以收敛的 现在中间加了层GAT,loss就一直居高不下。开始的loss是0.几几 然后慢慢飙升到一千多,在这个基础上开始上下震荡……