穆穆青风至 2022-09-22 14:57 采纳率: 97.4%
浏览 81
已结题

tensorflow中的.numpy()函数是啥


        with tf.GradientTape() as tape:  # with结构记录梯度信息
            y = tf.matmul(x_train, w1) + b1  # 神经网络乘加运算
            y = tf.nn.softmax(y)  # 使输出y符合概率分布(此操作后与独热码同量级,可相减求loss)
            y_ = tf.one_hot(y_train, depth=3)  # 将标签值转换为独热码格式,方便计算loss和accuracy
            loss = tf.reduce_mean(tf.square(y_ - y))  # 采用均方误差损失函数mse = mean(sum(y-out)^2)
            print('loss 类型:\t',type(loss))
            loss_all += loss.numpy()  # 将每个step计算出的loss累加,为后续求loss平均值提供数据,这样计算的loss更准确

最后一行,loss.numpy()是干了啥,这是个啥函数,有啥用

  • 写回答

4条回答 默认 最新

  • 万里鹏程转瞬至 人工智能领域优质创作者 2022-09-22 15:28
    关注

    因为这里的loss是tensor,所以调用loss.numpy()将其转换为numpy数组。这里最主要的原因是loss_all一开始不是tensor类型,如果loss_all是tensor类型,则可以不用loss.numpy()。进行类型转换,只是为了确保语法正确

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(3条)

报告相同问题?

问题事件

  • 系统已结题 10月6日
  • 已采纳回答 9月28日
  • 创建了问题 9月22日

悬赏问题

  • ¥20 关于#c++#的问题:水果店管理系统
  • ¥30 dbLinq最新版linq sqlite
  • ¥20 对D盘进行分盘之前没有将visual studio2022卸载掉,现在该如何下载回来
  • ¥15 完成虚拟机环境配置,还有安装kettle
  • ¥15 2024年全国大学生数据分析大赛A题:直播带货与电商产品的大数据分析 问题5. 请设计一份优惠券的投放策略,需要考虑优惠券的数量、优惠券的金额、投放时间段和投放商品种类等因素。求具体的python代码
  • ¥15 有人会搭建生鲜配送自营+平台的管理系统吗
  • ¥15 用matlab写代码
  • ¥30 motoradmin系统的多对多配置
  • ¥15 求组态王串口自定义通信配置方法或代码?
  • ¥15 实验 :UML2.0 结构建模