dongpo5207 2017-10-01 12:21
浏览 113
已采纳

在Go中加载Tensorflow模型时无法预测

I've loaded a Tensorflow model in Go and cannot get predictions - it keeps complaining about shape mismatch - a simple 2d array. Would appreciate an idea here, thank you so much in advance.

Error running the session with input, err: You must feed a value for placeholder tensor 'theoutput_target' with dtype float
 [[Node: theoutput_target = Placeholder[_output_shapes=[[?,?]], dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

Input tensor being sent is a [][]float32{ {1.0}, }

a := [][]float32{ {1.0}, }
tensor, terr :=  tf.NewTensor(a)
if terr != nil {
    fmt.Printf("Error creating input tensor: %s
", terr.Error())
    return
}
result, runErr := model.Session.Run(
    map[tf.Output]*tf.Tensor{
        model.Graph.Operation("theinput").Output(0): tensor,
    },
    []tf.Output{
        model.Graph.Operation("theoutput_target").Output(0),
    },
    nil,
)

and the model is generated via Keras and exported to TF using SavedModelBuilder after:

layer_name_input = "theinput"
layer_name_output = "theoutput"

def get_encoder():
    model = Sequential()
    model.add(Dense(5, input_dim=1))
    model.add(Activation("relu"))
    model.add(Dense(5, input_dim=1))
    return model

inputs = Input(shape=(1, ), name=layer_name_input)
encoder = get_encoder()
model = encoder(inputs)
model = Activation("relu")(model)
objective = Dense(1, name=layer_name_output)(model)
model = Model(inputs=[inputs], outputs=objective)
model.compile(loss='mean_squared_error', optimizer='sgd')

EDIT - fixed, it was a problem with exporting from Keras to TF (layer names). Pasting the export here, hopefully helpful for someone else:

def export_to_tf(keras_model_path, export_path, export_version, is_functional=False):

    sess = tf.Session()
    K.set_session(sess)
    K.set_learning_phase(0)

    export_path = os.path.join(export_path, str(export_version))

    model = load_model(keras_model_path)
    config = model.get_config()
    weights = model.get_weights()
    if is_functional == True:
        model = Model.from_config(config)
    else:
        model = Sequential.from_config(config)
    model.set_weights(weights)

    with K.get_session() as sess:
        inputs = [ (model_input.name.split(":")[0], model_input) for model_input in model.inputs]
        outputs = [ (model_output.name.split(":")[0], model_output) for model_output in model.outputs]
        signature = predict_signature_def(inputs=dict(inputs),
                                      outputs=dict(outputs))
        input_descriptor = [ { 'name': item[0], 'shape': item[1].shape.as_list() } for item in inputs]
        output_descriptor = [ { 'name': item[0], 'shape': item[1].shape.as_list() } for item in outputs]
        builder = saved_model_builder.SavedModelBuilder(export_path)
        builder.add_meta_graph_and_variables(
            sess=sess,
            tags=[tag_constants.SERVING],
            signature_def_map={signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature})
        builder.save()

        descriptor = dict()
        descriptor["inputs"] = input_descriptor
        descriptor["outputs"] = output_descriptor
        pprint.pprint(descriptor)               
  • 写回答

1条回答 默认 最新

  • dongluan5740 2017-10-02 12:34
    关注

    That's something strange in your code and error. Tensorflow is complaining about a missing value for the placeholder with name 'theoutput_target', whilst this placeholder is never defined in the code you posted. Instead, your code defines a placeholder whose name is 'theinput'.

    Also, I suggest you to use a more complete and easy to use wrapper around the tensorflow API: tfgo

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

悬赏问题

  • ¥15 delta降尺度计算的一些细节,有偿
  • ¥15 Arduino红外遥控代码有问题
  • ¥15 数值计算离散正交多项式
  • ¥30 数值计算均差系数编程
  • ¥15 redis-full-check比较 两个集群的数据出错
  • ¥15 Matlab编程问题
  • ¥15 训练的多模态特征融合模型准确度很低怎么办
  • ¥15 kylin启动报错log4j类冲突
  • ¥15 超声波模块测距控制点灯,灯的闪烁很不稳定,经过调试发现测的距离偏大
  • ¥15 import arcpy出现importing _arcgisscripting 找不到相关程序