尊敬的作者:
您好!拜读了您写的“社区分享 | 从零开始学习 TinyML”文章,有个问题很想提问,冒昧打扰,请见谅。我自己训练了两层神经网络,并利用xdd生成16进制数组,并将代码替换alignas(8) const unsigned char g_model[],将const int g_model_len改成我的参数,但是下载到板子里预测值是一个恒值,请问还需要改动其他地方的代码么?我只改了这两个参数,其他没发现需要改动的,烦请您指点一下,谢谢
关于#人工智能#的问题:拜读了您写的“社区分享 | 从零开始学习 TinyML”文章(相关搜索:神经网络)
- 写回答
- 好问题 0 提建议
- 追加酬金
- 关注问题
- 邀请回答
-
4条回答 默认 最新
关注 让阿豪来帮你解答,本回答参考chatgpt3.5编写提供,如果还有疑问可以评论或留言
尊敬的读者: 感谢您对我的文章的关注和反馈。针对您提出的问题,可能有以下几种可能性:- 您的神经网络模型存在问题。请检查您训练的神经网络模型是否存在过拟合或欠拟合的情况,或者是否进行了正确的数据预处理、正则化等操作,以及是否使用了正确的损失函数、训练算法等。您可以尝试在PC端或其他平台上使用您的模型进行预测,看看是否有相同的问题出现。
- 您将生成的16进制数组正确地替换到了代码中,但是预测代码的输入数据格式可能不符合您的预期。请确保您的输入数据符合模型的期望输入格式,并检查是否正确地传递了数据到模型中。
- 您可能还需要修改其他与模型相关的代码。您可以检查一下模型初始化、推理等过程中是否存在其他代码需要修改的地方,例如您的模型结构是否和原始代码中的模型结构相同,是否需要修改其他与模型相关的参数、函数等。 如果您能提供更多的信息和代码,以及问题出现的具体环境和平台,我将更加容易给出具体的答案和代码示例。感谢您的理解和支持! 具体案例和代码示例: 假设您使用的是TensorFlow Lite for Microcontrollers模块,并采用了以下代码来进行模型预测:
#include "tensorflow/lite/micro/micro_error_reporter.h" #include "tensorflow/lite/micro/micro_interpreter.h" #include "tensorflow/lite/micro/micro_mutable_op_resolver.h" #include "tensorflow/lite/micro/all_ops_resolver.h" #include "model_quantized.h" //原始模型定义 #include "input_data.h" //输入数据 const int kNumCols = 28; const int kNumRows = 28; const int kNumChannels = 1; const int kNumSamples = 1; //定义TensorFlow Lite模型运行相关对象 tflite::ErrorReporter* error_reporter = nullptr; const tflite::Model* model = nullptr; tflite::MicroInterpreter* interpreter = nullptr; TfLiteTensor* input = nullptr; TfLiteTensor* output = nullptr; //设置TensorFlow Lite模型运行环境及各种参数 void Setup() { static tflite::MicroErrorReporter micro_error_reporter; error_reporter = µ_error_reporter; model = tflite::GetModel(model_quantized); if (model->version() != TFLITE_SCHEMA_VERSION) { error_reporter->Report( "Model provided is schema version %d not equal " "to supported version %d.", model->version(), TFLITE_SCHEMA_VERSION); return; } static tflite::MicroMutableOpResolver<6> micro_op_resolver; micro_op_resolver.AddConv2D(); micro_op_resolver.AddMaxPool2D(); micro_op_resolver.AddFullyConnected(); micro_op_resolver.AddReshape(); micro_op_resolver.AddRelu(); micro_op_resolver.AddSoftmax(); static tflite::MicroInterpreter static_interpreter( model, micro_op_resolver, tensor_arena, kTensorArenaSize, error_reporter); interpreter = &static_interpreter; TfLiteStatus allocate_status = interpreter->AllocateTensors(); if (allocate_status != kTfLiteOk) { error_reporter->Report("AllocateTensors() failed"); return; } input = interpreter->input(0); //获取输入Tensor output = interpreter->output(0); //获取输出Tensor } //将输入数据转换成TensorFlow Lite模型期望的格式 void FillInputTensor(float* input_data) { for (int i = 0; i < kNumSamples; ++i) { for (int row = 0; row < kNumRows; ++row) { for (int col = 0; col < kNumCols; ++col) { for (int channel = 0; channel < kNumChannels; ++channel) { const int input_index = (i * kNumRows * kNumCols * kNumChannels) + (row * kNumCols * kNumChannels) + (col * kNumChannels) + channel; input->data.f[input_index] = input_data[input_index]; } } } } } //运行TensorFlow Lite模型 void RunInference() { TfLiteStatus invoke_status = interpreter->Invoke(); if (invoke_status != kTfLiteOk) { error_reporter->Report("Invoke failed"); return; } } //获取模型预测结果并返回 int GetPrediction() { float max_value = output->data.f[0]; int max_index = 0; for (int i = 1; i < output->dims->data[1]; ++i) { if (output->data.f[i] > max_value) { max_value = output->data.f[i]; max_index = i; } } return max_index; } int main() { Setup(); //初始化TensorFlow Lite模型运行环境 FillInputTensor(input_data); //将输入数据转换成TensorFlow Lite模型期望的格式 RunInference(); //运行TensorFlow Lite模型 int prediction = GetPrediction(); //获取模型预测结果 printf("prediction: %d\n", prediction); //输出预测结果 return 0; }
如果您的模型和输入数据符合以上代码的要求,那么您只需要将您生成的16进制数组替换到model_quantized.h中的相应位置,并将const int g_model_len改成您的数组大小即可。如果您使用的是其他平台或模块,可以参考相关的代码示例来进行修改。
解决 无用评论 打赏 举报
悬赏问题
- ¥15 python使用selenium工具爬取网站的问题
- ¥15 关于#python#的问题:如何通过pywinauto获取到图中“窗格”内部的内容
- ¥15 visionMaster4.3.0 与QT 的二次开发异常
- ¥50 关于#pcb工艺#的问题:这个设计电路中,最终组合起来起到了什么作用
- ¥15 鼎捷t100或鼎捷E10生产模块与odoo17详细区别和鼎捷t100或鼎捷E10物料清单(BOM)可以在子级的子级在同一界面操作吗
- ¥50 VS2019,xamarin框架串口调试报错Java.Lang.SecurityException: Exception of type
- ¥20 QT如何判断QLineF线鼠标划过事件
- ¥15 关于#phpstorm#的问题:phpstorm编辑工具 光标选中了就会自动复制到粘贴板上 这样我之前复制的内容就失效了
- ¥15 pychram安装jupyter插件
- ¥60 悬赏破解越狱iphone4s中360保险箱密码遗忘