张可盐 2021-07-01 15:23 采纳率: 0%
浏览 46
已结题

深度学习和参数相乘,矩阵维度不一致

目前在看深度学习入门(基于Python)这本书,运行的是里面的示例代码。 在训练完成模型后,对手写数字测试集取单一图片进行预测。在预测的时候,用到了y = Wx + b。wx是矩阵相乘,但是现在有3个W,3个b,上一次的输出的维度不能和下一个W的维度对应,是我哪方面有问题吗,欢迎各位指教

img = x_train[0].reshape(1,1,28,28)
x = im2col(img, 5, 5)  # (48400, 25)

def predict(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network["W3"]  # (30, 1, 5, 5)(363000, 100)(100, 2)
    b1, b2, b3 = network['b1'], network['b2'], network['b3']  # (30,)(100, )(2, )
    col_W1 = W1.reshape(30, -1).T  # (25, 30)
    # print(x.shape, col_W1.shape)
    a1 = np.dot(x, col_W1) + b1
    z1 = sigmoid(a1)  # (48400,30)
    # 运行这里时候  (48400,30) * (363000, 100) 我应该怎么办?变形或者什么?
    a2 = np.dot(z1, W2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, W3) + b3
    y = softmax(a3)
    return y
  • 写回答

1条回答 默认 最新

  • 爱晚乏客游 2021-07-01 16:41
    关注

    将图片按照你的模型流程走啊,走到了参数这里不就是可以相乘了。不然不知道你这些参数维度含义,直接扩展维度相乘,但是结果没意义

    评论

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 7月18日

悬赏问题

  • ¥15 游戏盾如何溯源服务器真实ip?
  • ¥15 Mac版Fiddler Everywhere4.0.1提示强制更新
  • ¥15 android 集成sentry上报时报错。
  • ¥50 win10链接MySQL
  • ¥35 跳过我的世界插件ip验证
  • ¥15 抖音看过的视频,缓存在哪个文件
  • ¥15 自定义损失函数报输入参数的数目不足
  • ¥15 如果我想学习C大家有是的的资料吗
  • ¥15 根据文件名称对文件进行排序
  • ¥15 deploylinux的ubuntu系统无法成功安装使用MySQL❓