weixin_44626124 2021-04-17 15:36
浏览 20

tenforflow调用模型权重预测 测试集问题?

求教各位大佬,就是在训练模型时喂给模型有(x,y),在预测时候只喂x_test ,我用tf官网给出的一份代码试了试,有问题

def train_step(inp, tar):
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]

    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)

    with tf.GradientTape() as tape:
        predictions, _ = transformer(
                inp, tar_inp, 
                True, 
                enc_padding_mask, 
                combined_mask, 
                dec_padding_mask
            )
        loss = loss_function(tar_real, predictions)
    
    gradients = tape.gradient(loss, transformer.trainable_variables)    
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
    
    
    train_loss(loss)


for epoch in range(config.EPOCHS):
    start = time.time()

    train_loss.reset_states()

    for (batch, (inp, tar)) in enumerate(train_dataset):
        
        train_step(inp, tar)

            # 55k samples
            # we display 3 batch results -- 0th, middle and last one (approx)
            # 55k / 64 ~ 858; 858 / 2 = 429
        if batch % 429 == 0:
            print (f'Epoch {epoch + 1} Batch {batch} Loss {train_loss.result()}')
    if (epoch + 1) % 5 == 0:
        ckpt_save_path = ckpt_manager.save()
        print ('Saving checkpoint for epoch {} at {}'.format(epoch+1, ckpt_save_path))
    
    
    print ('Epoch {} Loss {:.4f}'.format(epoch + 1, train_loss.result()))
    

    print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))

以上为训练代码,训练出权重后,我调用权重预测数据出了问题

results=[]
for (batch,inp) in enumerate (test_dataset):
    encoder_input=inp
    output=[]
    tar_inp = tar[:, :-1]
    
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(encoder_input,output)
    predictions, _ = transformer(
                inp, tar_inp, 
                True, 
                enc_padding_mask, 
                combined_mask, 
                dec_padding_mask
            )
    results.append(predictions)

出的问题是

InvalidArgumentError: slice index 1 of dimension 0 out of bounds. [Op:StridedSlice] name: strided_slice/

求教各位大佬,多谢!!!!!!!

  • 写回答

0条回答 默认 最新

    报告相同问题?

    悬赏问题

    • ¥50 永磁型步进电机PID算法
    • ¥15 sqlite 附加(attach database)加密数据库时,返回26是什么原因呢?
    • ¥88 找成都本地经验丰富懂小程序开发的技术大咖
    • ¥15 如何处理复杂数据表格的除法运算
    • ¥15 如何用stc8h1k08的片子做485数据透传的功能?(关键词-串口)
    • ¥15 有兄弟姐妹会用word插图功能制作类似citespace的图片吗?
    • ¥200 uniapp长期运行卡死问题解决
    • ¥15 latex怎么处理论文引理引用参考文献
    • ¥15 请教:如何用postman调用本地虚拟机区块链接上的合约?
    • ¥15 为什么使用javacv转封装rtsp为rtmp时出现如下问题:[h264 @ 000000004faf7500]no frame?