dongmi1864
2017-11-06 17:58
浏览 104

如何在Golang中的文本上执行DL-RNN模型?

I have build RNN model in tensor-flow based on reddit/twitter conversations. I saved it in pb. Does anyone know how to pass raw text string through the model in golang and produce output?

modeldir := "/my_model.pb"

// Buffer input text
var buffer bytes.Buffer

args := os.Args[1:]

for _, arg := range args {
    buffer.WriteString(arg + " ")
}

inputText := buffer.String()

// Load the serialized GraphDef from a file.

model, err := ioutil.ReadFile(modeldir)
if err != nil {
    log.Fatal(err)
}
// Construct an in-memory graph from the serialized form.
graph := tf.NewGraph()
if err := graph.Import(model, ""); err != nil {
    log.Fatal(err)
}
// Create a session for inference over graph.
session, err := tf.NewSession(graph, nil)
if err != nil {
    log.Fatal(err)
}
defer session.Close()

图片转代码服务由CSDN问答提供 功能建议

我在基于 tensor-flow 的模型中构建了 RNN 模型 在reddit / twitter对话中。 我将其保存在 pb 中。 有谁知道如何通过 golang 中的模型传递原始文本字符串并产生输出?

  modeldir:=“ /my_model.pb"

//缓冲区输入文本
var缓冲区字节。Buffer
 
args:= os.Args [1:  ] 
 
for _,arg:= range args {
 buffer.WriteString(arg +“”)
} 
 
inputText:= buffer.String()
 
 //从a加载序列化的GraphDef  file。
 
model,err:= ioutil.ReadFile(modeldir)
if err!= nil {
 log.Fatal(err)
} 
 //从序列化的表格构造内存中图形。\  ngraph:= tf.NewGraph()
if err:= graph.Import(model,“”);  err!= nil {
 log.Fatal(err)
} 
 //创建一个会话以对图进行推断。
session,err:= tf.NewSession(graph,nil)
if err!= nil {\  n log.Fatal(err)
} 
defer session.Close()
   
 
  • 写回答
  • 好问题 提建议
  • 追加酬金
  • 关注问题
  • 收藏
  • 邀请回答

1条回答 默认 最新

  • dream3323 2017-11-06 18:40
    已采纳

    You can use tfgo to easily load into Go and use a trained tensorflow model: just export the trained model using tf.saved_model.builder.SavedModelBuilder as shown in the tfgo README.

    However, you just have to extract from the graph the input placeholder and then feed the network using it.

    Let's suppose you exported your model calling it my_model and tagged it with the tag tag. Also let's suppose that your input placeholder is named "Placeholder". Moreover, you have to know the name of your output node. Let's call it output/node/path/op. Then your code should look like:

    import (
            "fmt"
            tg "github.com/galeone/tfgo"
            tf "github.com/tensorflow/tensorflow/tensorflow/go"
            "flags"
    )
    
    func main() {
            model := tg.LoadModel("my_model", []string{"tag"}, nil)
    
            // Buffer input text
            var buffer bytes.Buffer
            args := os.Args[1:]
    
            for _, arg := range args {
                buffer.WriteString(arg + " ")
            }
            // handle the retunred error below, if any
            inputText, _ := tf.NewTensor(buffer.String())
    
            results := model.Exec([]tf.Output{
                    model.Op("output/node/path/op", 0),
            }, map[tf.Output]*tf.Tensor{
                    model.Op("Placeholder", 0): inputText,
            })
            // do something with results[0]
    }
    
    评论
    解决 无用
    打赏 举报

相关推荐 更多相似问题