eis/eqpalg/.do_not_use/regression/krr.h

64 lines
1.7 KiB
C++

/**
* @file eqpalg/regression/krr.h
* @brief krr
* @author Cat (null.null.null@qq.com)
* @version 0.1
* @date 2021-09-17
*
* Copyright: Baosight Co. Ltd.
* DO NOT COPY/USE WITHOUT PERMISSION
*
*/
#pragma once
#include <eqpalg/regression/train_result.h>
#include "mix_cc/algorithm/split.h"
#include <eqpalg/define/sample.h>
#include <vector>
#include <memory>
namespace regression {
using namespace dlib;
/**
* @brief krr
* @tparam rows
* @tparam cols
*/
class Krr {
public:
//
using SampleType = matrix<double>;
using KernelType = radial_basis_kernel<SampleType>;
private:
double split_precent_;
const size_t dims_;
std::unique_ptr<decision_function<KernelType>> predict_func_;
double err_rate_;
public:
Krr(size_t dims, const double gama, const std::vector<SampleType>& tr_samples,
const std::vector<double>& tr_labels)
: dims_(dims) {
this->split_precent_ = 0.7;
krr_trainer<KernelType> trainer;
trainer.set_kernel(KernelType(gama));
size_t split_size = tr_samples.size() * (split_precent_);
auto splitted_sample = mix_cc::split(tr_samples, {split_size});
auto splitted_label = mix_cc::split(tr_labels, {split_size});
predict_func_ = std::make_unique<decision_function<KernelType>>(
trainer.train(splitted_sample[0], splitted_label[0]));
std::vector<double> loo_values;
trainer.train(splitted_sample[1], splitted_label[1], loo_values);
err_rate_ = mean_squared_error(splitted_label[1], loo_values);
}
double predict(SamplePoint input) {
dlib::matrix<double> input_convert(dims_,1);
for (size_t i = 0; i < dims_; i++) {
input_convert(i) = input[i];
}
return (*predict_func_)(input_convert);
}
};
} // namespace regression