douyan1972
2017-09-22 14:20 阅读 161

保存使用keras训练的TF模型,然后在Go中进行评估

I'm trying to setup a classical MNIST challenge model with keras, then save the tensorflow graph and subsequently load it in Go and evaluate with some input. I've been following this article which supplies full code on github. Nils is using just tensorflow to setup the comp.graph but I would like to use keras. I managd to save the model the same way as he does

model:

   model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
                 activation='relu',
                 input_shape=(28,28,1), name="inputNode"))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax', name="inferNode"))

which runs ok, trains and evaluates and then saving as posted above:

builder = tf.saved_model.builder.SavedModelBuilder("mnistmodel_my")
# GOLANG note that we must tag our model so that we can retrieve it at inference-time
builder.add_meta_graph_and_variables(sess, ["serve"])
builder.save()

Which I then try to evaluate as :

result, runErr := model.Session.Run(
        map[tf.Output]*tf.Tensor{
            model.Graph.Operation("inputNode").Output(0): tensor,
        },
        []tf.Output{
            model.Graph.Operation("inferNode").Output(0),
        },
        nil,
    )

In Go I follow the example but when evaluating, I get:

    panic: nil-Operation. If the Output was created with a Scope object, see Scope.Err() for details.

goroutine 1 [running]:
github.com/tensorflow/tensorflow/tensorflow/go.Output.c(0x0, 0x0, 0x0, 0x0)
    /Users/air/go/src/github.com/tensorflow/tensorflow/tensorflow/go/operation.go:119 +0xbb
github.com/tensorflow/tensorflow/tensorflow/go.newCRunArgs(0xc42006e210, 0xc420047ef0, 0x1, 0x1, 0x0, 0x0, 0x0, 0xc4200723c8)
    /Users/air/go/src/github.com/tensorflow/tensorflow/tensorflow/go/session.go:307 +0x22d
github.com/tensorflow/tensorflow/tensorflow/go.(*Session).Run(0xc420078060, 0xc42006e210, 0xc420047ef0, 0x1, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, ...)
    /Users/air/go/src/github.com/tensorflow/tensorflow/tensorflow/go/session.go:85 +0x153
main.main()
    /Users/air/PycharmProjects/GoTensor/custom.go:36 +0x341
exit status 2

Since it says nil-Operation I think I might have incorrectly labelled the nodes. But I don't know which other nodes should I then label?

Many thanks!!!

  • 点赞
  • 写回答
  • 关注问题
  • 收藏
  • 复制链接分享

1条回答 默认 最新

  • 已采纳
    dpoppu4300 dpoppu4300 2017-09-22 16:23

    Your code should work fine. You're right about the cause of the nil-operation.

    You just have to find the complete node name of your "inputNode".

    From python, after your model definition, you can loop over the graph nodes and look for the complete name, in that way:

    for n in sess.graph.as_graph_def().node:
        if "inputNode" in n.name:
            print(n.name)
    

    Once you got the complete name, you can use it in your Go program.

    Also, I suggest you to use a more complete and easy to use wrapper around the tensorflow API: tfgo

    点赞 评论 复制链接分享

相关推荐