eis/eqpalg/.do_not_use/krls.h.x

72 lines
2.1 KiB
Plaintext

/**
* @file krls.h
* @brief krls
* @author Cat (null.null.null@qq.com)
* @version 0.1
* @date 2021-09-17
*
* Copyright: Baosight Co. Ltd.
* DO NOT COPY/USE WITHOUT PERMISSION
*
*/
#pragma once
#include <eqpalg/regression/train_result.h>
#include "mix_cc/algorithm/split.h"
#include <vector>
#include <memory>
namespace regression {
using namespace dlib;
/**
* @brief krls 回归预测
* @tparam rows
* @tparam cols
*/
template <size_t rows = 0, size_t cols = 0>
class krls_t {
public:
using SampleType = matrix<double, rows, cols>;
using KernelType = radial_basis_kernel<SampleType>;
using SamplesType = std::vector<SampleType>;
using LabelsType = std::vector<double>;
using TrainedResult = TrainResult<krls<KernelType>, rows, cols>;
private:
double split_precent_;
public:
krls_t() { this->split_precent_ = 0.7; }
explicit krls_t(double split_percent) {
this->split_precent_ = split_percent;
}
/**
* @brief 学习样本,产生判断函数
* @param gama My Param doc
* @param tr_samples My Param doc
* @param tr_labels My Param doc
* @return TrainedResult
*/
TrainedResult learn(const double gama, const SamplesType& tr_samples,
const LabelsType& tr_labels) {
krls<KernelType> trainer(KernelType(0.1), 0.001);
size_t split_size = tr_samples.size() * (split_precent_);
auto splitted_sample = mix_cc::split(tr_samples, {split_size});
auto splitted_label = mix_cc::split(tr_labels, {split_size});
auto train_size = splitted_sample[0].size();
for (size_t i = 0; i < train_size; i++) {
trainer.train(splitted_sample[0][i], splitted_label[0][i]);
}
TrainedResult result;
result.set_predict_func(std::make_unique<krls<KernelType>>(trainer));
std::vector<double> loo_values;
auto valide_size = splitted_sample[1].size();
for (auto i = 0; i < valide_size; i++) {
loo_values.push_back(trainer(splitted_sample[1][i]));
}
result.err_rate = mean_squared_error(splitted_label[1], loo_values);
return result;
}
};
} // namespace regression