Ming_Yan 2022-03-08 22:36 采纳率: 100%
浏览 86
已结题

tensorflow C++ 调用PB模型文件参数问题

问题遇到的现象和发生背景

最近在做NSFW识别,从github找到一个已经训练好的PB模型文件,使用python调用的,当尝试用C++调用的时候,发现参数不正确.通过python 生成日志,用TensorBoard查看日志时,发现模型输入参数是字符串??字符串类型是json,在我理解力,明明应该是一张图片才对?!python 脚本里也是这样传的。
gitee地址: https://gitee.com/yyj8209/CVSample/tree/master/TensorFlow/inception_model
github源地址:https://github.com/kingroc711/CVSample/tree/master/TensorFlow/inception_model

问题相关代码,请勿粘贴截图

img

#include <iostream>
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"

#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/cc/ops/image_ops.h"
//#include "eigen3/unsupported/Eigen/CXX11/Tensor"
//#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include <vector>

//using namespace tensorflow;

using namespace std;
using namespace tensorflow;
using tensorflow::Tensor;
using tensorflow::Status;
using tensorflow::string;
using tensorflow::int32;



Status ReadTensorFromImageFile(string file_name, const int input_height,
                                 const int input_width,
                                 vector<Tensor>* out_tensors);

int main(int argc, char *argv[])
{
    SessionOptions sessionOptions;
    Session *session = NewSession(sessionOptions);
    string modelPath = "/opt/work/build_work/TensorFlow/inception_model/output_graph.pb";
    //tensorflow 官方模型
//    modelPath = "/opt/work/c_work/qt_work/tensorflow_cc_demo/model/classify_image_graph_def.pb";
    GraphDef graphDef;
    Status statud_load = ReadBinaryProto(Env::Default(), modelPath, &graphDef);
    if(statud_load.ok()) {
        cout << "load pb file success : " << modelPath << endl;
    }
    cout << "node size:" << graphDef.node_size() << endl;
    graphDef.node();

    if( session->Create(graphDef).ok() ) {
        cout << "success graph in session " << endl;
    }
    string image_path("/root/test.png");
//    image_path = "/opt/work/c_work/qt_work/tensorflow_cc_demo/model/cropped_panda.jpg";
    vector<Tensor> inputs;
    if(ReadTensorFromImageFile(image_path, 100, 100, &inputs).ok()) {
        cout << "image load success!" << endl;
        cout << inputs.size() << endl;
    }

    vector<Tensor> outputs;
    string input = "DecodeJpeg/contents:0";
    string output = "final_result:0";
    cout << inputs[0].DebugString() << endl;
    pair<string, Tensor> img(input,inputs[0]);
    vector<pair<string, tensorflow::Tensor>> runInputs = {
        {"DecodeJpeg/contents:0", inputs[0]},
    };

    Status status = session->Run(runInputs, {output}, {}, &outputs);
    cout << status << endl;
    if (!status.ok()) {
        cout << "run failed!" << endl;
    }
    cout << outputs.size() << endl;

    cout << "hello wordl" << endl;
    return 0;
}



Status ReadTensorFromImageFile(string file_name, const int input_height,
                                 const int input_width,
                                 vector<Tensor>* out_tensors) {
      auto root = Scope::NewRootScope();
      using namespace ops;

      auto file_reader = ops::ReadFile(root.WithOpName("file_reader"),file_name);
      const int wanted_channels = 1;
      Output image_reader;
      std::size_t found = file_name.find(".png");
      //判断文件格式
      if (found!=std::string::npos) {
          image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,DecodePng::Channels(wanted_channels));
      }
      else {
          image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,DecodeJpeg::Channels(wanted_channels));
      }
      // 下面几步是读取图片并处理
      auto float_caster =Cast(root.WithOpName("float_caster"), image_reader, DT_FLOAT);
      auto dims_expander = ExpandDims(root, float_caster, 0);
      auto resized = ResizeBilinear(root, dims_expander,Const(root.WithOpName("resize"), {input_height, input_width}));
      // Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),{input_std});
     Transpose(root.WithOpName("transpose"),resized,{0,2,1,3});

      GraphDef graph;
      root.ToGraphDef(&graph);

      unique_ptr<Session> session(NewSession(SessionOptions()));
      session->Create(graph);
      session->Run({}, {"transpose"}, {}, out_tensors);//Run,获取图片数据保存到Tensor中

      return Status::OK();
 }

运行结果及报错内容

Invalid argument: Expects arg[0] to be string but float is provided

我的解答思路和尝试过的方法

希望用C++调过PB模型的,给出一些建议。我没有接触过太多tensorflow,并不清楚我现在想法 对不对

我想要达到的结果
  • 写回答

