guet_gjl 2021-04-02 20:08 采纳率: 100%

    for t in range(n_iterations):

# get a single batch
# get the X and the targets values
X = batch[0]
ts = batch[1]
if torch.cuda.is_available():
X = X.cuda()
ts = ts.cuda()

# evaluate each task loss L_i(t)
task_loss = model(X, ts) # this will do a forward pass in the model and will also evaluate the loss
# compute the weighted loss w_i(t) * L_i(t)
# initialize the initial loss L(0) if t=0
if t == 0:
# set L(0)
if torch.cuda.is_available():
else:

# get the total loss
# do the backward pass to compute the gradients for the whole set of weights
# This is equivalent to compute each \nabla_W L_i(t)
loss.backward(retain_graph=True)

# set the gradients of w_i(t) to zero because these gradients have to be updated using the GradNorm loss

# switch for each weighting algorithm:

# get layer of shared weights
W = model.get_last_shared_layer()
print(W.parameters())
# G^{(i)}_w(t)
norms = []
# get the gradient of this task loss with respect to the shared parameters
# compute the norm
norms.append(torch.norm(torch.mul(model.weights[i], gygw[0])))
norms = torch.stack(norms)
#print('G_w(t): {}'.format(norms))

# compute the inverse training rate r_i(t)
# \curl{L}_i
if torch.cuda.is_available():
else:
# r_i(t)
inverse_train_rate = loss_ratio / np.mean(loss_ratio)
#print('r_i(t): {}'.format(inverse_train_rate))

# compute the mean norm \tilde{G}_w(t)
if torch.cuda.is_available():
mean_norm = np.mean(norms.data.cpu().numpy())
else:
mean_norm = np.mean(norms.data.numpy())
#print('tilde G_w(t): {}'.format(mean_norm))

# this term has to remain constant
constant_term = torch.tensor(mean_norm * (inverse_train_rate ** args.alpha), requires_grad=False)
if torch.cuda.is_available():
constant_term = constant_term.cuda()
#print('Constant term: {}'.format(constant_term))
#this is the GradNorm loss itself

# compute the gradient for the weights
model.weights.grad = torch.autograd.grad(grad_norm_loss, model.weights)[0]

model.weights.grad = torch.autograd.grad(
model.weights)[0]

• 写回答

#### 4条回答默认 最新

• CSDN专家-孙老师 2021-04-04 01:55
关注

本回答被题主选为最佳回答 , 对您是否有帮助呢?
评论

#### 悬赏问题

• ¥15 关于#c语言#的问题：构成555单稳态触发器，采用LED指示灯延时时间，对延时时间进行测量并显示（如楼道声控延时灯）需要Proteus仿真图和C语言代码
• ¥50 神舟笔记本，没有linux的驱动，装的Ubuntu系统，想把风扇速度调到最大
• ¥15 workstation加载centos进入emergency模式，查看日志报警如图，怎样解决呢？
• ¥50 如何用单纯形法寻优不能精准找不到给定的参数，并联机构误差识别，给定误差有7个？matlab
• ¥15 workstation加载centos进入emergency模式，查看日志报警如图，没有XFS,怎样解决呢？
• ¥15 应用商店如何检测在架应用内容是否违规？
• ¥15 Ubuntu系统配置PX4
• ¥50 nw.js调用activex
• ¥15 数据库获取信息反馈出错，直接查询了ref字段并且还使用了User文档的_id而不是自己的
• ¥15 将安全信息用到以下对象时发生以下错误：c:dumpstack.log.tmp 另一个程序正在使用此文件，因此无法访问