weixin_44626124 2021-04-17 15:36 采纳率: 0%
浏览 21

tenforflow调用模型权重预测 测试集问题?

求教各位大佬,就是在训练模型时喂给模型有(x,y),在预测时候只喂x_test ,我用tf官网给出的一份代码试了试,有问题

def train_step(inp, tar):
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]

    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)

    with tf.GradientTape() as tape:
        predictions, _ = transformer(
                inp, tar_inp, 
                True, 
                enc_padding_mask, 
                combined_mask, 
                dec_padding_mask
            )
        loss = loss_function(tar_real, predictions)
    
    gradients = tape.gradient(loss, transformer.trainable_variables)    
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
    
    
    train_loss(loss)


for epoch in range(config.EPOCHS):
    start = time.time()

    train_loss.reset_states()

    for (batch, (inp, tar)) in enumerate(train_dataset):
        
        train_step(inp, tar)

            # 55k samples
            # we display 3 batch results -- 0th, middle and last one (approx)
            # 55k / 64 ~ 858; 858 / 2 = 429
        if batch % 429 == 0:
            print (f'Epoch {epoch + 1} Batch {batch} Loss {train_loss.result()}')
    if (epoch + 1) % 5 == 0:
        ckpt_save_path = ckpt_manager.save()
        print ('Saving checkpoint for epoch {} at {}'.format(epoch+1, ckpt_save_path))
    
    
    print ('Epoch {} Loss {:.4f}'.format(epoch + 1, train_loss.result()))
    

    print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))

以上为训练代码,训练出权重后,我调用权重预测数据出了问题

results=[]
for (batch,inp) in enumerate (test_dataset):
    encoder_input=inp
    output=[]
    tar_inp = tar[:, :-1]
    
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(encoder_input,output)
    predictions, _ = transformer(
                inp, tar_inp, 
                True, 
                enc_padding_mask, 
                combined_mask, 
                dec_padding_mask
            )
    results.append(predictions)

出的问题是

InvalidArgumentError: slice index 1 of dimension 0 out of bounds. [Op:StridedSlice] name: strided_slice/

求教各位大佬,多谢!!!!!!!

  • 写回答

1条回答 默认 最新

  • 码农阿豪@新空间代码工作室 Java领域优质创作者 2024-07-16 09:31
    关注
    让阿豪来帮你解答,本回答参考chatgpt3.5编写提供,如果还有疑问可以评论或留言
    根据您提供的代码片段和报错信息,问题出在您的预测代码中。您在预测时使用了`tar[:, :-1]`这个操作,但是在预测阶段并没有`tar`这个输入数据。这个操作是用于在训练中获取目标序列中除最后一个词外的所有词,但在预测中不需要这样做。 以下是修改后的预测代码示例:
    results = []
    for (batch, inp) in enumerate(test_dataset):
        encoder_input = inp
        output = tf.expand_dims([START_TOKEN], 0)  # Assuming you have a START_TOKEN defined
        for i in range(MAX_LENGTH):  # MAX_LENGTH is the maximum length for output sequence
            enc_padding_mask, combined_mask, dec_padding_mask = create_masks(encoder_input, output)
            predictions, _ = transformer(
                encoder_input, output, False,
                enc_padding_mask, combined_mask, dec_padding_mask
            )
            predictions = predictions[:, -1:, :]  # Get the last prediction
            predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
            if predicted_id == END_TOKEN:  # Assuming you have an END_TOKEN defined
                break
            output = tf.concat([output, predicted_id], axis=-1)
        results.append(output)
    

    在这个修改后的代码中,我们针对每个输入数据进行预测,动态根据模型的输出生成下一个单词,并将其添加到输出序列中。请确保定义了START_TOKENEND_TOKEN,以便在生成序列时使用。 另外,根据您的具体情况可能需要做一些调整,比如设置合适的MAX_LENGTH,确保模型输出的序列不会无限增长等。希望这能帮助解决您的问题,如有其他疑问,请随时询问。

    评论

报告相同问题?

悬赏问题

  • ¥15 metadata提取的PDF元数据,如何转换为一个Excel
  • ¥15 关于arduino编程toCharArray()函数的使用
  • ¥100 vc++混合CEF采用CLR方式编译报错
  • ¥15 coze 的插件输入飞书多维表格 app_token 后一直显示错误,如何解决?
  • ¥15 vite+vue3+plyr播放本地public文件夹下视频无法加载
  • ¥15 c#逐行读取txt文本,但是每一行里面数据之间空格数量不同
  • ¥50 如何openEuler 22.03上安装配置drbd
  • ¥20 ING91680C BLE5.3 芯片怎么实现串口收发数据
  • ¥15 无线连接树莓派,无法执行update,如何解决?(相关搜索:软件下载)
  • ¥15 Windows11, backspace, enter, space键失灵