我的代码如下所示:
import tensorflow as tf
import numpy as np
import cv2
import os
#全局变量
classes = {'grass':0,'soldiers':1}#数据分类
path_grass='D:/AIR_space/data/grass' #草丛图片路径
path_soldiers='D:/AIR_space/data/soldiers' #士兵图片路径
picture=[] #存放图片
labels=[] #存放标签
files_grass=os.listdir(path_grass)
files_soldiers=os.listdir(path_soldiers)
#读取图片数据并且将其存入到图片和标签变量中
def read_image(path,files,shape=(32,32)):
for f in files:
f_path=path+'/'+f
img=cv2.imread(f_path)
img=cv2.resize(img,shape)
img=img.astype(np.float32)
picture.append(img)
split_f=f.split('_')
label_f=int(classes[split_f[0]])
labels.append(label_f)
#建立dataset
def data_set(data,label):
train_data=tf.data.Dataset.from_tensor_slices(data)
train_labels=tf.data.Dataset.from_tensor_slices(label).map(lambda z: tf.one_hot(z,len(classes)))
train_dataset=tf.data.Dataset.zip((train_data,train_labels)).shuffle(1000).repeat(10).batch(256)
return train_dataset
#建立CNN模型
def build_model():
model=tf.keras.Sequential()
#第一层卷积
model.add(tf.keras.layers.Conv2D(64,(3,3),padding='same',activation='relu',input_shape=(32,32,3)))
model.add(tf.keras.layers.MaxPooling2D(padding='same'))
#第二层卷积
model.add(tf.keras.layers.Conv2D(128,(3,3),padding='same',activation='relu'))
model.add(tf.keras.layers.Conv2D(256,(3,3),padding='same',activation='relu'))
model.add(tf.keras.layers.MaxPooling2D(padding='same'))
#全连接层
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(64,activation='relu'))
model.add(tf.keras.layers.Dense(32,activation='relu'))
model.add(tf.keras.layers.Dropout(0.3))
model.add(tf.keras.layers.Dense(2,activation='softmax'))
return model
#主函数部分
read_image(path_grass,files_grass)
read_image(path_soldiers,files_soldiers)
train_data=data_set(picture,labels)
print(train_data)
model_cnn=build_model()
#model_cnn.build(input_shape=[None,32,32,3])
model_cnn.summary()
model_cnn.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss='binary_crossentropy',metrics=['accuracy'])
model_cnn.fit(train_data,batch_size=19,epochs=20)
数据集是192张图片,小兵96张,对应soldiers,草丛96张,对应grass,然后我在训练的时候,发现输出如下图:
我很不理解epoch下面那个x/8的那个8表示什么,哪里来的,看参考书说表示训练集数量,可我加载的数据集是192张,求解释那个8是怎么回事,如何修改,谢谢啦