eis/eqpalg/.do_not_use/regression/train_result.h

40 lines
948 B
C++

/**
* @file eqpalg/regression/train_result.h
* @brief 训练结果,包含预测函数和错误率
* @author Cat (null.null.null@qq.com)
* @version 0.1
* @date 2021-09-17
*
* Copyright: Baosight Co. Ltd.
* DO NOT COPY/USE WITHOUT PERMISSION
*
*/
#pragma once
#include <memory>
#include <utility>
namespace regression {
template <typename FuncType>
struct TrainResult {
public:
bool is_result_legal;
double cross_validation_score;
double err_rate;
private:
const std::unique_ptr<FuncType> predict_func_ptr_;
const size_t dims_;
TrainResult(std::unique_ptr<FuncType>&& func_ptr, size_t dims)
: predict_func_ptr_(std::forward(func_ptr)), dims_(dims) {}
template <typename Tp>
double predict(Tp input) {
dlib::matrix<double, 0, 1> input_convert;
for (size_t i = 0; i < dims_; i++) {
input_convert(i) = input[i];
}
return (*predict_func_ptr_)(input_convert);
}
};
} // namespace regression