1条回答 默认 最新

  • Ming_Yan 2022-03-15 23:31
    关注

    问题已解决,首先,传入参数是一张图片,类型确实是一个字符串,只不过类型是tstring,通过std::unique_ptrtensorflow::RandomAccessFile生成字符串。

    tensorflow已经提供了针对C++预测图片提供了示例,示例地址:https://gitee.com/mirrors/tensorflow/blob/master/tensorflow/examples/label_image/main.cc
    示例中ReadEntireFile函数就是专门处理把图片转为字符串数据的函数。
    最后附上运行成功全部代码:

    #include "tensorflow/core/public/session.h"
    #include "tensorflow/core/framework/tensor.h"
    #include "tensorflow/core/framework/graph.pb.h"
    #include "tensorflow/core/framework/tensor.h"
    #include "tensorflow/core/graph/default_device.h"
    #include "tensorflow/core/graph/graph_def_builder.h"
    #include "tensorflow/core/platform/logging.h"
    #include "tensorflow/core/platform/types.h"
    #include "tensorflow/core/public/session.h"
    #include <vector>
    
    using namespace std;
    using namespace tensorflow;
    using tensorflow::Tensor;
    using tensorflow::Status;
    using tensorflow::string;
    using tensorflow::int32;
    
    static Status ReadEntireFile(tensorflow::Env* env, const string& filename,
                                     Tensor* output);
    
    
    int main(void){
        SessionOptions sessionOptions;
        Session *session = NewSession(sessionOptions);
        //pb文件路径
        string modelPath = "/opt/work/build_work/TensorFlow/inception_model/output_graph.pb";
        GraphDef graphDef;
        Status statud_load = ReadBinaryProto(Env::Default(), modelPath, &graphDef);
        if(statud_load.ok()) {
            cout << "load pb file success : " << modelPath << endl;
        }
    
    
        if( session->Create(graphDef).ok() ) {
            cout << "success graph in session " << endl;
        }
        vector<Tensor> outputs;
        string input = "DecodeJpeg/contents:0";
        string output = "final_result:0";
        Tensor input0(DT_STRING, TensorShape());
        //图片文件
        if(ReadEntireFile(tensorflow::Env::Default(), "/root/test.png", &input0).ok()) {
            cout << "图片读取成功!" << endl;
        }
        vector<pair<string, tensorflow::Tensor>> runInputs = {
            {"DecodeJpeg/contents:0", input0},
        };
        //预测
        Status status = session->Run(runInputs, {output}, {}, &outputs);
        cout << status << endl;
        if (!status.ok()) {
            cout << "run failed!" << endl;
        }
        //处理输出结果,模型输出结果就是一维数组,按照索引0,1,2,3,4分别对应porn。neutral、hentai、drawings、sexy
        Tensor scores;
        scores = outputs[0];
        tensorflow::TTypes<float>::Flat scores_flat = scores.flat<float>();
        //scores_flat.size() 数量是 5,打印每一个分类分数
        for(int i = 0; i < scores_flat.size(); i++) {
            cout << scores_flat(i) << endl;
        }
        return 0;
    }
    
    
    static Status ReadEntireFile(tensorflow::Env* env, const string& filename,
                                 Tensor* output) {
      tensorflow::uint64 file_size = 0;
      TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
    
      string contents;
      contents.resize(file_size);
    
      std::unique_ptr<tensorflow::RandomAccessFile> file;
      TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
    
      tensorflow::StringPiece data;
      TF_RETURN_IF_ERROR(file->Read(0, file_size, &data, &(contents)[0]));
      if (data.size() != file_size) {
        return tensorflow::errors::DataLoss("Truncated read of '", filename,
                                            "' expected ", file_size, " got ",
                                            data.size());
      }
      output->scalar<tstring>()() = tstring(data);
      return Status::OK();
    }
    ```c++
    
    
    

    ```

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

问题事件

  • 系统已结题 3月23日
  • 已采纳回答 3月15日
  • 创建了问题 3月8日

悬赏问题

  • ¥15 matlab不知道怎么改,求解答!!
  • ¥15 永磁直线电机的电流环pi调不出来
  • ¥15 用stata实现聚类的代码
  • ¥15 请问paddlehub能支持移动端开发吗?在Android studio上该如何部署?
  • ¥20 docker里部署springboot项目,访问不到扬声器
  • ¥15 netty整合springboot之后自动重连失效
  • ¥15 悬赏!微信开发者工具报错,求帮改
  • ¥20 wireshark抓不到vlan
  • ¥20 关于#stm32#的问题:需要指导自动酸碱滴定仪的原理图程序代码及仿真
  • ¥20 设计一款异域新娘的视频相亲软件需要哪些技术支持