NEWBIE829 2021-06-06 10:50 采纳率: 100%
浏览 399
已采纳

网上找到的densenet分类的代码 运行以后训练数据的准确率很高 测试数据的准确率一直很低

from tensorflow.keras.callbacks import ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import tensorflow as tf
from PIL import Image
import numpy as np
import itertools
import os

im_height = 512
im_width = 512
batch_size = 32
epochs = 10


image_path = "./input/chest-xray-pneumonia/chest_xray/"
train_dir = image_path + "train"
validation_dir = image_path + "test"
test_dir = image_path + "valid"

train_image_generator = ImageDataGenerator(
    rescale=1. / 255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)
validation_image_generator = ImageDataGenerator(rescale=1. / 255)
test_image_generator = ImageDataGenerator(rescale=1. / 255)

train_data_gen = train_image_generator.flow_from_directory(directory=train_dir,
                                                           batch_size=batch_size,
                                                           shuffle=True,
                                                           target_size=(im_height, im_width),
                                                           class_mode='categorical')

total_train = train_data_gen.n

val_data_gen = validation_image_generator.flow_from_directory(directory=validation_dir,
                                                              batch_size=batch_size,
                                                              shuffle=False,
                                                              target_size=(im_height, im_width),
                                                              class_mode='categorical')

total_val = val_data_gen.n

test_data_gen = test_image_generator.flow_from_directory(directory=test_dir,
                                                         batch_size=batch_size,
                                                         shuffle=False,
                                                         target_size=(im_height, im_width),
                                                         class_mode='categorical')

total_test = test_data_gen.n

covn_base = tf.keras.applications.DenseNet201(weights='imagenet', include_top = False,input_shape=(im_height,im_width,3))
covn_base.trainable = False

model = tf.keras.Sequential()
model.add(covn_base)
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dropout(rate=0.2))
model.add(tf.keras.layers.Dense(2, activation='softmax'))
model.summary()


model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss = 'categorical_crossentropy',
    metrics=['accuracy']
)


reduce_lr = ReduceLROnPlateau(
                                monitor='val_loss',
                                factor=0.1,
                                patience=2,
                                mode='auto',
                                verbose=1
                             )


checkpoint = ModelCheckpoint(
                                filepath='DenseNet201.ckpt',
                                monitor='val_accuracy',
                                save_weights_only=False,
                                save_best_only=True,
                                mode='auto',
                                period=1
                            )

history = model.fit(x=train_data_gen,
                    steps_per_epoch=total_train // batch_size,
                    epochs=epochs,
                    validation_data=val_data_gen,
                    validation_steps=total_val // batch_size,
                    callbacks=[checkpoint, reduce_lr])

 

  • 写回答

5条回答 默认 最新

  • Briwisdom 2021-06-06 11:26
    关注

    train_acc和val_acc在这3个epoch都是在增长的,从网上直接下载的代码精度一个是0.9+,一个0.6+也算是正常的。接下来就需要题主针对自己的数据集找特点进行网络调参,或者数据预处理的优化了。比如数据的预处理优化,或者学习率,优化器的调整。

    另外,不知道你的测试集和验证集的比例,一般7:3,以及它们是否随机分配产生(即训练集和数据集是否符合同一分布特点)。这个也会是影响训练集和验证集精度相差较多的原因。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(4条)

报告相同问题?

悬赏问题

  • ¥50 有数据,怎么建立模型求影响全要素生产率的因素
  • ¥50 有数据,怎么用matlab求全要素生产率
  • ¥15 TI的insta-spin例程
  • ¥15 完成下列问题完成下列问题
  • ¥15 C#算法问题, 不知道怎么处理这个数据的转换
  • ¥15 YoloV5 第三方库的版本对照问题
  • ¥15 请完成下列相关问题!
  • ¥15 drone 推送镜像时候 purge: true 推送完毕后没有删除对应的镜像,手动拷贝到服务器执行结果正确在样才能让指令自动执行成功删除对应镜像,如何解决?
  • ¥15 求daily translation(DT)偏差订正方法的代码
  • ¥15 js调用html页面需要隐藏某个按钮