云深395 2022-04-12 09:36 采纳率: 20%
浏览 52
已结题

tensorflow怎么解决这个问题,是什么问题,解决方法?

问题遇到的现象和发生背景 模型无法按照我想的运行,不知道是否是shape还是什么没搞好?
问题相关代码,请勿粘贴截图

from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D,MaxPool2D,Activation,Dropout,Flatten,Dense
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.preprocessing.image import ImageDataGenerator,img_to_array,load_img
import numpy as np
import tensorflow as tf
import pathlib
data_dir = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
fname='flower_photos', untar=True)
data_root = pathlib.Path(data_dir)
print(data_root)
import random
all_image_paths=list(data_root.glob('/'))
all_image_paths=[str(path) for path in all_image_paths]
random.shuffle(all_image_paths)
print(len(all_image_paths))
label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())#读取目录并排序为类别名
label_to_index = dict((name, index) for index, name in enumerate(label_names))#创建类别字典
all_image_labels = [label_to_index[pathlib.Path(path).parent.name]
for path in all_image_paths] #图像parent path 对应类
@tf.function
def preprocess_image(path):
image_size=224
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [image_size, image_size])

# 数据增强

x=tf.image.random_brightness(x, 1)#亮度调整

x = tf.image.random_flip_up_down(x) #上下颠倒

x= tf.image.random_flip_left_right(x) # 左右镜像

x = tf.image.random_crop(x, [image_size, image_size, 3]) # 随机裁剪

image /= 255.0  # normalize to [0,1] range

image= normalize(image) # 标准化

return image

ds = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))
def load_and_preprocess_from_path_label(path, label):
return preprocess_image(path), label

image_label_ds = ds.map(load_and_preprocess_from_path_label)
image_label_ds
AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 16
image_count = len(all_image_paths)

设置一个和数据集大小一致的 shuffle buffer size(随机缓冲区大小)以保证数据

被充分打乱。

ds = image_label_ds.shuffle(buffer_size=image_count) # buffer_size等于数据集大小确保充分打乱
ds = ds.repeat() #repeat 适用于next(iter(ds))
ds = ds.batch(BATCH_SIZE)

当模型在训练的时候,prefetch 使数据集在后台取得 batch。

ds = ds.prefetch(buffer_size=AUTOTUNE)#随机缓冲区相关
vgg16_model = VGG16(weights='imagenet',include_top=False, input_shape=(224,224,3))
vgg16_model.summary()

搭建全连接层

top_model = Sequential()
top_model.add(Flatten(input_shape=vgg16_model.output_shape[1:]))
top_model.add(Dense(256,activation='relu'))
top_model.add(Dropout(0.5))
top_model.add(Dense(5,activation='softmax'))

model = Sequential()
model.add(vgg16_model)
model.add(top_model)
def change_range(image,label):
return 2*image-1, label
keras_ds = ds.map(change_range)

数据集可能需要几秒来启动,因为要填满其随机缓冲区。

image_batch, label_batch = next(iter(keras_ds))
feature_map_batch = vgg16_model(image_batch)
print(feature_map_batch.shape)

定义优化器,代价函数,训练过程中计算准确率

model.compile(optimizer=SGD(lr=1e-3,momentum=0.9),loss='categorical_crossentropy',metrics=['accuracy'])

model.fit(ds, epochs=1, steps_per_epoch=3)

运行结果及报错内容

model.compile(optimizer=SGD(lr=1e-3,momentum=0.9),loss='categorical_crossentropy',metrics=['accuracy'])

model.fit(ds, epochs=1, steps_per_epoch=3)
ValueError: Shapes (None, 1) and (None, 5) are incompatible

我的解答思路和尝试过的方法 将全连接层改了

top_model.add(Dense(5,activation='softmax'))改成top_model.add(Dense(1,activation='softmax'))
可以运行- 8s 347ms/step - loss: 0.0000e+00 - accuracy: 0.2083
<keras.callbacks.History at 0x209a21a4c18>
但我想要输出5个分类

我想要达到的结果可以对图片进行预测输出5个类别中的一个,就是全连接层输出为5个分类可以运行,可以预测,可以输出准确率,召回率,损失率三率
  • 写回答

1条回答 默认 最新

  • 爱晚乏客游 2022-04-12 10:21
    关注

    代码用控件提交,你这代码乱的.
    这个报错你要检查下是不是你优化器中loss的问题,你的数据标签是什么样子的,直接类别id的话不能用CategoricalCrossentropy()
    换成SparseCategoricalCrossentropy()或者binary_crossentropy试试看,如果你真的要用的话,要对标签进行编码才行。至于两者的区别,你可以看下链接
    https://blog.csdn.net/qq_40212975/article/details/108245786

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论 编辑记录

报告相同问题?

问题事件

  • 系统已结题 4月20日
  • 已采纳回答 4月12日
  • 创建了问题 4月12日

悬赏问题

  • ¥15 如何构建全国统一的物流管理平台?
  • ¥100 ijkplayer使用AndroidStudio/CMake编译,如何支持 rtsp 直播流?
  • ¥20 和学习数据的传参方式,选择正确的传参方式有关
  • ¥15 这是网络安全里面的poem code
  • ¥15 用js遍历数据并对非空元素添加css样式
  • ¥15 使用autodl云训练,希望有直接运行的代码(关键词-数据集)
  • ¥50 python写segy数据出错
  • ¥20 关于线性结构的问题:希望能从头到尾完整地帮我改一下,困扰我很久了
  • ¥30 3D多模态医疗数据集-视觉问答
  • ¥20 设计一个二极管稳压值检测电路