今天尝试根据网上教程写了手写识别的实例,发现打印权重weights的时候,随着迭代次数权重不会改变。但其他变量如bias是随迭代而改变的。请问大家这个权重为什么不会更新呢?
import tensorflow as tf
import input_data
tf.compat.v1.disable_eager_execution()
def full_connection():
# 1. 准备数据
mnist = input_data.read_data_sets("./mnist_data", one_hot=True)
x = tf.compat.v1.placeholder(dtype=tf.float32, shape=(None, 784))
y_true = tf.compat.v1.placeholder(dtype=tf.float32, shape=(None, 10))
# 2. 构建模型
weights = tf.compat.v1.Variable(initial_value=tf.compat.v1.random_normal(shape=[784, 10]), name="weights")
bias = tf.compat.v1.Variable(initial_value=tf.compat.v1.random_normal(shape=[10]), name="bias")
y_predict = tf.matmul(x, weights) + bias
# 3. 构造损失函数
error = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict)
error = tf.reduce_mean(error)
# 4. 优化损失
optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.03).minimize(error)
# 5. 准确率计算
equal_list = tf.equal(tf.argmax(y_true, 1), tf.argmax(y_predict, 1))
accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))
# 6. 增加变量显示
tf.compat.v1.summary.scalar("error", error)
tf.compat.v1.summary.histogram("weights", weights)
tf.compat.v1.summary.histogram("bias", bias)
merged = tf.compat.v1.summary.merge_all()
saver = tf.compat.v1.train.Saver()
with tf.compat.v1.Session() as sess:
init = tf.compat.v1.global_variables_initializer()
sess.run(init)
image, label = mnist.train.next_batch(100)
file_writer = tf.compat.v1.summary.FileWriter("./log", graph=sess.graph)
for i in range(1000):
_, loss, accuracy_value = sess.run([optimizer, error, accuracy], feed_dict={x: image, y_true: label})
print("第%d次训练,损失为%f,准确率为%f" % (i+1, loss, accuracy_value))
print(weights.eval())
summary_result = sess.run(merged, feed_dict={x: image, y_true: label})
file_writer.add_summary(summary_result, i)
if i % 10 ==0:
saver.save(sess, "./log/model/model-1.ckpt")
return None
if __name__ == "__main__":
full_connection()
第6次训练,损失为16.070484,准确率为0.090000
权重为:
[[-6.5291589e-01 -1.6071169e+00 2.1290806e-01 ... -2.1784568e+00
-1.2091833e+00 -8.6140698e-01]
[ 1.2597474e+00 -1.4910623e-01 -3.8707003e-01 ... 1.0214819e-01
5.9081298e-01 -6.1707306e-01]
[-1.4509798e+00 -1.9147521e-02 1.6300718e+00 ... -2.3356327e-01
2.2199700e+00 1.0290977e+00]
...
[ 8.4160960e-01 -4.5380855e-01 4.1801175e-01 ... 3.1385022e-01
-1.0918922e+00 1.6047851e+00]
[ 1.5874079e-01 1.3555746e+00 7.1338147e-01 ... 5.6245077e-01
1.2595990e+00 -8.3930290e-01]
[ 7.4030966e-01 2.1435671e+00 -4.2898212e-02 ... -5.6554389e-01
7.3194194e-01 -7.8909863e-05]]
第112次训练,损失为6.315756,准确率为0.270000
权重为:
[[-6.5291589e-01 -1.6071169e+00 2.1290806e-01 ... -2.1784568e+00
-1.2091833e+00 -8.6140698e-01]
[ 1.2597474e+00 -1.4910623e-01 -3.8707003e-01 ... 1.0214819e-01
5.9081298e-01 -6.1707306e-01]
[-1.4509798e+00 -1.9147521e-02 1.6300718e+00 ... -2.3356327e-01
2.2199700e+00 1.0290977e+00]
...
[ 8.4160960e-01 -4.5380855e-01 4.1801175e-01 ... 3.1385022e-01
-1.0918922e+00 1.6047851e+00]
[ 1.5874079e-01 1.3555746e+00 7.1338147e-01 ... 5.6245077e-01
1.2595990e+00 -8.3930290e-01]
[ 7.4030966e-01 2.1435671e+00 -4.2898212e-02 ... -5.6554389e-01
7.3194194e-01 -7.8909863e-05]]
[[-6.5291589e-01 -1.6071169e+00 2.1290806e-01 ... -2.1784568e+00
-1.2091833e+00 -8.6140698e-01]
[ 1.2597474e+00 -1.4910623e-01 -3.8707003e-01 ... 1.0214819e-01
5.9081298e-01 -6.1707306e-01]
[-1.4509798e+00 -1.9147521e-02 1.6300718e+00 ... -2.3356327e-01
2.2199700e+00 1.0290977e+00]
...
[ 8.4160960e-01 -4.5380855e-01 4.1801175e-01 ... 3.1385022e-01
-1.0918922e+00 1.6047851e+00]
[ 1.5874079e-01 1.3555746e+00 7.1338147e-01 ... 5.6245077e-01
1.2595990e+00 -8.3930290e-01]
[ 7.4030966e-01 2.1435671e+00 -4.2898212e-02 ... -5.6554389e-01
7.3194194e-01 -7.8909863e-05]]