问题遇到的现象和发生背景
最近在做NSFW识别,从github找到一个已经训练好的PB模型文件,使用python调用的,当尝试用C++调用的时候,发现参数不正确.通过python 生成日志,用TensorBoard查看日志时,发现模型输入参数是字符串??字符串类型是json,在我理解力,明明应该是一张图片才对?!python 脚本里也是这样传的。
gitee地址: https://gitee.com/yyj8209/CVSample/tree/master/TensorFlow/inception_model
github源地址:https://github.com/kingroc711/CVSample/tree/master/TensorFlow/inception_model
问题相关代码,请勿粘贴截图
#include <iostream>
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/cc/ops/image_ops.h"
//#include "eigen3/unsupported/Eigen/CXX11/Tensor"
//#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include <vector>
//using namespace tensorflow;
using namespace std;
using namespace tensorflow;
using tensorflow::Tensor;
using tensorflow::Status;
using tensorflow::string;
using tensorflow::int32;
Status ReadTensorFromImageFile(string file_name, const int input_height,
const int input_width,
vector<Tensor>* out_tensors);
int main(int argc, char *argv[])
{
SessionOptions sessionOptions;
Session *session = NewSession(sessionOptions);
string modelPath = "/opt/work/build_work/TensorFlow/inception_model/output_graph.pb";
//tensorflow 官方模型
// modelPath = "/opt/work/c_work/qt_work/tensorflow_cc_demo/model/classify_image_graph_def.pb";
GraphDef graphDef;
Status statud_load = ReadBinaryProto(Env::Default(), modelPath, &graphDef);
if(statud_load.ok()) {
cout << "load pb file success : " << modelPath << endl;
}
cout << "node size:" << graphDef.node_size() << endl;
graphDef.node();
if( session->Create(graphDef).ok() ) {
cout << "success graph in session " << endl;
}
string image_path("/root/test.png");
// image_path = "/opt/work/c_work/qt_work/tensorflow_cc_demo/model/cropped_panda.jpg";
vector<Tensor> inputs;
if(ReadTensorFromImageFile(image_path, 100, 100, &inputs).ok()) {
cout << "image load success!" << endl;
cout << inputs.size() << endl;
}
vector<Tensor> outputs;
string input = "DecodeJpeg/contents:0";
string output = "final_result:0";
cout << inputs[0].DebugString() << endl;
pair<string, Tensor> img(input,inputs[0]);
vector<pair<string, tensorflow::Tensor>> runInputs = {
{"DecodeJpeg/contents:0", inputs[0]},
};
Status status = session->Run(runInputs, {output}, {}, &outputs);
cout << status << endl;
if (!status.ok()) {
cout << "run failed!" << endl;
}
cout << outputs.size() << endl;
cout << "hello wordl" << endl;
return 0;
}
Status ReadTensorFromImageFile(string file_name, const int input_height,
const int input_width,
vector<Tensor>* out_tensors) {
auto root = Scope::NewRootScope();
using namespace ops;
auto file_reader = ops::ReadFile(root.WithOpName("file_reader"),file_name);
const int wanted_channels = 1;
Output image_reader;
std::size_t found = file_name.find(".png");
//判断文件格式
if (found!=std::string::npos) {
image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,DecodePng::Channels(wanted_channels));
}
else {
image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,DecodeJpeg::Channels(wanted_channels));
}
// 下面几步是读取图片并处理
auto float_caster =Cast(root.WithOpName("float_caster"), image_reader, DT_FLOAT);
auto dims_expander = ExpandDims(root, float_caster, 0);
auto resized = ResizeBilinear(root, dims_expander,Const(root.WithOpName("resize"), {input_height, input_width}));
// Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),{input_std});
Transpose(root.WithOpName("transpose"),resized,{0,2,1,3});
GraphDef graph;
root.ToGraphDef(&graph);
unique_ptr<Session> session(NewSession(SessionOptions()));
session->Create(graph);
session->Run({}, {"transpose"}, {}, out_tensors);//Run,获取图片数据保存到Tensor中
return Status::OK();
}
运行结果及报错内容
Invalid argument: Expects arg[0] to be string but float is provided
我的解答思路和尝试过的方法
希望用C++调过PB模型的,给出一些建议。我没有接触过太多tensorflow,并不清楚我现在想法 对不对