/** * @file eqpalg/regression/krr.h * @brief krr * @author Cat (null.null.null@qq.com) * @version 0.1 * @date 2021-09-17 * * Copyright: Baosight Co. Ltd. * DO NOT COPY/USE WITHOUT PERMISSION * */ #pragma once #include #include "mix_cc/algorithm/split.h" #include #include #include namespace regression { using namespace dlib; /** * @brief krr * @tparam rows * @tparam cols */ class Krr { public: // using SampleType = matrix; using KernelType = radial_basis_kernel; private: double split_precent_; const size_t dims_; std::unique_ptr> predict_func_; double err_rate_; public: Krr(size_t dims, const double gama, const std::vector& tr_samples, const std::vector& tr_labels) : dims_(dims) { this->split_precent_ = 0.7; krr_trainer trainer; trainer.set_kernel(KernelType(gama)); size_t split_size = tr_samples.size() * (split_precent_); auto splitted_sample = mix_cc::split(tr_samples, {split_size}); auto splitted_label = mix_cc::split(tr_labels, {split_size}); predict_func_ = std::make_unique>( trainer.train(splitted_sample[0], splitted_label[0])); std::vector loo_values; trainer.train(splitted_sample[1], splitted_label[1], loo_values); err_rate_ = mean_squared_error(splitted_label[1], loo_values); } double predict(SamplePoint input) { dlib::matrix input_convert(dims_,1); for (size_t i = 0; i < dims_; i++) { input_convert(i) = input[i]; } return (*predict_func_)(input_convert); } }; } // namespace regression