weixin_44626124 2021-04-17 15:36
浏览 20

tenforflow调用模型权重预测 测试集问题?

求教各位大佬,就是在训练模型时喂给模型有(x,y),在预测时候只喂x_test ,我用tf官网给出的一份代码试了试,有问题

def train_step(inp, tar):
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]

    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)

    with tf.GradientTape() as tape:
        predictions, _ = transformer(
                inp, tar_inp, 
                True, 
                enc_padding_mask, 
                combined_mask, 
                dec_padding_mask
            )
        loss = loss_function(tar_real, predictions)
    
    gradients = tape.gradient(loss, transformer.trainable_variables)    
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
    
    
    train_loss(loss)


for epoch in range(config.EPOCHS):
    start = time.time()

    train_loss.reset_states()

    for (batch, (inp, tar)) in enumerate(train_dataset):
        
        train_step(inp, tar)

            # 55k samples
            # we display 3 batch results -- 0th, middle and last one (approx)
            # 55k / 64 ~ 858; 858 / 2 = 429
        if batch % 429 == 0:
            print (f'Epoch {epoch + 1} Batch {batch} Loss {train_loss.result()}')
    if (epoch + 1) % 5 == 0:
        ckpt_save_path = ckpt_manager.save()
        print ('Saving checkpoint for epoch {} at {}'.format(epoch+1, ckpt_save_path))
    
    
    print ('Epoch {} Loss {:.4f}'.format(epoch + 1, train_loss.result()))
    

    print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))

以上为训练代码,训练出权重后,我调用权重预测数据出了问题

results=[]
for (batch,inp) in enumerate (test_dataset):
    encoder_input=inp
    output=[]
    tar_inp = tar[:, :-1]
    
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(encoder_input,output)
    predictions, _ = transformer(
                inp, tar_inp, 
                True, 
                enc_padding_mask, 
                combined_mask, 
                dec_padding_mask
            )
    results.append(predictions)

出的问题是

InvalidArgumentError: slice index 1 of dimension 0 out of bounds. [Op:StridedSlice] name: strided_slice/

求教各位大佬,多谢!!!!!!!

  • 写回答

0条回答 默认 最新

    报告相同问题?

    悬赏问题

    • ¥30 求一下解题思路,完全不懂。
    • ¥30 关于#硬件工程#的问题:求一下解题思路
    • ¥15 运筹学对偶单纯行法构造扩充问题
    • ¥20 XP系统的老电脑一开机就提示找不到rundll.exe,付费求解
    • ¥15 milvus查询出来的score怎么转换成0-1之间的相似性
    • ¥15 多ip服务器站群如何搭建l2tp服务器
    • ¥15 lvgl V9移植到linux开发板
    • ¥15 VB.net中在窗体中创建一个button控件来关闭窗体,但是提示错误,我该怎么办
    • ¥15 网上下载好的程序但是arduinoIDE编程报错,运行不了,哪里出错了,能具体给改一下吗
    • ¥15 Sharepoint JS开发 付费技术指导