64 lines
1.7 KiB
C++
64 lines
1.7 KiB
C++
/**
|
|
* @file eqpalg/regression/krr.h
|
|
* @brief krr
|
|
* @author Cat (null.null.null@qq.com)
|
|
* @version 0.1
|
|
* @date 2021-09-17
|
|
*
|
|
* Copyright: Baosight Co. Ltd.
|
|
* DO NOT COPY/USE WITHOUT PERMISSION
|
|
*
|
|
*/
|
|
#pragma once
|
|
#include <eqpalg/regression/train_result.h>
|
|
#include "mix_cc/algorithm/split.h"
|
|
#include <eqpalg/define/sample.h>
|
|
#include <vector>
|
|
#include <memory>
|
|
namespace regression {
|
|
using namespace dlib;
|
|
/**
|
|
* @brief krr
|
|
* @tparam rows
|
|
* @tparam cols
|
|
*/
|
|
class Krr {
|
|
public:
|
|
//
|
|
using SampleType = matrix<double>;
|
|
using KernelType = radial_basis_kernel<SampleType>;
|
|
|
|
private:
|
|
double split_precent_;
|
|
const size_t dims_;
|
|
std::unique_ptr<decision_function<KernelType>> predict_func_;
|
|
|
|
double err_rate_;
|
|
|
|
public:
|
|
Krr(size_t dims, const double gama, const std::vector<SampleType>& tr_samples,
|
|
const std::vector<double>& tr_labels)
|
|
: dims_(dims) {
|
|
this->split_precent_ = 0.7;
|
|
krr_trainer<KernelType> trainer;
|
|
trainer.set_kernel(KernelType(gama));
|
|
size_t split_size = tr_samples.size() * (split_precent_);
|
|
auto splitted_sample = mix_cc::split(tr_samples, {split_size});
|
|
auto splitted_label = mix_cc::split(tr_labels, {split_size});
|
|
predict_func_ = std::make_unique<decision_function<KernelType>>(
|
|
trainer.train(splitted_sample[0], splitted_label[0]));
|
|
std::vector<double> loo_values;
|
|
trainer.train(splitted_sample[1], splitted_label[1], loo_values);
|
|
err_rate_ = mean_squared_error(splitted_label[1], loo_values);
|
|
}
|
|
|
|
double predict(SamplePoint input) {
|
|
dlib::matrix<double> input_convert(dims_,1);
|
|
for (size_t i = 0; i < dims_; i++) {
|
|
input_convert(i) = input[i];
|
|
}
|
|
return (*predict_func_)(input_convert);
|
|
}
|
|
};
|
|
} // namespace regression
|