import tensorflow as tf
from matplotlib import pyplot as plt
import os
import numpy as np
import glob
from torch.utils import data
from PIL import Image
from torchvision import transforms
# 数据集内图片的验证
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
dataset_path = "rps_data_sample"
labels = []
for i in os.listdir(dataset_path):
if os.path.isdir(os.path.join(dataset_path, i)):
labels.append(i)
NUM_EXAMPLES = 5
for label in labels:
label_dir = os.path.join(dataset_path, label)
example_filenames = os.listdir(label_dir)[:NUM_EXAMPLES]
fig, axs = plt.subplots(1, 5, figsize=(10, 2))
for i in range(NUM_EXAMPLES):
axs[i].imshow(plt.imread(os.path.join(label_dir, example_filenames[i])))
axs[i].get_xaxis().set_visible(False)
axs[i].get_yaxis().set_visible(False)
fig.suptitle(f'Showing {NUM_EXAMPLES} examples for {label}')
# 创建data.Dataset子类Mydataset
class Mydataset(data.Dataset):
def __init__(self, root):
self.imgs_path = root
def __getitem__(self, index):
img_path = self.imgs_path[index]
return img_path
def __len__(self):
return len(self.imgs_path)
# glob遍历数据路径
all_imgs_path = glob.glob('rps_data_sample\\*\\*.jpg')
for var in all_imgs_path:
print(var)
# 建立gesture_data
gesture_dataset = Mydataset(all_imgs_path)
print(len(gesture_dataset))
# path迭代
species = ['none', 'paper', 'rock', 'scissors']
species_to_id = dict((c, i) for i, c in enumerate(species))
id_to_species = dict((v, k) for k, v in species_to_id.items())
all_labels = []
for img in all_imgs_path:
for i, c in enumerate(species):
if c in img:
all_labels.append(i)
transform = transforms.Compose([
transforms.Resize((224, 224)), # 统一图像尺寸
transforms.ToTensor()
])
index = np.random.permutation(len(all_imgs_path))
all_imgs_path = np.array(all_imgs_path)[index]
all_labels = np.array(all_labels)[index]
p = int(len(all_imgs_path) * 0.8)
x_train_ = all_imgs_path[:p]
y_train_ = all_labels[:p]
x_test_ = all_imgs_path[p:]
y_test_ = all_imgs_path[p:]
class MyDatasetpro(data.Dataset):
def __init__(self, img_paths, labels, transform):
self.imgs = img_paths
self.labels = labels
self.transforms = transform
def __getitem__(self, index):
img = self.imgs[index]
label = self.labels[index]
pil_img = Image.open(img).convert('RGB') # 确保图像为RGB格式
data = self.transforms(pil_img)
return data, label
def __len__(self):
return len(self.imgs)
x_train = MyDatasetpro(x_train_, y_train_, transform)
x_test = MyDatasetpro(x_test_, y_test_, transform)
# 使用TensorFlow的数据处理功能
def tf_dataset(dataset, batch_size):
def generator():
for img, label in dataset:
yield img.numpy(), label
ds = tf.data.Dataset.from_generator(generator, output_types=(tf.float32, tf.int32),
output_shapes=((224, 224,3), ()))
ds = ds.batch(batch_size)
return ds
batch_size = 32
train_ds = tf_dataset(x_train, batch_size)
test_ds = tf_dataset(x_test, batch_size)
# tensorflow训练
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=32, kernel_size=(5, 5), padding='same', input_shape=(224, 224, 3)),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(filters=96, kernel_size=(5, 5), padding='same'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('relu'),
tf.keras.layers.AveragePooling2D(pool_size=(3, 3), padding='valid'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(256, activation='relu', kernel_regularizer=tf.keras.regularizers.l1()),
tf.keras.layers.Dense(64, activation='relu', kernel_regularizer=tf.keras.regularizers.l1()),
tf.keras.layers.Dense(5, activation='softmax', kernel_regularizer=tf.keras.regularizers.l1())
])
model.compile(optimizer='nadam', loss="sparse_categorical_crossentropy", metrics=["sparse_categorical_accuracy"])
history = model.fit(train_ds, epochs=10000, validation_data=test_ds, validation_freq=1)
model.summary()
file = open('/weights.txt', 'w')
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.numpy()) + '\n')
file.close()
# 可视化
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label="Validation Accuracy")
plt.title('Training and Validation Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label="Validation Loss")
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
这段代码中报generator` yielded an element of shape (3, 224, 224) where an element of shape (224, 224, 3) was expected.咋处理啊