import tensorflow as tf
from matplotlib import pyplot as plt
import os
import numpy as np
import glob
from torch.utils import data
from PIL import Image
from torchvision import transforms
# 数据集内图片的验证
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
dataset_path = "rps_data_sample"
labels = []
for i in os.listdir(dataset_path):
if os.path.isdir(os.path.join(dataset_path, i)):
labels.append(i)
NUM_EXAMPLES = 5
for label in labels:
label_dir = os.path.join(dataset_path, label)
example_filenames = os.listdir(label_dir)[:NUM_EXAMPLES]
fig, axs = plt.subplots(1, 5, figsize=(10,2))
for i in range(NUM_EXAMPLES):
axs[i].imshow(plt.imread(os.path.join(label_dir, example_filenames[i])))
axs[i].get_xaxis().set_visible(False)
axs[i].get_yaxis().set_visible(False)
fig.suptitle(f'Showing {NUM_EXAMPLES} examples for {label}')
# 创建data.Dataset子类Mydataset
class Mydataset(data.Dataset):
def __init__(self,root):
self.imgs_path = root
def __getitem__(self, index):
img_path = self.imgs_path[index]
return img_path
def __len__(self):
return len(self.imgs_path)
# glob遍历数据路径
all_imgs_path = glob.glob('rps_data_sample\*\*.jpg')
for var in all_imgs_path:
print(var)
# 建立gesture_data
gesture_dataset = Mydataset(all_imgs_path)
print(len(gesture_dataset))
# path迭代
species = ['none','paper','rock','scissors']
species_to_id = dict((c,i) for i,c in enumerate(species))
id_to_species = dict((v,k) for k,v in species_to_id.items())
all_labels = []
for img in all_imgs_path:
for i,c in enumerate(species):
if c in img:
all_labels.append(i)
transform = transforms.Compose([transforms.ToTensor()])
index = np.random.permutation(len(all_imgs_path))
all_imgs_path = np.array(all_imgs_path)[index]
all_labels = np.array(all_labels)[index]
p = int(len(all_imgs_path)*0.8)
x_train_ = all_imgs_path[:p]
y_train_ = all_labels[:p]
x_test_ = all_imgs_path[p:]
y_test_ = all_imgs_path[p:]
class MyDatasetpro(data.Dataset):
def __init__(self,img_paths ,labels, transform):
self.imgs = img_paths
self.labels = labels
self.transforms = transform
def __getitem__(self, index):
img = self.imgs[index]
label = self.labels[index]
pil_img = Image.open(img)
data = self.transforms(pil_img)
return data, label
def __len__(self):
return len(self.imgs)
x_train = MyDatasetpro(x_train_, y_train_, transform)
x_test = MyDatasetpro(x_test_, y_test_, transform)
y_train = data.DataLoader(x_train, batch_size=5,shuffle=True)
y_test = data.DataLoader(x_test, batch_size=5,shuffle=True)
# tensorflow训练
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=32,kernel_size=(5,5),padding='same'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(filters=96,kernel_size=(5,5),padding='same'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('relu'),
tf.keras.layers.AveragePooling2D(pool_size=(3,3),padding='valid'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(256,activation='relu',kernel_regularizer=tf.keras.regularizers.l1()),
tf.keras.layers.Dense(64,activation='relu',kernel_regularizer=tf.keras.regularizers.l1()),
tf.keras.layers.Dense(5,activation='softmax',kernel_regularizer=tf.keras.regularizers.l1())
])
model.compile(optimizer='nadam',loss="sparse_categoriical_crossentropy", metrics=["sparse_categorical_accuracy"])
# 由于神经层较多,cpu一次可能跑不了,断点续训
# checkpoint_save_path = './rps_data_sample/checkpoint.ckpt'
# if os.path.exists(checkpoint_save_path+'index'):
# print('load model')
# model.load_weights(checkpoint_save_path)
# cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
# save_weights_only=True,
# save_best_only=True)
history = model.fit(np.array(x_train),y_train,batch_size=32, epochs=1000, validation_data=(np.array(x_test),y_test),validation_freq=1)
model.summary()
file = open('/weights.txt','w')
for v in model.trainable_variables:
file.write(str(v.name)+'\n')
file.write(str(v.shape)+'\n')
file.write(str(v.numpy()+'\n'))
file.close()
# 可视化
acc=history.history['sparse_categorical_accuracy']
val_acc=history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss=history.history['val_loss']
plt.subplot(1,2,1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label="Validation Accuracy")
plt.title('Training and Validation Accuracy')
结果报setting an array element with a sequence. The requested array has an inhomogeneous shape after 2 dimensions. The detected shape was (400, 2) + inhomogeneous part.
怎么处理啊