 ###### guet_gjl
2021-04-02 20:08

    for t in range(n_iterations):

# get a single batch
# get the X and the targets values
X = batch
ts = batch
if torch.cuda.is_available():
X = X.cuda()
ts = ts.cuda()

# evaluate each task loss L_i(t)
task_loss = model(X, ts) # this will do a forward pass in the model and will also evaluate the loss
# compute the weighted loss w_i(t) * L_i(t)
# initialize the initial loss L(0) if t=0
if t == 0:
# set L(0)
if torch.cuda.is_available():
else:

# get the total loss
# do the backward pass to compute the gradients for the whole set of weights
# This is equivalent to compute each \nabla_W L_i(t)
loss.backward(retain_graph=True)

# set the gradients of w_i(t) to zero because these gradients have to be updated using the GradNorm loss

# switch for each weighting algorithm:

# get layer of shared weights
W = model.get_last_shared_layer()
print(W.parameters())
# G^{(i)}_w(t)
norms = []
# get the gradient of this task loss with respect to the shared parameters
# compute the norm
norms.append(torch.norm(torch.mul(model.weights[i], gygw)))
norms = torch.stack(norms)
#print('G_w(t): {}'.format(norms))

# compute the inverse training rate r_i(t)
# \curl{L}_i
if torch.cuda.is_available():
else:
# r_i(t)
inverse_train_rate = loss_ratio / np.mean(loss_ratio)
#print('r_i(t): {}'.format(inverse_train_rate))

# compute the mean norm \tilde{G}_w(t)
if torch.cuda.is_available():
mean_norm = np.mean(norms.data.cpu().numpy())
else:
mean_norm = np.mean(norms.data.numpy())
#print('tilde G_w(t): {}'.format(mean_norm))

# this term has to remain constant
constant_term = torch.tensor(mean_norm * (inverse_train_rate ** args.alpha), requires_grad=False)
if torch.cuda.is_available():
constant_term = constant_term.cuda()
#print('Constant term: {}'.format(constant_term))
#this is the GradNorm loss itself

# compute the gradient for the weights
model.weights.grad = torch.autograd.grad(grad_norm_loss, model.weights)

model.weights.grad = torch.autograd.grad(
model.weights)