唠嗑24 2017-04-28 11:35 采纳率: 0%
浏览 1940

Alexnet分类问题,程序输入不匹配

用Alexnet网络做一个二分类问题,输入的图片也是227乘227的彩图。遇到了如下的问题说是形状不匹配图片说明也不知道怎么解决,求大神帮忙
from future import division, print_function, absolute_import
import os
import random
from PIL import Image
import numpy as np
import tflearn
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d, upsample_2d
from tflearn.layers.normalization import local_response_normalization
from tflearn.layers.estimator import regression

#import tflearn.datasets.oxflower17 as oxflower17
#X, Y = oxflower17.load_data(one_hot=True, resize_pics=(227, 227))
np.random.seed(170)
def load_data(DataDir):

data = np.empty((170,227,227,3),dtype="float32") #37是图片个数,800*800为图片大小,3是图片通道数
label = np.empty((170,),dtype="int")

imgs = os.listdir(DataDir)

num = len(imgs)

for i in range(num):

img = Image.open(DataDir+imgs[i])

arr = np.asarray(img,dtype="float32")

data[i,:,:,:] = arr

if i<53:

label[i] = int(0) #o是无缺陷类,共170张图,第0-52张为无缺陷类。
else:

label[i] = int(1)

data /= np.max(data) #这两行是数据归一化,不用管
data -= np.mean(data)

return data,label

data,label=load_data('C:/Users/Administrator/Desktop/cnntest/picture/')
index = [i for i in range(len(data))]

random.shuffle(index) #之前做标签时,数据是按类排的,这边直接打乱顺序。所以标签还是一一对应的。
data = data[index]

label = labelindex = (data[0:119],data[120:]) #traindata包括了两类数据,不用分开来输入。7:3训练集:预测集
(TrainLabel,TestLabel) = (label[0:119],label[120:])

Building 'AlexNet'

network = input_data(shape=[None, 227, 227, 3])
network = conv_2d(network, 96, 11, strides=4, activation='relu') #96为滤波器个数,11为滤波器大小

network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)

network = conv_2d(network, 256, 5, activation='relu', group=2)

network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)

network = conv_2d(network, 384, 3, activation='relu')

network = conv_2d(network, 384, 3, activation='relu')

network = conv_2d(network, 256, 3, activation='relu')

network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)

network = upsample_2d(network,2,name='upsample')

network = fully_connected(network, 4096, activation='relu')
network = dropout(network, 0.5)

network = fully_connected(network, 4096, activation='relu')
network = dropout(network, 0.5)
#net = tflearn.global_avg_pool(net)
network = fully_connected(network, 2, activation='softmax')
network = regression(network, optimizer='momentum',
loss='categorical_crossentropy',
learning_rate=0.001)

Training

print('Training ------------')
model = tflearn.DNN(network, checkpoint_path='model_alexnet', max_checkpoints=1, tensorboard_verbose=2)
model.fit(TrainData,TrainLabel, n_epoch=5, validation_set=0.1, shuffle=True,
show_metric=True, batch_size=64, snapshot_step=200,
snapshot_epoch=False, run_id='CNNPOTATO')

model.save('CNNPOTATO.model')
model.load('CNNPOTATO.model')

print('\nTesting ------------')

Evaluate the model with the metrics we defined earlier

loss, accuracy = model.evaluate(TestData, TestLabel)
print('\ntest loss: ', loss)
print('\ntest accuracy: ', accuracy)
#print(model.predict([Y[1]]))


  • 写回答

2条回答 默认 最新

  • threenewbee 2017-04-28 15:54
    关注
    评论

报告相同问题?

悬赏问题

  • ¥88 找成都本地经验丰富懂小程序开发的技术大咖
  • ¥15 如何处理复杂数据表格的除法运算
  • ¥15 如何用stc8h1k08的片子做485数据透传的功能?(关键词-串口)
  • ¥15 有兄弟姐妹会用word插图功能制作类似citespace的图片吗?
  • ¥200 uniapp长期运行卡死问题解决
  • ¥15 请教:如何用postman调用本地虚拟机区块链接上的合约?
  • ¥15 为什么使用javacv转封装rtsp为rtmp时出现如下问题:[h264 @ 000000004faf7500]no frame?
  • ¥15 乘性高斯噪声在深度学习网络中的应用
  • ¥15 关于docker部署flink集成hadoop的yarn,请教个问题 flink启动yarn-session.sh连不上hadoop,这个整了好几天一直不行,求帮忙看一下怎么解决
  • ¥15 深度学习根据CNN网络模型,搭建BP模型并训练MNIST数据集