102 lines
2.3 KiB
C++
102 lines
2.3 KiB
C++
/**
|
|
* @file eqpalg/oneClassSvm/frame.h
|
|
* @brief oneClassSvm的框架
|
|
* @author Cat (null.null.null@qq.com)
|
|
* @version 0.1
|
|
* @date 2021-07-08
|
|
*
|
|
* Company: Baosight Co. Ltd.
|
|
* DO NOT COPY/USE WITHOUT PERMISSION
|
|
*
|
|
*/
|
|
#pragma once
|
|
|
|
#include <iostream>
|
|
#include <dlib/svm.h>
|
|
#include <log4cplus/LOG.h>
|
|
#include <chrono>
|
|
#include <utility>
|
|
#include <string>
|
|
//#include <dlib/svm.h>
|
|
#include <dlib/array2d.h>
|
|
#include <vector>
|
|
#include <tuple>
|
|
// #include <eqpalg/regression/krls.h>
|
|
// #include <eqpalg/regression/krr.h>
|
|
// #include <eqpalg/regression/train_result.h>
|
|
//#include <eqpalg/oneClassSvm/oneSVM.hpp>
|
|
|
|
namespace oneClassSvm {
|
|
|
|
using std::tuple;
|
|
using std::vector;
|
|
using namespace dlib;
|
|
using sample_type = matrix<double, 0, 1>;
|
|
using kernel_type = radial_basis_kernel<sample_type>;
|
|
using TrainResult = decision_function<kernel_type>;
|
|
using SamplesType = vector<sample_type>;
|
|
TrainResult decision_function_get(SamplesType samples);
|
|
|
|
/**
|
|
* @brief one class svm的主框架
|
|
* @tparam rows
|
|
*/
|
|
template <size_t dims>
|
|
struct Frame {
|
|
using ExchangedMetaData =
|
|
std::array<double, dims>; ///< 用来做数据输入的单元数据类型
|
|
using InData =
|
|
std::vector<ExchangedMetaData>; ///< 用来做数据输入的数组数据类型
|
|
|
|
|
|
|
|
Frame(const std::string& ruleId) { this->rule_id_ = ruleId; }
|
|
|
|
TrainResult svm_result_; ///< 训练得到的svm判别函数
|
|
|
|
private:
|
|
std::string rule_id_; //< 算法实例id
|
|
public:
|
|
/**
|
|
* @brief Construct a new Frame object
|
|
* @param ruleId My Param doc
|
|
*/
|
|
// Frame(const std::string& ruleId) { this->rule_id_ = ruleId; }
|
|
|
|
/**
|
|
* @brief Destroy the Frame object
|
|
*/
|
|
~Frame() = default;
|
|
|
|
/**
|
|
* @brief 载入用来训练的数据
|
|
* @param data My Param doc
|
|
* @return int
|
|
*/
|
|
int load(const InData& data);
|
|
|
|
protected:
|
|
/**
|
|
* @brief 训练svm模型
|
|
* @param tr_samples 样本-x
|
|
* @return df函数
|
|
*/
|
|
TrainResult& train_get_df_one_class_svm(SamplesType& tr_samples);
|
|
|
|
public:
|
|
/**
|
|
* @brief 预测svm的结果
|
|
* @param data 输入-x
|
|
* @return double
|
|
*/
|
|
double predict(ExchangedMetaData& data);
|
|
double standard_error; //< 误差标准 train_get_df_one_class_svm后得到
|
|
|
|
/**
|
|
* @brief valid
|
|
*/
|
|
bool valid();
|
|
};
|
|
|
|
} // namespace oneClassSvm
|