100 lines
2.5 KiB
C++
100 lines
2.5 KiB
C++
/**
|
||
* @file eqpalg/regression/frame.h
|
||
* @brief 回归预测系统的框架
|
||
* @author Cat (null.null.null@qq.com)
|
||
* @version 0.1
|
||
* @date 2021-07-08
|
||
*
|
||
* Company: Baosight Co. Ltd.
|
||
* DO NOT COPY/USE WITHOUT PERMISSION
|
||
*
|
||
*/
|
||
#pragma once
|
||
|
||
#include <vector>
|
||
#include <iostream>
|
||
#include <dlib/svm.h>
|
||
#include <log4cplus/LOG.h>
|
||
#include <chrono>
|
||
#include <utility>
|
||
#include <tuple>
|
||
#include <eqpalg/regression/krr.h>
|
||
#include <eqpalg/regression/train_result.h>
|
||
namespace regression {
|
||
using namespace std;
|
||
using namespace dlib;
|
||
|
||
/**
|
||
* @brief 回归系统的主框架
|
||
*/
|
||
struct Frame {
|
||
// 用来为dlib提供输入的数据类型
|
||
using SampleType = matrix<
|
||
double>; ///< 因为有一维数据作为label(下方的LabelsType),所以需要输入的rows
|
||
///< -1
|
||
using Label = std::vector<double>; ///< 用来当作拟合目标数据y的输出集合
|
||
using TrSample = std::vector<SampleType>; ///< 拟合数据x的输入集合
|
||
|
||
private:
|
||
std::chrono::system_clock::duration train_time_cost_; ///< 训练所花费时间
|
||
std::chrono::system_clock::duration
|
||
regression_func_time_cost_; ///< 回归函数所花费时间
|
||
/**
|
||
* @todo
|
||
* 上方的花费时间暂时没有使用,未来可以将这两个花费时间信息
|
||
* 作为评判拟合优劣的一个选项
|
||
*/
|
||
|
||
// 设置两种回归函数
|
||
std::unique_ptr<Krr> krr_;
|
||
// typename krls_t<rows - 1, cols>::TrainedResult krls_result_;
|
||
|
||
int best_regression_type = -1; ///< 默认回归函数无效
|
||
|
||
const std::string rule_id_;
|
||
|
||
const size_t dims_;
|
||
|
||
public:
|
||
Frame(const std::string& ruleId, size_t dims);
|
||
|
||
/**
|
||
* @brief Destroy the Frame object
|
||
*/
|
||
~Frame() = default;
|
||
|
||
/**
|
||
* @brief 载入用来拟合的数据
|
||
* @param data My Param doc
|
||
* @return int
|
||
*/
|
||
int load(const SampleWindow& data);
|
||
|
||
protected:
|
||
/**
|
||
* @brief 训练和测试回归,选择合适的回归函数
|
||
* @param gama gama
|
||
* @param tr_samples 样本-x
|
||
* @param tr_labels 样本y
|
||
*/
|
||
int train_and_test_regression(const double gama, const TrSample& tr_samples,
|
||
const Label& tr_labels);
|
||
|
||
public:
|
||
/**
|
||
* @brief 预测回归的结果
|
||
* @param data 输入-x
|
||
* @return double
|
||
*/
|
||
double predict(const SamplePoint& data);
|
||
|
||
/**
|
||
* @brief 回归函数是否有效,如果无效,则不进行计算
|
||
* @return true
|
||
* @return false
|
||
*/
|
||
bool valid();
|
||
};
|
||
|
||
} // namespace regression
|