douyan1972 2017-09-22 14:20
浏览 187
已采纳

保存使用keras训练的TF模型,然后在Go中进行评估

I'm trying to setup a classical MNIST challenge model with keras, then save the tensorflow graph and subsequently load it in Go and evaluate with some input. I've been following this article which supplies full code on github. Nils is using just tensorflow to setup the comp.graph but I would like to use keras. I managd to save the model the same way as he does

model:

   model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
                 activation='relu',
                 input_shape=(28,28,1), name="inputNode"))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax', name="inferNode"))

which runs ok, trains and evaluates and then saving as posted above:

builder = tf.saved_model.builder.SavedModelBuilder("mnistmodel_my")
# GOLANG note that we must tag our model so that we can retrieve it at inference-time
builder.add_meta_graph_and_variables(sess, ["serve"])
builder.save()

Which I then try to evaluate as :

result, runErr := model.Session.Run(
        map[tf.Output]*tf.Tensor{
            model.Graph.Operation("inputNode").Output(0): tensor,
        },
        []tf.Output{
            model.Graph.Operation("inferNode").Output(0),
        },
        nil,
    )

In Go I follow the example but when evaluating, I get:

    panic: nil-Operation. If the Output was created with a Scope object, see Scope.Err() for details.

goroutine 1 [running]:
github.com/tensorflow/tensorflow/tensorflow/go.Output.c(0x0, 0x0, 0x0, 0x0)
    /Users/air/go/src/github.com/tensorflow/tensorflow/tensorflow/go/operation.go:119 +0xbb
github.com/tensorflow/tensorflow/tensorflow/go.newCRunArgs(0xc42006e210, 0xc420047ef0, 0x1, 0x1, 0x0, 0x0, 0x0, 0xc4200723c8)
    /Users/air/go/src/github.com/tensorflow/tensorflow/tensorflow/go/session.go:307 +0x22d
github.com/tensorflow/tensorflow/tensorflow/go.(*Session).Run(0xc420078060, 0xc42006e210, 0xc420047ef0, 0x1, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, ...)
    /Users/air/go/src/github.com/tensorflow/tensorflow/tensorflow/go/session.go:85 +0x153
main.main()
    /Users/air/PycharmProjects/GoTensor/custom.go:36 +0x341
exit status 2

Since it says nil-Operation I think I might have incorrectly labelled the nodes. But I don't know which other nodes should I then label?

Many thanks!!!

  • 写回答

1条回答 默认 最新

  • dpoppu4300 2017-09-22 16:23
    关注

    Your code should work fine. You're right about the cause of the nil-operation.

    You just have to find the complete node name of your "inputNode".

    From python, after your model definition, you can loop over the graph nodes and look for the complete name, in that way:

    for n in sess.graph.as_graph_def().node:
        if "inputNode" in n.name:
            print(n.name)
    

    Once you got the complete name, you can use it in your Go program.

    Also, I suggest you to use a more complete and easy to use wrapper around the tensorflow API: tfgo

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

悬赏问题

  • ¥15 vhdl+MODELSIM
  • ¥20 simulink中怎么使用solve函数?
  • ¥30 dspbuilder中使用signalcompiler时报错Error during compilation: Fitter failed,求解决办法
  • ¥15 gwas 分析-数据质控之过滤稀有突变中出现的问题
  • ¥15 没有注册类 (异常来自 HRESULT: 0x80040154 (REGDB_E_CLASSNOTREG))
  • ¥15 知识蒸馏实战博客问题
  • ¥15 用PLC设计纸袋糊底机送料系统
  • ¥15 simulink仿真中dtc控制永磁同步电机如何控制开关频率
  • ¥15 用C语言输入方程怎么
  • ¥15 网站显示不安全连接问题