唠嗑24 2017-04-28 11:35 采纳率: 0%
浏览 1940

Alexnet分类问题,程序输入不匹配

用Alexnet网络做一个二分类问题,输入的图片也是227乘227的彩图。遇到了如下的问题说是形状不匹配图片说明也不知道怎么解决,求大神帮忙
from future import division, print_function, absolute_import
import os
import random
from PIL import Image
import numpy as np
import tflearn
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d, upsample_2d
from tflearn.layers.normalization import local_response_normalization
from tflearn.layers.estimator import regression

#import tflearn.datasets.oxflower17 as oxflower17
#X, Y = oxflower17.load_data(one_hot=True, resize_pics=(227, 227))
np.random.seed(170)
def load_data(DataDir):

data = np.empty((170,227,227,3),dtype="float32") #37是图片个数,800*800为图片大小,3是图片通道数
label = np.empty((170,),dtype="int")

imgs = os.listdir(DataDir)

num = len(imgs)

for i in range(num):

img = Image.open(DataDir+imgs[i])

arr = np.asarray(img,dtype="float32")

data[i,:,:,:] = arr

if i<53:

label[i] = int(0) #o是无缺陷类,共170张图,第0-52张为无缺陷类。
else:

label[i] = int(1)

data /= np.max(data) #这两行是数据归一化,不用管
data -= np.mean(data)

return data,label

data,label=load_data('C:/Users/Administrator/Desktop/cnntest/picture/')
index = [i for i in range(len(data))]

random.shuffle(index) #之前做标签时,数据是按类排的,这边直接打乱顺序。所以标签还是一一对应的。
data = data[index]

label = labelindex = (data[0:119],data[120:]) #traindata包括了两类数据,不用分开来输入。7:3训练集:预测集
(TrainLabel,TestLabel) = (label[0:119],label[120:])

Building 'AlexNet'

network = input_data(shape=[None, 227, 227, 3])
network = conv_2d(network, 96, 11, strides=4, activation='relu') #96为滤波器个数,11为滤波器大小

network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)

network = conv_2d(network, 256, 5, activation='relu', group=2)

network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)

network = conv_2d(network, 384, 3, activation='relu')

network = conv_2d(network, 384, 3, activation='relu')

network = conv_2d(network, 256, 3, activation='relu')

network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)

network = upsample_2d(network,2,name='upsample')

network = fully_connected(network, 4096, activation='relu')
network = dropout(network, 0.5)

network = fully_connected(network, 4096, activation='relu')
network = dropout(network, 0.5)
#net = tflearn.global_avg_pool(net)
network = fully_connected(network, 2, activation='softmax')
network = regression(network, optimizer='momentum',
loss='categorical_crossentropy',
learning_rate=0.001)

Training

print('Training ------------')
model = tflearn.DNN(network, checkpoint_path='model_alexnet', max_checkpoints=1, tensorboard_verbose=2)
model.fit(TrainData,TrainLabel, n_epoch=5, validation_set=0.1, shuffle=True,
show_metric=True, batch_size=64, snapshot_step=200,
snapshot_epoch=False, run_id='CNNPOTATO')

model.save('CNNPOTATO.model')
model.load('CNNPOTATO.model')

print('\nTesting ------------')

Evaluate the model with the metrics we defined earlier

loss, accuracy = model.evaluate(TestData, TestLabel)
print('\ntest loss: ', loss)
print('\ntest accuracy: ', accuracy)
#print(model.predict([Y[1]]))


  • 写回答

2条回答

  • threenewbee 2017-04-28 15:54
    关注
    评论

报告相同问题?

悬赏问题

  • ¥15 HFSS 中的 H 场图与 MATLAB 中绘制的 B1 场 部分对应不上
  • ¥15 如何在scanpy上做差异基因和通路富集?
  • ¥20 关于#硬件工程#的问题,请各位专家解答!
  • ¥15 关于#matlab#的问题:期望的系统闭环传递函数为G(s)=wn^2/s^2+2¢wn+wn^2阻尼系数¢=0.707,使系统具有较小的超调量
  • ¥15 FLUENT如何实现在堆积颗粒的上表面加载高斯热源
  • ¥30 截图中的mathematics程序转换成matlab
  • ¥15 动力学代码报错,维度不匹配
  • ¥15 Power query添加列问题
  • ¥50 Kubernetes&Fission&Eleasticsearch
  • ¥15 報錯:Person is not mapped,如何解決?