100 lines
2.5 KiB
C
100 lines
2.5 KiB
C
|
|
/**
|
|||
|
|
* @file eqpalg/regression/frame.h
|
|||
|
|
* @brief 回归预测系统的框架
|
|||
|
|
* @author Cat (null.null.null@qq.com)
|
|||
|
|
* @version 0.1
|
|||
|
|
* @date 2021-07-08
|
|||
|
|
*
|
|||
|
|
* Company: Baosight Co. Ltd.
|
|||
|
|
* DO NOT COPY/USE WITHOUT PERMISSION
|
|||
|
|
*
|
|||
|
|
*/
|
|||
|
|
#pragma once
|
|||
|
|
|
|||
|
|
#include <vector>
|
|||
|
|
#include <iostream>
|
|||
|
|
#include <dlib/svm.h>
|
|||
|
|
#include <log4cplus/LOG.h>
|
|||
|
|
#include <chrono>
|
|||
|
|
#include <utility>
|
|||
|
|
#include <tuple>
|
|||
|
|
#include <eqpalg/regression/krr.h>
|
|||
|
|
#include <eqpalg/regression/train_result.h>
|
|||
|
|
namespace regression {
|
|||
|
|
using namespace std;
|
|||
|
|
using namespace dlib;
|
|||
|
|
|
|||
|
|
/**
|
|||
|
|
* @brief 回归系统的主框架
|
|||
|
|
*/
|
|||
|
|
struct Frame {
|
|||
|
|
// 用来为dlib提供输入的数据类型
|
|||
|
|
using SampleType = matrix<
|
|||
|
|
double>; ///< 因为有一维数据作为label(下方的LabelsType),所以需要输入的rows
|
|||
|
|
///< -1
|
|||
|
|
using Label = std::vector<double>; ///< 用来当作拟合目标数据y的输出集合
|
|||
|
|
using TrSample = std::vector<SampleType>; ///< 拟合数据x的输入集合
|
|||
|
|
|
|||
|
|
private:
|
|||
|
|
std::chrono::system_clock::duration train_time_cost_; ///< 训练所花费时间
|
|||
|
|
std::chrono::system_clock::duration
|
|||
|
|
regression_func_time_cost_; ///< 回归函数所花费时间
|
|||
|
|
/**
|
|||
|
|
* @todo
|
|||
|
|
* 上方的花费时间暂时没有使用,未来可以将这两个花费时间信息
|
|||
|
|
* 作为评判拟合优劣的一个选项
|
|||
|
|
*/
|
|||
|
|
|
|||
|
|
// 设置两种回归函数
|
|||
|
|
std::unique_ptr<Krr> krr_;
|
|||
|
|
// typename krls_t<rows - 1, cols>::TrainedResult krls_result_;
|
|||
|
|
|
|||
|
|
int best_regression_type = -1; ///< 默认回归函数无效
|
|||
|
|
|
|||
|
|
const std::string rule_id_;
|
|||
|
|
|
|||
|
|
const size_t dims_;
|
|||
|
|
|
|||
|
|
public:
|
|||
|
|
Frame(const std::string& ruleId, size_t dims);
|
|||
|
|
|
|||
|
|
/**
|
|||
|
|
* @brief Destroy the Frame object
|
|||
|
|
*/
|
|||
|
|
~Frame() = default;
|
|||
|
|
|
|||
|
|
/**
|
|||
|
|
* @brief 载入用来拟合的数据
|
|||
|
|
* @param data My Param doc
|
|||
|
|
* @return int
|
|||
|
|
*/
|
|||
|
|
int load(const SampleWindow& data);
|
|||
|
|
|
|||
|
|
protected:
|
|||
|
|
/**
|
|||
|
|
* @brief 训练和测试回归,选择合适的回归函数
|
|||
|
|
* @param gama gama
|
|||
|
|
* @param tr_samples 样本-x
|
|||
|
|
* @param tr_labels 样本y
|
|||
|
|
*/
|
|||
|
|
int train_and_test_regression(const double gama, const TrSample& tr_samples,
|
|||
|
|
const Label& tr_labels);
|
|||
|
|
|
|||
|
|
public:
|
|||
|
|
/**
|
|||
|
|
* @brief 预测回归的结果
|
|||
|
|
* @param data 输入-x
|
|||
|
|
* @return double
|
|||
|
|
*/
|
|||
|
|
double predict(const SamplePoint& data);
|
|||
|
|
|
|||
|
|
/**
|
|||
|
|
* @brief 回归函数是否有效,如果无效,则不进行计算
|
|||
|
|
* @return true
|
|||
|
|
* @return false
|
|||
|
|
*/
|
|||
|
|
bool valid();
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
} // namespace regression
|