程序是用来对蜜蜂(bee)和黄蜂(wasp)分类的,用的模型是在vgg16上拼接的,代码如下
from keras.applications.vgg16 import VGG16
from keras.layers import Dense, Flatten, Activation, Dropout
from keras.models import Sequential
from keras.preprocessing.image import ImageDataGenerator
import keras
import shutil
import os
def creatDataGenerator(train_dir, test_dir):
train_data_generator = ImageDataGenerator(rescale=.1/255)
test_data_generator = ImageDataGenerator(rescale=.1/255)
train_generator = train_data_generator.flow_from_directory(train_dir,
target_size=(150,150),
batch_size=32,
class_mode='binary')
test_generator = test_data_generator.flow_from_directory(test_dir,
target_size=(150,150),
batch_size=32,
class_mode='binary')
return train_generator, test_generator
vgg_model = VGG16(weights='imagenet', include_top=False, input_shape=(150,150,3))
cla_model = Sequential()
cla_model.add(Flatten())
cla_model.add(Dense(512, activation='relu'))
cla_model.add(Dropout(0.5))
cla_model.add(Dense(1, activation='sigmoid'))
model = Sequential()
model.add(vgg_model)
model.add(cla_model)
model.compile(loss='binary_crossentropy', optimizer='RMSprop', metrics=['accuracy'])
train_generator, test_generator = creatDataGenerator(train_dir=r'C:\Users\ayana\.keras\datasets\bee-vs-wasp\train',
test_dir=r'C:\Users\ayana\.keras\datasets\bee-vs-wasp\test')
H = model.fit(train_generator,
steps_per_epoch=50,
epochs=30,
validation_data=test_generator,
validation_steps=50)
然后训练以后进行预测,选择的是黄蜂的10张图(蜜蜂预测出来也是同样的结果)
顺便训练的准确率也比较低,不到0.6,也一直不知道怎么能高一些
from keras.preprocessing.image import load_img, img_to_array
import numpy as np
def predict(i):
img_path = os.listdir(r'C:\Users\ayana\.keras\datasets\bee-vs-wasp\test\wasp')[i]
img = load_img(path='C:\\Users\\ayana\\.keras\\datasets\\bee-vs-wasp\\test\\wasp\\'+img_path,
target_size=(150,150))
img = np.expand_dims(img, axis=0)/255
prediction = model.predict(img)
return prediction
for i in range(10):
print(predict(i))
#>>>[[0.4714901]]
# [[0.4714901]]
# [[0.4714901]]
# [[0.4714901]]
# [[0.4714901]]
# [[0.4714901]]
# [[0.4714901]]
# [[0.4714901]]
# [[0.4714901]]
# [[0.4714901]]
再用np.argmax()的话就都是0了
被困了一天了,#求救