ayanamiprpr 2021-06-18 20:41 采纳率: 100%
浏览 1076
已采纳

keras 二分类预测结果几乎全是一个值

程序是用来对蜜蜂(bee)和黄蜂(wasp)分类的,用的模型是在vgg16上拼接的,代码如下

from keras.applications.vgg16 import VGG16
from keras.layers import Dense, Flatten, Activation, Dropout
from keras.models import Sequential
from keras.preprocessing.image import ImageDataGenerator
import keras
import shutil
import os

def creatDataGenerator(train_dir, test_dir):
    train_data_generator = ImageDataGenerator(rescale=.1/255)
    test_data_generator = ImageDataGenerator(rescale=.1/255)

    train_generator = train_data_generator.flow_from_directory(train_dir,
                                                            target_size=(150,150),
                                                            batch_size=32,
                                                            class_mode='binary')
    test_generator = test_data_generator.flow_from_directory(test_dir,
                                                            target_size=(150,150),
                                                            batch_size=32,
                                                            class_mode='binary')
    return train_generator, test_generator

vgg_model = VGG16(weights='imagenet', include_top=False, input_shape=(150,150,3))

cla_model = Sequential()
cla_model.add(Flatten())
cla_model.add(Dense(512, activation='relu'))
cla_model.add(Dropout(0.5))
cla_model.add(Dense(1, activation='sigmoid'))

model = Sequential()
model.add(vgg_model)
model.add(cla_model)

model.compile(loss='binary_crossentropy', optimizer='RMSprop', metrics=['accuracy'])

train_generator, test_generator = creatDataGenerator(train_dir=r'C:\Users\ayana\.keras\datasets\bee-vs-wasp\train',
                                                    test_dir=r'C:\Users\ayana\.keras\datasets\bee-vs-wasp\test')
H = model.fit(train_generator,
              steps_per_epoch=50,
              epochs=30,
              validation_data=test_generator,
              validation_steps=50)

然后训练以后进行预测,选择的是黄蜂的10张图(蜜蜂预测出来也是同样的结果)

顺便训练的准确率也比较低,不到0.6,也一直不知道怎么能高一些

from keras.preprocessing.image import load_img, img_to_array
import numpy as np

def predict(i):
    img_path = os.listdir(r'C:\Users\ayana\.keras\datasets\bee-vs-wasp\test\wasp')[i]
    img = load_img(path='C:\\Users\\ayana\\.keras\\datasets\\bee-vs-wasp\\test\\wasp\\'+img_path, 
                    target_size=(150,150))
    img = np.expand_dims(img, axis=0)/255
    prediction = model.predict(img)
    return prediction

for i in range(10):
    print(predict(i))

#>>>[[0.4714901]]
#    [[0.4714901]]
#    [[0.4714901]]
#    [[0.4714901]]
#    [[0.4714901]]
#    [[0.4714901]]
#    [[0.4714901]]
#    [[0.4714901]]
#    [[0.4714901]]
#    [[0.4714901]]

再用np.argmax()的话就都是0了

被困了一天了,#求救

  • 写回答

2条回答 默认 最新

  • 兰振lanzhen 2021-06-18 22:55
    关注

    应该是这个吧,你训练之后得到的模型是H,prediction = H.predict(img)  

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

悬赏问题

  • ¥15 有没有人会打学生成绩管理系统呀
  • ¥15 在使用Fiddler和夜神模拟器抓包的时候一直出现443该怎么办啊QAQ搜了好几个笔记都没有解决
  • ¥15 3x7的二维数组A、B、C,A中的任意1个数组元素与B的任意1个数组元素、同时又与C的任意1个数组元素比较,把不同位置出现相同数的比较称为无意义,反之称为有意义,把有意义的比较打印输出。
  • ¥20 预测模型怎么处理原始数据(随机森林)
  • ¥20 请问discuz3.5如何实现插入ckplayer全能播放器功能呢?
  • ¥15 thingsboard代码编译出错误
  • ¥15 博途v18仿真报错怎么解决
  • ¥15 欧姆龙plc枕式包装机 ST编程
  • ¥15 为啥快手广告联盟的广告这么难出来
  • ¥15 k8s集群重启后,kubelet一直报systemctl restart kubelet.service "Failed to delete cgroup paths"