dongmi1864 2017-11-06 17:58
浏览 114
已采纳

如何在Golang中的文本上执行DL-RNN模型?

I have build RNN model in tensor-flow based on reddit/twitter conversations. I saved it in pb. Does anyone know how to pass raw text string through the model in golang and produce output?

modeldir := "/my_model.pb"

// Buffer input text
var buffer bytes.Buffer

args := os.Args[1:]

for _, arg := range args {
    buffer.WriteString(arg + " ")
}

inputText := buffer.String()

// Load the serialized GraphDef from a file.

model, err := ioutil.ReadFile(modeldir)
if err != nil {
    log.Fatal(err)
}
// Construct an in-memory graph from the serialized form.
graph := tf.NewGraph()
if err := graph.Import(model, ""); err != nil {
    log.Fatal(err)
}
// Create a session for inference over graph.
session, err := tf.NewSession(graph, nil)
if err != nil {
    log.Fatal(err)
}
defer session.Close()
  • 写回答

1条回答 默认 最新

  • dream3323 2017-11-06 18:40
    关注

    You can use tfgo to easily load into Go and use a trained tensorflow model: just export the trained model using tf.saved_model.builder.SavedModelBuilder as shown in the tfgo README.

    However, you just have to extract from the graph the input placeholder and then feed the network using it.

    Let's suppose you exported your model calling it my_model and tagged it with the tag tag. Also let's suppose that your input placeholder is named "Placeholder". Moreover, you have to know the name of your output node. Let's call it output/node/path/op. Then your code should look like:

    import (
            "fmt"
            tg "github.com/galeone/tfgo"
            tf "github.com/tensorflow/tensorflow/tensorflow/go"
            "flags"
    )
    
    func main() {
            model := tg.LoadModel("my_model", []string{"tag"}, nil)
    
            // Buffer input text
            var buffer bytes.Buffer
            args := os.Args[1:]
    
            for _, arg := range args {
                buffer.WriteString(arg + " ")
            }
            // handle the retunred error below, if any
            inputText, _ := tf.NewTensor(buffer.String())
    
            results := model.Exec([]tf.Output{
                    model.Op("output/node/path/op", 0),
            }, map[tf.Output]*tf.Tensor{
                    model.Op("Placeholder", 0): inputText,
            })
            // do something with results[0]
    }
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?