我想知道码 2022-12-15 17:43 采纳率: 100%
浏览 762
已结题

报错:error: (-215:Assertion failed) !_src.empty() in function 'cv::cvtColor'

报错:cv2.error: OpenCV(4.5.5) D:\a\opencv-python\opencv-python\opencv\modules\imgproc\src\color.cpp:182: error: (-215:Assertion failed) !_src.empty() in function 'cv::cvtColor'
图片读取转RGB,有些图片运行后会报错,但是由于图片有1800多张,没法细找事哪张图片的问题,图片路径没有中文
# 载入与模型网络构建
import numpy as np
import glob
import cv2
from random import shuffle
import streamlit as st
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten
from tensorflow.keras.preprocessing.image import ImageDataGenerator  # 图像生成器
from tensorflow.keras.preprocessing import image
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.models import load_model
import os

def load_data(path):
    """
    获取数据并对应标签
    Args:
        path: 所有图片所在的大文件夹路径
    Returns: 训练集和验证集数据和对应标签
    """
    # 所有图片所在路径
    img_path = glob.glob(path + '/*/*')
    # 花对应的数值型标签, 和文件夹下的顺序一致
    classes = os.listdir(path)
    f_class = {k: v for k, v in zip(classes, range(len(classes)))}
    images = []  # 储存图片
    labels = []  # 储存图片对应的标签
    # 打乱数据
    shuffle(img_path)
    for im in img_path:
        # 读取图片
        img = cv2.imread(im)
        # 转换为RGB
        img = cv2.cvtColor(img, code=cv2.COLOR_BGR2RGB)
        # 中心裁剪(按比例裁剪, 只是为了剪掉周边的水印, 而尽可能的保留原图片信息)
        # img = np.array(tf.image.central_crop(img, 0.8))  # 裁剪比例0.8 # 或许没必要转换为array?
        h, w = img.shape[0], img.shape[1]
        img = img[int(h * 0.1):int(h * 0.9), int(w * 0.1):int(w * 0.9), :]
        # 缩放
        img = cv2.resize(img, (256, 256))
        # 图片缩放并标准化--这一步在ImageGenerator做
        lab = im.split('\\')[1]
        # 储存
        images.append(img)
        labels.append(f_class[lab])

    # 分训练集和验证集 训练集:验证集 = 8:2
    Length = len(images)
    tarin_img = np.array(images[: int(Length * 0.8)])
    val_img = np.array(images[int(Length * 0.8):])

    train_lab = np.array(labels[: int(Length * 0.8)])
    val_lab = np.array(labels[int(Length * 0.8):])

    # ----------定义图像增强生成器-----------
    # 训练集生成器
    train_genretor = ImageDataGenerator(rotation_range=90,  # 随机旋转
                                        zoom_range=0.2,     # 缩放
                                        rescale=1. / 255,   # 标准化
                                        )
    # 验证集生成器
    val_genretor = ImageDataGenerator(rescale=1. / 255, )   # 标准化
    # 实现增强, 并批量打包
    train = train_genretor.flow(x=tarin_img, y=train_lab, batch_size=32, )  # 要事先在load_data中shuffle,flow中shuffle会报错
    val = val_genretor.flow(x=val_img, y=val_lab, batch_size=32, )  # 若内存不够,则减小batch_size

    return train, val

