哎呀100 2018-11-13 01:41 采纳率: 0%
浏览 2531

请大神指点,VGG-16训练时权重不更新,怎么回事??

用tensorflow训练VGG-16时,权重不发生变化是怎么回事??

 import tensorflow as tf
import scipy.io as sio 
#import numpy as np
import matplotlib.image as mpimg
import pickle as cp
a = list(sio.loadmat('dataset/corel5k_train_annot.mat')['annot1'])
b = list(sio.loadmat('dataset/corel5k_test_annot.mat')['annot2'])
def get_batch(image, label, batch_size, now_batch, total_batch):

    if now_batch < total_batch-1:
        image_batch = image[now_batch*batch_size:(now_batch+1)*batch_size]
        label_batch = label[now_batch*batch_size:(now_batch+1)*batch_size]
    else:
        image_batch = image[now_batch*batch_size:]
        label_batch 


= label[now_batch*batch_size:]
    #image_batch = tf.reshape(image_batch,[-1,128,128,3])

    return  image_batch,label_batch

train_img = []
with open('image/corel5k_train_list.txt') as f:
    for i in f.readlines():
        train_img += [mpimg.imread('image/%s.jpg'%i.strip())]
    cp.dump(train_img,open("train.pkl","wb"))
test_img = []
with open('image/corel5k_test_list.txt') as f:
    for i in f.readlines():
        test_img += [mpimg.imread('image/%s.jpg'%i.strip())]
    cp.dump(test_img,open("test.pkl","wb"))#一种保存列表的方式

x = tf.placeholder(tf.float32,[None,128,128,3])
y_ = tf.placeholder(tf.float32, shape=[None, 260])
#loss=tf.Variable(tf.constant(0.0))
#第一
W1 = tf.Variable(tf.truncated_normal([3, 3, 3, 64], mean=0.0,stddev=1.0))
b1 = tf.Variable(tf.constant(0.0, shape=[64]))
h1 =  tf.nn.relu(tf.nn.conv2d(x, W1, strides=[1, 1, 1, 1], padding='SAME') + b1)
#p1 = tf.nn.max_pool(h1, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')

