40 lines
948 B
C++
40 lines
948 B
C++
/**
|
|
* @file eqpalg/regression/train_result.h
|
|
* @brief 训练结果,包含预测函数和错误率
|
|
* @author Cat (null.null.null@qq.com)
|
|
* @version 0.1
|
|
* @date 2021-09-17
|
|
*
|
|
* Copyright: Baosight Co. Ltd.
|
|
* DO NOT COPY/USE WITHOUT PERMISSION
|
|
*
|
|
*/
|
|
#pragma once
|
|
#include <memory>
|
|
#include <utility>
|
|
namespace regression {
|
|
template <typename FuncType>
|
|
struct TrainResult {
|
|
public:
|
|
bool is_result_legal;
|
|
double cross_validation_score;
|
|
double err_rate;
|
|
|
|
private:
|
|
const std::unique_ptr<FuncType> predict_func_ptr_;
|
|
const size_t dims_;
|
|
|
|
TrainResult(std::unique_ptr<FuncType>&& func_ptr, size_t dims)
|
|
: predict_func_ptr_(std::forward(func_ptr)), dims_(dims) {}
|
|
|
|
template <typename Tp>
|
|
double predict(Tp input) {
|
|
dlib::matrix<double, 0, 1> input_convert;
|
|
for (size_t i = 0; i < dims_; i++) {
|
|
input_convert(i) = input[i];
|
|
}
|
|
return (*predict_func_ptr_)(input_convert);
|
|
}
|
|
};
|
|
} // namespace regression
|