100 lines
2.1 KiB
C
100 lines
2.1 KiB
C
|
|
#pragma once
|
||
|
|
/**
|
||
|
|
* @file loss_compress.h
|
||
|
|
* @brief
|
||
|
|
* @author Cat (null.null.null@qq.com)
|
||
|
|
* @version 0.1
|
||
|
|
* @date 2021-08-18
|
||
|
|
*
|
||
|
|
* Copyright: Baosight Co. Ltd.
|
||
|
|
* DO NOT COPY/USE WITHOUT PERMISSION
|
||
|
|
*
|
||
|
|
*/
|
||
|
|
#include <vector>
|
||
|
|
#include <map>
|
||
|
|
#include <tuple>
|
||
|
|
#include <utility>
|
||
|
|
#include <random>
|
||
|
|
#include <numeric>
|
||
|
|
#include <limits>
|
||
|
|
#include <set>
|
||
|
|
#include <algorithm>
|
||
|
|
|
||
|
|
namespace distribution {
|
||
|
|
|
||
|
|
class LossCompressSingle {
|
||
|
|
public:
|
||
|
|
struct value_t {
|
||
|
|
double value;
|
||
|
|
double precision;
|
||
|
|
// 获取左右边界的函数
|
||
|
|
double get_left() const { return value - (precision / 2); }
|
||
|
|
double get_right() const { return value + (precision / 2); }
|
||
|
|
|
||
|
|
bool operator<(const value_t& rhs) const {
|
||
|
|
return this->get_right() - rhs.get_left() < 0.01;
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
protected:
|
||
|
|
typedef std::pair<value_t, size_t> v_pair_t;
|
||
|
|
|
||
|
|
public:
|
||
|
|
typedef std::map<value_t, size_t> data_t;
|
||
|
|
typedef std::vector<v_pair_t> exchange_data_t;
|
||
|
|
|
||
|
|
protected:
|
||
|
|
data_t data_;
|
||
|
|
|
||
|
|
bool is_first_commit_ = false;
|
||
|
|
size_t first_commit_counts_ = 0;
|
||
|
|
std::vector<double> data_to_commit_;
|
||
|
|
|
||
|
|
exchange_data_t insert_list_;
|
||
|
|
exchange_data_t delete_list_;
|
||
|
|
data_t update_list_;
|
||
|
|
|
||
|
|
size_t total_size_; // 压缩前总样本大小
|
||
|
|
constexpr static size_t dest_data_size_min = 200; // 目标最小样本数量
|
||
|
|
constexpr static size_t dest_data_size_max = 15000; // 目标最大样本数量
|
||
|
|
constexpr static size_t dest_decompress_size = 1000; // 解压缩后目标样本大小
|
||
|
|
double scale_; // 解压缩缩放比例
|
||
|
|
|
||
|
|
int data_sub_div();
|
||
|
|
|
||
|
|
int data_merge();
|
||
|
|
|
||
|
|
public:
|
||
|
|
std::vector<double> decompress_data();
|
||
|
|
|
||
|
|
std::vector<double> decompress_data_full();
|
||
|
|
|
||
|
|
int Store(double value);
|
||
|
|
|
||
|
|
int set_data(const exchange_data_t& data);
|
||
|
|
|
||
|
|
int reforge_precision();
|
||
|
|
|
||
|
|
int commit();
|
||
|
|
|
||
|
|
exchange_data_t receive_delete_list();
|
||
|
|
|
||
|
|
exchange_data_t receive_insert_list();
|
||
|
|
|
||
|
|
exchange_data_t receive_update_list();
|
||
|
|
|
||
|
|
exchange_data_t get_ordered_compress_pack();
|
||
|
|
|
||
|
|
bool clear_changed_list();
|
||
|
|
|
||
|
|
size_t get_total_size();
|
||
|
|
|
||
|
|
size_t get_scaled_size();
|
||
|
|
|
||
|
|
public:
|
||
|
|
LossCompressSingle(/* args */);
|
||
|
|
~LossCompressSingle();
|
||
|
|
};
|
||
|
|
|
||
|
|
} // namespace distribution
|