/** * @file eqpalg/oneClassSvm/frame.h * @brief oneClassSvm的框架 * @author Cat (null.null.null@qq.com) * @version 0.1 * @date 2021-07-08 * * Company: Baosight Co. Ltd. * DO NOT COPY/USE WITHOUT PERMISSION * */ #pragma once #include #include #include #include #include #include //#include #include #include #include // #include // #include // #include //#include namespace oneClassSvm { using std::tuple; using std::vector; using namespace dlib; using sample_type = matrix; using kernel_type = radial_basis_kernel; using TrainResult = decision_function; using SamplesType = vector; TrainResult decision_function_get(SamplesType samples); /** * @brief one class svm的主框架 * @tparam rows */ template struct Frame { using ExchangedMetaData = std::array; ///< 用来做数据输入的单元数据类型 using InData = std::vector; ///< 用来做数据输入的数组数据类型 Frame(const std::string& ruleId) { this->rule_id_ = ruleId; } TrainResult svm_result_; ///< 训练得到的svm判别函数 private: std::string rule_id_; //< 算法实例id public: /** * @brief Construct a new Frame object * @param ruleId My Param doc */ // Frame(const std::string& ruleId) { this->rule_id_ = ruleId; } /** * @brief Destroy the Frame object */ ~Frame() = default; /** * @brief 载入用来训练的数据 * @param data My Param doc * @return int */ int load(const InData& data); protected: /** * @brief 训练svm模型 * @param tr_samples 样本-x * @return df函数 */ TrainResult& train_get_df_one_class_svm(SamplesType& tr_samples); public: /** * @brief 预测svm的结果 * @param data 输入-x * @return double */ double predict(ExchangedMetaData& data); double standard_error; //< 误差标准 train_get_df_one_class_svm后得到 /** * @brief valid */ bool valid(); }; } // namespace oneClassSvm