def model_train(train, val, epochs):
    # ----------模型搭建-----------
    # 搭建多几层
    tf.keras.backend.clear_session()
    model = Sequential([
        Conv2D(64, kernel_size=3, input_shape=(256, 256, 3), activation='relu'),
        MaxPooling2D(),
        Conv2D(128, kernel_size=3, activation='relu'),
        MaxPooling2D(),
        Conv2D(128, kernel_size=3, activation='relu'),
        MaxPooling2D(),
        Conv2D(256, kernel_size=3, activation='relu'),
        MaxPooling2D(),
        Flatten(),
        Dense(128, activation='relu'),
        Dense(7, activation='softmax')
    ])

    # model.summary()
    # 模型编译
    lr = 0.0001  # 学习率
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
                  loss='sparse_categorical_crossentropy',
                  metrics=['acc'])

    checkpoint = ModelCheckpoint('./models/model.h5',     # 文件保存路径
                                   monitor='val_acc',       # 保存h5文件的条件
                                  verbose=1,               # 训练进度条状态
                                  save_weights_only=False, # 是否只保存权重
                                  save_best_only=True,     # 是否只保存最优的
                                  mode='max',              # monitor的判断条件
                                  perio=1)                 # CheckPoint之间的间隔的epoch数

    history = model.fit(train,                  # 训练集的生成器
                        epochs=epochs,          # 训练n轮
                        callbacks=[checkpoint],
                        validation_data=val,    # 验证集的生成器
                        shuffle=True,
                        )
    print('训练完毕!')
    # 保存网络
    model.save('./models/model.h5')
    print('模型保存成功!')
    print(history.history['val_acc'])
    accuracy = max(history.history['val_acc'])
    print('模型准确率:', accuracy)
    return accuracy

# streamlit缓存
@st.cache
def model_pred(predimg_path='./', model_path='./'):
    """
    模型预测
    Args:
        predimg_path: 预测数据所在路径
    model_path: 模型所在路径
    Returns: 预测类别
    """
    # 查看单张图片结果(load_img中路径名称需修改)
    img = image.load_img(predimg_path, target_size=(256, 256))
    data = image.img_to_array(img) / 255.0
    data = np.expand_dims(data, axis=0)

    # 花对应的数值型标签
    classes = os.listdir('./flowers/')
    f_class = {v: k for k, v in zip(classes, range(len(classes)))}
    # 加载训练好的模型
    model = load_model(model_path)
    # 预测
    result = model.predict(data).argmax()
    prediction = f_class[result]

    return prediction

# 获取数据
train,val = load_data('./flowers')
# 基于cnn的花卉识别模型
accuracy = model_train(train, val, epochs=20)
# 模型预测
model_pred('./test_images/rose.jpg','./models/model.h5')

img

  • 写回答

1条回答 默认 最新

  • ShowMeAI 2022-12-15 20:20
    关注

    你可以用一点小技巧,剔除掉有问题的图片,只读取OK的图片,中间部分修改后的代码如下,望采纳

    for im in img_path:
            try:
                    # 读取图片
                    img = cv2.imread(im)
                    # 转换为RGB
                    img = cv2.cvtColor(img, code=cv2.COLOR_BGR2RGB)
                    # 中心裁剪(按比例裁剪, 只是为了剪掉周边的水印, 而尽可能的保留原图片信息)
                    # img = np.array(tf.image.central_crop(img, 0.8))  # 裁剪比例0.8 # 或许没必要转换为array?
                    h, w = img.shape[0], img.shape[1]
                    img = img[int(h * 0.1):int(h * 0.9), int(w * 0.1):int(w * 0.9), :]
                    # 缩放
                    img = cv2.resize(img, (256, 256))
                    # 图片缩放并标准化--这一步在ImageGenerator做
                    lab = im.split('\\')[1]
                    # 储存
                    images.append(img)
                    labels.append(f_class[lab])
            except:
                    continue
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

问题事件

  • 系统已结题 12月23日
  • 已采纳回答 12月15日
  • 创建了问题 12月15日

悬赏问题

  • ¥15 用verilog实现tanh函数和softplus函数
  • ¥15 Hadoop集群部署启动Hadoop时碰到问题
  • ¥15 求京东批量付款能替代天诚
  • ¥15 slaris 系统断电后,重新开机后一直自动重启
  • ¥15 QTableWidget重绘程序崩溃
  • ¥15 谁能帮我看看这拒稿理由啥意思啊阿啊
  • ¥15 关于vue2中methods使用call修改this指向的问题
  • ¥15 idea自动补全键位冲突
  • ¥15 请教一下写代码,代码好难
  • ¥15 iis10中如何阻止别人网站重定向到我的网站