yangyanzhao 2019-12-02 09:36 采纳率: 0%
浏览 321
已采纳

OPENNN 预测结果总是1.0

OPENNN 预测结果总是1.0

neural_network = new NeuralNetwork(4, 6, 1);
ModelSelection model_selection;
TrainingStrategy training_strategy;
LossIndex loss_index;
DataSet data_set;
data_set.set_data_file_name("D://iris_plant1.csv");

data_set.set_separator("Comma");

data_set.load_data();

OpenNN::Variables * variables_pointer = data_set.get_variables_pointer();
variables_pointer->set_name(0, "sepal_length");
variables_pointer->set_units(0, "centimeters");
variables_pointer->set_use(0, Variables::Input);

variables_pointer->set_name(1, "sepal_width");
variables_pointer->set_units(1, "centimeters");
variables_pointer->set_use(1, Variables::Input);

variables_pointer->set_name(2, "petal_length");
variables_pointer->set_units(2, "centimeters");
variables_pointer->set_use(2, Variables::Input);

variables_pointer->set_name(3, "petal_width");
variables_pointer->set_units(3, "centimeters");
variables_pointer->set_use(3, Variables::Input);

variables_pointer->set_name(4, "iris_setosa");
variables_pointer->set_use(4, Variables::Target);


const Matrix<std::string> inputs_information = variables_pointer->arrange_inputs_information();
const Matrix<std::string> targets_information = variables_pointer->arrange_targets_information();



Instances* instances_pointer = data_set.get_instances_pointer();

instances_pointer->split_random_indices();

const Vector< Statistics<double> > inputs_statistics = data_set.scale_inputs_minimum_maximum();

Inputs* inputs_pointer = neural_network->get_inputs_pointer();
inputs_pointer->set_information(inputs_information);

Outputs* outputs_pointer = neural_network->get_outputs_pointer();
outputs_pointer->set_information(targets_information);

neural_network->construct_scaling_layer();

ScalingLayer* scaling_layer_pointer = neural_network->get_scaling_layer_pointer();
//scaling_layer_pointer->set_scaling_method(ScalingLayer::MinimumMaximum);
scaling_layer_pointer->set_scaling_method(ScalingLayer::NoScaling);

neural_network->construct_unscaling_layer();
UnscalingLayer* unscaling_layer_pointer = neural_network->get_unscaling_layer_pointer();
unscaling_layer_pointer->set_unscaling_method(UnscalingLayer::NoUnscaling);

neural_network->construct_probabilistic_layer();


//neural_network->set_bounding_layer_pointer(true);

ProbabilisticLayer* probabilistic_layer_pointer = neural_network->get_probabilistic_layer_pointer();
probabilistic_layer_pointer->set_probabilistic_method(ProbabilisticLayer::Softmax);

// Loss index
loss_index.set_data_set_pointer(&data_set);
loss_index.set_neural_network_pointer(neural_network);


training_strategy.set(&loss_index);
training_strategy.set_main_type(TrainingStrategy::QUASI_NEWTON_METHOD);
QuasiNewtonMethod* quasi_Newton_method_pointer = training_strategy.get_quasi_Newton_method_pointer();
quasi_Newton_method_pointer->set_minimum_loss_increase(1.0e-6);
training_strategy.set_display(false);

TrainingStrategy::Results results = training_strategy.perform_training();

// Model selection
model_selection.set_training_strategy_pointer(&training_strategy);
model_selection.set_order_selection_type(ModelSelection::GOLDEN_SECTION);
GoldenSectionOrder* golden_section_order_pointer = model_selection.get_golden_section_order_pointer();
golden_section_order_pointer->set_tolerance(1.0e-7);


// Testing analysis

TestingAnalysis testing_analysis(neural_network, &data_set);
const Matrix<size_t> confusion = testing_analysis.calculate_confusion();

// Save results

data_set.save("data_set.xml");

neural_network->save("neural_network.xml");
neural_network->save_expression("expression.txt");

training_strategy.save("training_strategy.xml");

model_selection.save("model_selection.xml");

使用calculate_outputs预测结果,输出总是1.
  • 写回答

1条回答 默认 最新

  • zqbnqsdsmd 2019-12-02 22:40
    关注
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

问题事件

  • 已采纳回答 10月22日

悬赏问题

  • ¥15 onlyoffice编辑完后立即下载,下载的不是最新编辑的文档
  • ¥15 求caverdock使用教程
  • ¥15 Coze智能助手搭建过程中的问题请教
  • ¥15 12864只亮屏 不显示汉字
  • ¥20 三极管1000倍放大电路
  • ¥15 vscode报错如何解决
  • ¥15 前端vue CryptoJS Aes CBC加密后端java解密
  • ¥15 python随机森林对两个excel表格读取,shap报错
  • ¥15 基于STM32心率血氧监测(OLED显示)相关代码运行成功后烧录成功OLED显示屏不显示的原因是什么
  • ¥100 X轴为分离变量(因子变量),如何控制X轴每个分类变量的长度。