yangyanzhao 2019-12-02 09:36 采纳率: 87.5%
浏览 319
已采纳

OPENNN 预测结果总是1.0

OPENNN 预测结果总是1.0

neural_network = new NeuralNetwork(4, 6, 1);
ModelSelection model_selection;
TrainingStrategy training_strategy;
LossIndex loss_index;
DataSet data_set;
data_set.set_data_file_name("D://iris_plant1.csv");

data_set.set_separator("Comma");

data_set.load_data();

OpenNN::Variables * variables_pointer = data_set.get_variables_pointer();
variables_pointer->set_name(0, "sepal_length");
variables_pointer->set_units(0, "centimeters");
variables_pointer->set_use(0, Variables::Input);

variables_pointer->set_name(1, "sepal_width");
variables_pointer->set_units(1, "centimeters");
variables_pointer->set_use(1, Variables::Input);

variables_pointer->set_name(2, "petal_length");
variables_pointer->set_units(2, "centimeters");
variables_pointer->set_use(2, Variables::Input);

variables_pointer->set_name(3, "petal_width");
variables_pointer->set_units(3, "centimeters");
variables_pointer->set_use(3, Variables::Input);

variables_pointer->set_name(4, "iris_setosa");
variables_pointer->set_use(4, Variables::Target);


const Matrix<std::string> inputs_information = variables_pointer->arrange_inputs_information();
const Matrix<std::string> targets_information = variables_pointer->arrange_targets_information();



Instances* instances_pointer = data_set.get_instances_pointer();

instances_pointer->split_random_indices();

const Vector< Statistics<double> > inputs_statistics = data_set.scale_inputs_minimum_maximum();

Inputs* inputs_pointer = neural_network->get_inputs_pointer();
inputs_pointer->set_information(inputs_information);

Outputs* outputs_pointer = neural_network->get_outputs_pointer();
outputs_pointer->set_information(targets_information);

neural_network->construct_scaling_layer();

ScalingLayer* scaling_layer_pointer = neural_network->get_scaling_layer_pointer();
//scaling_layer_pointer->set_scaling_method(ScalingLayer::MinimumMaximum);
scaling_layer_pointer->set_scaling_method(ScalingLayer::NoScaling);

neural_network->construct_unscaling_layer();
UnscalingLayer* unscaling_layer_pointer = neural_network->get_unscaling_layer_pointer();
unscaling_layer_pointer->set_unscaling_method(UnscalingLayer::NoUnscaling);

neural_network->construct_probabilistic_layer();


//neural_network->set_bounding_layer_pointer(true);

ProbabilisticLayer* probabilistic_layer_pointer = neural_network->get_probabilistic_layer_pointer();
probabilistic_layer_pointer->set_probabilistic_method(ProbabilisticLayer::Softmax);

// Loss index
loss_index.set_data_set_pointer(&data_set);
loss_index.set_neural_network_pointer(neural_network);


training_strategy.set(&loss_index);
training_strategy.set_main_type(TrainingStrategy::QUASI_NEWTON_METHOD);
QuasiNewtonMethod* quasi_Newton_method_pointer = training_strategy.get_quasi_Newton_method_pointer();
quasi_Newton_method_pointer->set_minimum_loss_increase(1.0e-6);
training_strategy.set_display(false);

TrainingStrategy::Results results = training_strategy.perform_training();

// Model selection
model_selection.set_training_strategy_pointer(&training_strategy);
model_selection.set_order_selection_type(ModelSelection::GOLDEN_SECTION);
GoldenSectionOrder* golden_section_order_pointer = model_selection.get_golden_section_order_pointer();
golden_section_order_pointer->set_tolerance(1.0e-7);


// Testing analysis

TestingAnalysis testing_analysis(neural_network, &data_set);
const Matrix<size_t> confusion = testing_analysis.calculate_confusion();

// Save results

data_set.save("data_set.xml");

neural_network->save("neural_network.xml");
neural_network->save_expression("expression.txt");

training_strategy.save("training_strategy.xml");

model_selection.save("model_selection.xml");

使用calculate_outputs预测结果,输出总是1.
  • 写回答

1条回答 默认 最新

  • zqbnqsdsmd 2019-12-02 22:40
    关注
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

问题事件

  • 已采纳回答 10月22日

悬赏问题

  • ¥60 版本过低apk如何修改可以兼容新的安卓系统
  • ¥25 由IPR导致的DRIVER_POWER_STATE_FAILURE蓝屏
  • ¥50 有数据,怎么建立模型求影响全要素生产率的因素
  • ¥50 有数据,怎么用matlab求全要素生产率
  • ¥15 TI的insta-spin例程
  • ¥15 完成下列问题完成下列问题
  • ¥15 C#算法问题, 不知道怎么处理这个数据的转换
  • ¥15 YoloV5 第三方库的版本对照问题
  • ¥15 请完成下列相关问题!
  • ¥15 drone 推送镜像时候 purge: true 推送完毕后没有删除对应的镜像,手动拷贝到服务器执行结果正确在样才能让指令自动执行成功删除对应镜像,如何解决?