_北海道人 2022-12-17 16:14 采纳率: 0%
浏览 70
已结题

tensorflow官网照抄代码正确率过低

问题遇到的现象和发生背景

我在tensorflow官网看到的一个代码,原版照抄下来之后,训练正确率只有0.1677

用代码块功能插入代码,请勿粘贴截图。 不用代码块回答率下降 50%
import keras
import tensorflow as tf
import numpy as np
from keras.layers import TextVectorization
import string
import re

batch_size = 32
raw_train_ds = tf.keras.preprocessing.text_dataset_from_directory(
    "aclImdb/train",
    batch_size=batch_size,
    validation_split=0.2,
    subset="training",
    seed=1337,
)
raw_val_ds = tf.keras.preprocessing.text_dataset_from_directory(
    "aclImdb/train",
    batch_size=batch_size,
    validation_split=0.2,
    subset="validation",
    seed=1337,
)
raw_test_ds = tf.keras.preprocessing.text_dataset_from_directory(
    "aclImdb/test", batch_size=batch_size
)

print(f"Number of batches in raw_train_ds: {raw_train_ds.cardinality()}")
print(f"Number of batches in raw_val_ds: {raw_val_ds.cardinality()}")
print(f"Number of batches in raw_test_ds: {raw_test_ds.cardinality()}")

def custom_standardization(input_data):
    lowercase = tf.strings.lower(input_data)
    stripped_html = tf.strings.regex_replace(lowercase,'<bar />',' ')
    '''
        string.punctuation = r"""!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"""
        regex_replace的第二个位置可以填入字符串或标量字符串‘tensor’,要使用的正则表达式
    '''
    return tf.strings.regex_replace(stripped_html,f'[{re.escape(string.punctuation)}]','')


max_features = 20000
embedding_dim = 128
sequence_length = 500

vectorize_layer = TextVectorization(
    standardize=custom_standardization,
    max_tokens=max_features,
    output_mode='int',
    output_sequence_length=sequence_length,
)

text_ds = raw_train_ds.map(lambda x, y:x)
vectorize_layer.adapt(text_ds)

def vectorize_text(text,label):
    text = tf.expand_dims(text,-1)
    return vectorize_layer(text),label

train_ds = raw_train_ds.map(vectorize_text)
val_ds = raw_val_ds.map(vectorize_text)
test_ds = raw_test_ds.map(vectorize_text)

#Do async prefetching / buffering of the data for best performance on GPU
train_ds = train_ds.cache().prefetch(buffer_size = 10)
val_ds = val_ds.cache().prefetch(buffer_size = 10)
test_ds = test_ds.cache().prefetch(buffer_size = 10)

"""
    build a model
"""
from keras import layers
inputs = keras.Input(shape=(None,),dtype='int64')

x = layers.Embedding(max_features,embedding_dim)(inputs)
x = layers.Dropout(0.5)(x)

x = layers.Conv1D(128,7,padding='valid',activation='relu',strides=3)(x)
x = layers.Conv1D(128,7,padding='valid',activation='relu',strides=3)(x)
x = layers.GlobalMaxPool1D()(x)

x = layers.Dense(128,activation='relu')(x)
x = layers.Dropout(0.5)(x)

predictions = layers.Dense(1,activation='sigmoid',name='predictions')(x)

model = keras.Model(inputs,predictions)

model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])

epochs = 3
model.fit(train_ds,validation_data=val_ds,epochs=epochs)

运行结果及详细报错内容

1875/1875 - 30s 13ms/step - loss: -287010357248.0000 - accuracy: 0.1665 - val_loss: -1532206383104.0000 - val_accuracy: 0.1677
Epoch 2/3
1875/1875 - 14s 8ms/step - loss: -10998242082816.0000 - accuracy: 0.1664 - val_loss: -27677966729216.0000 - val_accuracy: 0.1677
Epoch 3/3
1875/1875 - 14s 7ms/step - loss: -73722828423168.0000 - accuracy: 0.1664 - val_loss: -134023263289344.0000 - val_accuracy: 0.1677

我的解答思路和尝试过的方法,不写自己思路的,回答率下降 60%

原版照抄官网,之前一段时间我也遇到过这样的问题,当时的正确率也是这个数字。附一下源代码链接:[https://keras.io/examples/nlp/text_classification_from_scratch/](Text classification from scratch)

  • 写回答

5条回答 默认 最新

  • |__WhoAmI__| 2022-12-17 16:21
    关注

    可能的原因:

    1、使用的数据集是 IMDB 电影评论数据集,它是一个二分类问题,而你的模型输出层是一个单神经元输出层,因此可能会导致正确率较低。可以在输出层使用 sigmoid 激活函数,并使用二元交叉熵损失函数来解决这个问题。

    2、你的模型可能缺少足够的深度和宽度,因此无法对数据进行足够的拟合。可以尝试增加模型的层数,或者增加每层的神经元数量来解决这个问题。

    3、可能没有足够的训练数据。深度学习模型通常需要大量的数据来进行训练,如果的数据量较少,可能会导致正确率较低。可以尝试扩充数据集或使用数据增强来解决这个问题。

    4、你的模型可能存在过拟合问题。过拟合意味着模型在训练集上的表现很好,但是在测试集上的表现较差。可以尝试使用正则化技术(如 Dropout)或使用早停法来解决这个问题。

    5、你的模型可能受到超参数的影响。超参数是指模型的设置(如学习率,优化器类型等),它们可以影响模型的训练过程和最终的性能。可以尝试调整超参数来提高模型的正确率。

    评论

报告相同问题?

问题事件

  • 系统已结题 12月25日
  • 创建了问题 12月17日

悬赏问题

  • ¥15 如何构建全国统一的物流管理平台?
  • ¥100 ijkplayer使用AndroidStudio/CMake编译,如何支持 rtsp 直播流?
  • ¥20 和学习数据的传参方式,选择正确的传参方式有关
  • ¥15 这是网络安全里面的poem code
  • ¥15 用js遍历数据并对非空元素添加css样式
  • ¥15 使用autodl云训练,希望有直接运行的代码(关键词-数据集)
  • ¥50 python写segy数据出错
  • ¥20 关于线性结构的问题:希望能从头到尾完整地帮我改一下,困扰我很久了
  • ¥30 3D多模态医疗数据集-视觉问答
  • ¥20 设计一个二极管稳压值检测电路