weixin_41138872 2018-03-02 05:51 采纳率: 66.7%
浏览 2125
已采纳

python测试集结果调取问题

test_acc = sess.run(accr,feed_dict=feeds_test)
这个语句是用来调出测试准确率的,如何才能调出测试期间对每一个样本的预测数值????

x=tf.placeholder("float", [None,784])
#placeholder 占位,不赋给x实际值,784 像素值, None无穷样本
y=tf.placeholder("float", [None,10])
#10个分类目标
W=tf.Variable(tf.zeros([784,10]))
b=tf.Variable(tf.zeros([10]))
#tf.zeros 初始化

actv= tf.nn.softmax(tf.matmul(x,W)+b) #cost function
cost=tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv), reduction_indices=1))
learning_rate=0.01
optm= tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
print ('1')

pred=tf.equal(tf.argmax(actv, 1), tf.argmax(y, 1))
#ACCURACY
accr=tf.reduce_mean(tf.cast(pred,"float"))
#INITIALIZER
init=tf.global_variables_initializer()

training_epochs = 50 #所有样本迭代次数=50
batch_size = 100 #每次迭代用多少样本
display_step = 5 #展示
sess=tf.Session()
sess.run(init) #跑初始化
for epoch in range (training_epochs):
avg_cost=0
num_batch=int(mnist.train.num_examples/batch_size)
for i in range (num_batch):
batch_xs, batch_ys= mnist.train.next_batch(batch_size) #一步一步的往下找
sess.run(optm, feed_dict={x: batch_xs, y: batch_ys})
feeds={x:batch_xs, y: batch_ys}
avg_cost += sess.run (cost, feed_dict=feeds)/num_batch
#display
if epoch % display_step == 0:
feeds_train = {x: batch_xs, y: batch_ys}
feeds_test = {x: mnist.test.images, y: mnist.test.labels}
train_acc = sess.run(accr, feed_dict=feeds_train) #feed_dict 针对place holder占位
test_acc = sess.run(accr,feed_dict=feeds_test)
print ("Epoch: %03d/%03d cost: %.9f train_acc: %.3f test_acc: %.3f"
% (epoch, training_epochs, avg_cost, train_acc, test_acc))

  • 写回答

2条回答 默认 最新

  • violin_1229 2018-03-03 02:30
    关注

    简单的看了下,把sess.run(accr,feed_dict=feeds_test)改成sess.run([accr,actv],feed_dict=feeds_test)应该可以

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

悬赏问题

  • ¥30 Matlab打开默认名称带有/的光谱数据
  • ¥50 easyExcel模板 动态单元格合并列
  • ¥15 res.rows如何取值使用
  • ¥15 在odoo17开发环境中,怎么实现库存管理系统,或独立模块设计与AGV小车对接?开发方面应如何设计和开发?请详细解释MES或WMS在与AGV小车对接时需完成的设计和开发
  • ¥15 CSP算法实现EEG特征提取,哪一步错了?
  • ¥15 游戏盾如何溯源服务器真实ip?需要30个字。后面的字是凑数的
  • ¥15 vue3前端取消收藏的不会引用collectId
  • ¥15 delphi7 HMAC_SHA256方式加密
  • ¥15 关于#qt#的问题:我想实现qcustomplot完成坐标轴
  • ¥15 下列c语言代码为何输出了多余的空格