云深395 2022-04-12 09:36 采纳率: 20%
浏览 52
已结题

tensorflow怎么解决这个问题,是什么问题,解决方法?

问题遇到的现象和发生背景 模型无法按照我想的运行,不知道是否是shape还是什么没搞好?
问题相关代码,请勿粘贴截图

from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D,MaxPool2D,Activation,Dropout,Flatten,Dense
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.preprocessing.image import ImageDataGenerator,img_to_array,load_img
import numpy as np
import tensorflow as tf
import pathlib
data_dir = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
fname='flower_photos', untar=True)
data_root = pathlib.Path(data_dir)
print(data_root)
import random
all_image_paths=list(data_root.glob('/'))
all_image_paths=[str(path) for path in all_image_paths]
random.shuffle(all_image_paths)
print(len(all_image_paths))
label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())#读取目录并排序为类别名
label_to_index = dict((name, index) for index, name in enumerate(label_names))#创建类别字典
all_image_labels = [label_to_index[pathlib.Path(path).parent.name]
for path in all_image_paths] #图像parent path 对应类
@tf.function
def preprocess_image(path):
image_size=224
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [image_size, image_size])

# 数据增强

x=tf.image.random_brightness(x, 1)#亮度调整

x = tf.image.random_flip_up_down(x) #上下颠倒

x= tf.image.random_flip_left_right(x) # 左右镜像

x = tf.image.random_crop(x, [image_size, image_size, 3]) # 随机裁剪

image /= 255.0  # normalize to [0,1] range

image= normalize(image) # 标准化

return image

ds = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))
def load_and_preprocess_from_path_label(path, label):
return preprocess_image(path), label

image_label_ds = ds.map(load_and_preprocess_from_path_label)
image_label_ds
AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 16
image_count = len(all_image_paths)

设置一个和数据集大小一致的 shuffle buffer size(随机缓冲区大小)以保证数据

被充分打乱。

ds = image_label_ds.shuffle(buffer_size=image_count) # buffer_size等于数据集大小确保充分打乱
ds = ds.repeat() #repeat 适用于next(iter(ds))
ds = ds.batch(BATCH_SIZE)

当模型在训练的时候,prefetch 使数据集在后台取得 batch。

ds = ds.prefetch(buffer_size=AUTOTUNE)#随机缓冲区相关
vgg16_model = VGG16(weights='imagenet',include_top=False, input_shape=(224,224,3))
vgg16_model.summary()

搭建全连接层

top_model = Sequential()
top_model.add(Flatten(input_shape=vgg16_model.output_shape[1:]))
top_model.add(Dense(256,activation='relu'))
top_model.add(Dropout(0.5))
top_model.add(Dense(5,activation='softmax'))

model = Sequential()
model.add(vgg16_model)
model.add(top_model)
def change_range(image,label):
return 2*image-1, label
keras_ds = ds.map(change_range)

数据集可能需要几秒来启动,因为要填满其随机缓冲区。

image_batch, label_batch = next(iter(keras_ds))
feature_map_batch = vgg16_model(image_batch)
print(feature_map_batch.shape)

定义优化器,代价函数,训练过程中计算准确率

model.compile(optimizer=SGD(lr=1e-3,momentum=0.9),loss='categorical_crossentropy',metrics=['accuracy'])

model.fit(ds, epochs=1, steps_per_epoch=3)

运行结果及报错内容

model.compile(optimizer=SGD(lr=1e-3,momentum=0.9),loss='categorical_crossentropy',metrics=['accuracy'])

model.fit(ds, epochs=1, steps_per_epoch=3)
ValueError: Shapes (None, 1) and (None, 5) are incompatible

我的解答思路和尝试过的方法 将全连接层改了

top_model.add(Dense(5,activation='softmax'))改成top_model.add(Dense(1,activation='softmax'))
可以运行- 8s 347ms/step - loss: 0.0000e+00 - accuracy: 0.2083
<keras.callbacks.History at 0x209a21a4c18>
但我想要输出5个分类

我想要达到的结果可以对图片进行预测输出5个类别中的一个,就是全连接层输出为5个分类可以运行,可以预测,可以输出准确率,召回率,损失率三率
  • 写回答

1条回答 默认 最新

  • 爱晚乏客游 2022-04-12 10:21
    关注

    代码用控件提交,你这代码乱的.
    这个报错你要检查下是不是你优化器中loss的问题,你的数据标签是什么样子的,直接类别id的话不能用CategoricalCrossentropy()
    换成SparseCategoricalCrossentropy()或者binary_crossentropy试试看,如果你真的要用的话,要对标签进行编码才行。至于两者的区别,你可以看下链接
    https://blog.csdn.net/qq_40212975/article/details/108245786

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论 编辑记录

报告相同问题?

问题事件

  • 系统已结题 4月20日
  • 已采纳回答 4月12日
  • 创建了问题 4月12日

悬赏问题

  • ¥15 Python信息系统tkinter代码错在哪里
  • ¥15 FOR循环语句显示查询超过300S错误怎么办
  • ¥15 个人通讯录管理系统 C语言 程序设计 需要能运行成功的 结构体和数组 Visual C++编译器
  • ¥15 数电设计题 没有设计思路 不知道用什么芯片进行设计 求提供设计思路
  • ¥15 在动态多目标优化问题中,第一幅图展示的是问题DF6的相关定义和绘制的POS和POF图,请问图中公式PS(t)和PF(t)是如何推导的
  • ¥60 先数学建模,接着设计一种优化算法结合案例给出智能仓储四向穿梭车的调度计划
  • ¥15 Errno2:No such file or directory,在当前文件确实没有该图片,怎么解决?
  • ¥15 博世摄像头数据存储的问题(iscsi)
  • ¥15 如何实现对学生籍贯信息管理系统的选择排序
  • ¥15 写一个51单片机的时钟代码