W2 = tf.Variable(tf.truncated_normal([3, 3, 64, 64], mean=0.0,stddev=1.0))
b2 = tf.Variable(tf.constant(0.0, shape=[64]))
h2 = tf.nn.relu(tf.nn.conv2d(h1, W2, strides=[1, 1, 1, 1], padding='SAME') + b2)
p2 = tf.nn.max_pool(h2, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')
#第二
W3 = tf.Variable(tf.truncated_normal([3, 3, 64, 128], mean=0.0,stddev=1.0))
b3 = tf.Variable(tf.constant(0.0, shape=[128]))
h3 = tf.nn.relu(tf.nn.conv2d(p2, W3, strides=[1, 1, 1, 1], padding='SAME') + b3)
#p3 = tf.nn.max_pool(h2, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')

W4 = tf.Variable(tf.truncated_normal([3, 3, 128, 128], mean=0.0,stddev=1.0))
b4 = tf.Variable(tf.constant(0.0, shape=[128]))
h4 = tf.nn.relu(tf.nn.conv2d(h3, W4, strides=[1, 1, 1, 1], padding='SAME') + b4)
p4 = tf.nn.max_pool(h4, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')
#第三
W5 = tf.Variable(tf.truncated_normal([3, 3, 128, 256], mean=0.0,stddev=1.0))
b5 = tf.Variable(tf.constant(0.0, shape=[256]))
h5 = tf.nn.relu(tf.nn.conv2d(p4, W5, strides=[1, 1, 1, 1], padding='SAME') + b5)

W6 = tf.Variable(tf.truncated_normal([3, 3, 256, 256], mean=0.0,stddev=1.0))
b6 = tf.Variable(tf.constant(0.0, shape=[256]))
h6 = tf.nn.relu(tf.nn.conv2d(h5, W6, strides=[1, 1, 1, 1], padding='SAME') + b6)

W7 = tf.Variable(tf.truncated_normal([3, 3, 256, 256], mean=0.0,stddev=1.0))
b7 = tf.Variable(tf.constant(0.0, shape=[256]))
h7 = tf.nn.relu(tf.nn.conv2d(h6, W7, strides=[1, 1, 1, 1], padding='SAME') + b7)
p7 = tf.nn.max_pool(h7, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')

#第四
W8 = tf.Variable(tf.truncated_normal([3, 3, 256, 512], mean=0.0,stddev=1.0))
b8 = tf.Variable(tf.constant(0.0, shape=[512]))
h8 = tf.nn.relu(tf.nn.conv2d(p7, W8, strides=[1, 1, 1, 1], padding='SAME') + b8)

W9 = tf.Variable(tf.truncated_normal([3, 3, 512, 512], mean=0.0,stddev=1.0))
b9 = tf.Variable(tf.constant(0.0, shape=[512]))
h9 = tf.nn.relu(tf.nn.conv2d(h8, W9, strides=[1, 1, 1, 1], padding='SAME') + b9)

W10 = tf.Variable(tf.truncated_normal([3, 3, 512, 512], mean=0.0,stddev=1.0))
b10 = tf.Variable(tf.constant(0.0, shape=[512]))
h10 = tf.nn.relu(tf.nn.conv2d(h9, W10, strides=[1, 1, 1, 1], padding='SAME') + b10)
p10 = tf.nn.max_pool(h10, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')

#第五
W11 = tf.Variable(tf.truncated_normal([3, 3, 512, 512], mean=0.0,stddev=1.0))
b11 = tf.Variable(tf.constant(0.0, shape=[512]))
h11 = tf.nn.relu(tf.nn.conv2d(p10, W11, strides=[1, 1, 1, 1], padding='SAME') + b11)

W12 = tf.Variable(tf.truncated_normal([3, 3, 512, 512], mean=0.0,stddev=1.0))
b12 = tf.Variable(tf.constant(0.0, shape=[512]))
h12 = tf.nn.relu(tf.nn.conv2d(h11, W12, strides=[1, 1, 1, 1], padding='SAME') + b12)

W13 = tf.Variable(tf.truncated_normal([3, 3, 512, 512], mean=0.0,stddev=1.0))
b13 = tf.Variable(tf.constant(0.0, shape=[512]))
h13 = tf.nn.relu(tf.nn.conv2d(h12, W13, strides=[1, 1, 1, 1], padding='SAME') + b13)
p13 = tf.nn.max_pool(h13, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')

W_fc1 = tf.Variable(tf.truncated_normal([4*4*512, 4096], mean=0.0,stddev=1.0))
b_fc1 = tf.Variable(tf.constant(0.0, shape=[4096]))
h_pool2_flat = tf.reshape(p13, [-1, 4*4*512])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

keep_prob1 = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob1)

W_fc2 = tf.Variable(tf.truncated_normal([4096, 4096], mean=0.0,stddev=1.0))
b_fc2 = tf.Variable(tf.constant(0.0, shape=[4096]))
h_fc2 = tf.nn.relu(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

keep_prob2 = tf.placeholder(tf.float32)
h_fc2_drop = tf.nn.dropout(h_fc2, keep_prob2)

W_fc3 = tf.Variable(tf.truncated_normal([4096, 260], mean=0.0,stddev=1.0))
b_fc3 = tf.Variable(tf.constant(0.0, shape=[260]))
y_conv=tf.nn.softmax(tf.matmul(h_fc2_drop, W_fc3) + b_fc3)


loss = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv + 1e-10), reduction_indices=[1]))

train_step = tf.train.AdamOptimizer(0.1).minimize(loss)



with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for v in range(100):
        for r in range(90):
            image_batch,label_batch = get_batch(train_img,a,50,r,90)
            sess.run(train_step,feed_dict={x:image_batch,y_:label_batch,keep_prob1: 0.5,keep_prob2: 0.5})
        print(sess.run(W_fc3))


    print("*")
    test_batch,testlabel_batch = get_batch(test_img,b,50,0,2)
    print(sess.run(loss,feed_dict={x:test_batch,y_:testlabel_batch,keep_prob1: 1,keep_prob2: 1}))
  • 写回答

1条回答 默认 最新

  • zqbnqsdsmd 2018-11-15 11:53
    关注
    评论

报告相同问题?

悬赏问题

  • ¥15 vue3+express部署到nginx
  • ¥20 搭建pt1000三线制高精度测温电路
  • ¥15 使用Jdk8自带的算法,和Jdk11自带的加密结果会一样吗,不一样的话有什么解决方案,Jdk不能升级的情况
  • ¥15 画两个图 python或R
  • ¥15 在线请求openmv与pixhawk 实现实时目标跟踪的具体通讯方法
  • ¥15 八路抢答器设计出现故障
  • ¥15 opencv 无法读取视频
  • ¥15 按键修改电子时钟,C51单片机
  • ¥60 Java中实现如何实现张量类,并用于图像处理(不运用其他科学计算库和图像处理库))
  • ¥20 5037端口被adb自己占了