/** * @file krls.h * @brief krls * @author Cat (null.null.null@qq.com) * @version 0.1 * @date 2021-09-17 * * Copyright: Baosight Co. Ltd. * DO NOT COPY/USE WITHOUT PERMISSION * */ #pragma once #include #include "mix_cc/algorithm/split.h" #include #include namespace regression { using namespace dlib; /** * @brief krls 回归预测 * @tparam rows * @tparam cols */ template class krls_t { public: using SampleType = matrix; using KernelType = radial_basis_kernel; using SamplesType = std::vector; using LabelsType = std::vector; using TrainedResult = TrainResult, rows, cols>; private: double split_precent_; public: krls_t() { this->split_precent_ = 0.7; } explicit krls_t(double split_percent) { this->split_precent_ = split_percent; } /** * @brief 学习样本,产生判断函数 * @param gama My Param doc * @param tr_samples My Param doc * @param tr_labels My Param doc * @return TrainedResult */ TrainedResult learn(const double gama, const SamplesType& tr_samples, const LabelsType& tr_labels) { krls trainer(KernelType(0.1), 0.001); size_t split_size = tr_samples.size() * (split_precent_); auto splitted_sample = mix_cc::split(tr_samples, {split_size}); auto splitted_label = mix_cc::split(tr_labels, {split_size}); auto train_size = splitted_sample[0].size(); for (size_t i = 0; i < train_size; i++) { trainer.train(splitted_sample[0][i], splitted_label[0][i]); } TrainedResult result; result.set_predict_func(std::make_unique>(trainer)); std::vector loo_values; auto valide_size = splitted_sample[1].size(); for (auto i = 0; i < valide_size; i++) { loo_values.push_back(trainer(splitted_sample[1][i])); } result.err_rate = mean_squared_error(splitted_label[1], loo_values); return result; } }; } // namespace regression