72 lines
2.1 KiB
Plaintext
72 lines
2.1 KiB
Plaintext
/**
|
|
* @file krls.h
|
|
* @brief krls
|
|
* @author Cat (null.null.null@qq.com)
|
|
* @version 0.1
|
|
* @date 2021-09-17
|
|
*
|
|
* Copyright: Baosight Co. Ltd.
|
|
* DO NOT COPY/USE WITHOUT PERMISSION
|
|
*
|
|
*/
|
|
#pragma once
|
|
#include <eqpalg/regression/train_result.h>
|
|
#include "mix_cc/algorithm/split.h"
|
|
#include <vector>
|
|
#include <memory>
|
|
namespace regression {
|
|
using namespace dlib;
|
|
/**
|
|
* @brief krls 回归预测
|
|
* @tparam rows
|
|
* @tparam cols
|
|
*/
|
|
template <size_t rows = 0, size_t cols = 0>
|
|
class krls_t {
|
|
public:
|
|
using SampleType = matrix<double, rows, cols>;
|
|
using KernelType = radial_basis_kernel<SampleType>;
|
|
using SamplesType = std::vector<SampleType>;
|
|
using LabelsType = std::vector<double>;
|
|
using TrainedResult = TrainResult<krls<KernelType>, rows, cols>;
|
|
|
|
private:
|
|
double split_precent_;
|
|
|
|
public:
|
|
krls_t() { this->split_precent_ = 0.7; }
|
|
|
|
explicit krls_t(double split_percent) {
|
|
this->split_precent_ = split_percent;
|
|
}
|
|
|
|
/**
|
|
* @brief 学习样本,产生判断函数
|
|
* @param gama My Param doc
|
|
* @param tr_samples My Param doc
|
|
* @param tr_labels My Param doc
|
|
* @return TrainedResult
|
|
*/
|
|
TrainedResult learn(const double gama, const SamplesType& tr_samples,
|
|
const LabelsType& tr_labels) {
|
|
krls<KernelType> trainer(KernelType(0.1), 0.001);
|
|
size_t split_size = tr_samples.size() * (split_precent_);
|
|
auto splitted_sample = mix_cc::split(tr_samples, {split_size});
|
|
auto splitted_label = mix_cc::split(tr_labels, {split_size});
|
|
auto train_size = splitted_sample[0].size();
|
|
for (size_t i = 0; i < train_size; i++) {
|
|
trainer.train(splitted_sample[0][i], splitted_label[0][i]);
|
|
}
|
|
TrainedResult result;
|
|
result.set_predict_func(std::make_unique<krls<KernelType>>(trainer));
|
|
std::vector<double> loo_values;
|
|
auto valide_size = splitted_sample[1].size();
|
|
for (auto i = 0; i < valide_size; i++) {
|
|
loo_values.push_back(trainer(splitted_sample[1][i]));
|
|
}
|
|
result.err_rate = mean_squared_error(splitted_label[1], loo_values);
|
|
return result;
|
|
}
|
|
};
|
|
} // namespace regression
|