eis/eqpalg/utility/XorShift128Plus.hpp

134 lines
3.3 KiB
C++
Raw Normal View History

#pragma once
#include <chrono>
#include <iomanip>
#include <iostream>
#include <limits>
#include <random>
#include <string>
class XorShift128Plus {
private:
uint64_t s[2];
// 增强的字符串哈希
static std::pair<uint64_t, uint64_t>
hash_string_pair_std(const std::string &str) {
std::hash<std::string> hasher;
// 第一个哈希:原始字符串
uint64_t hash1 = static_cast<uint64_t>(hasher(str));
// 第二个哈希:反转字符串
std::string reversed = str;
std::reverse(reversed.begin(), reversed.end());
uint64_t hash2 = static_cast<uint64_t>(hasher(reversed));
// 如果字符串太短,添加额外混合
if (str.length() < 4) {
hash1 ^= 0x123456789ABCDEF0ULL;
hash2 ^= 0x0FEDCBA987654321ULL;
}
return {hash1, hash2};
}
public:
// 构造函数1使用字符串作为种子
XorShift128Plus(const std::string &seed_str) {
auto seeds = hash_string_pair_std(seed_str);
s[0] = seeds.first;
s[1] = seeds.second;
if (s[0] == 0 && s[1] == 0) {
s[0] = 0x123456789ABCDEF0ULL;
s[1] = 0x0FEDCBA987654321ULL;
}
// 预热,确保随机性
for (int i = 0; i < 10; ++i) {
next();
}
}
// 构造函数2使用两个64位整数作为种子
XorShift128Plus(uint64_t seed1 = 0, uint64_t seed2 = 0) {
if (seed1 == 0 && seed2 == 0) {
std::random_device rd;
// 正确组合rd()的结果到64位
s[0] = (static_cast<uint64_t>(rd()) << 32) | rd();
s[1] = (static_cast<uint64_t>(rd()) << 32) | rd();
} else {
s[0] = seed1;
s[1] = seed2;
}
if (s[0] == 0 && s[1] == 0) {
s[0] = 0x123456789ABCDEF0ULL;
s[1] = 0x0FEDCBA987654321ULL;
}
// 预热
for (int i = 0; i < 10; ++i) {
next();
}
}
// 正确的XorShift128Plus算法
uint64_t next() {
uint64_t x = s[0];
uint64_t y = s[1];
s[0] = y;
x ^= x << 23;
x ^= x >> 17;
x ^= y ^ (y >> 26);
s[1] = x;
return s[1] + y;
}
// 获取[min, max]范围内的随机整数(无偏)
int get_int(int min, int max) {
if (min >= max)
return min;
uint32_t range = static_cast<uint32_t>(max - min + 1);
// 拒绝采样法避免取模偏差
uint32_t limit = std::numeric_limits<uint32_t>::max() -
(std::numeric_limits<uint32_t>::max() % range);
uint32_t raw;
do {
// 取next()的高32位质量更好
raw = static_cast<uint32_t>(next() >> 32);
} while (raw >= limit);
return min + static_cast<int>(raw % range);
}
// 快速版本当range是2的幂时使用
int get_int_fast(int min, int max) {
uint32_t range = max - min + 1;
// 检查range是否是2的幂
if (range > 0 && (range & (range - 1)) == 0) {
uint64_t raw = next();
return min + static_cast<int>(raw & (range - 1));
}
// 否则使用通用版本
return get_int(min, max);
}
// 获取[0, 1)范围内的随机浮点数
double get_double() {
// 使用52位精度IEEE 754双精度的尾数位
uint64_t bits = (next() >> 12) | 0x3FF0000000000000ULL;
return *reinterpret_cast<double *>(&bits) - 1.0;
}
// 获取当前状态(用于调试)
void get_state(uint64_t &s0, uint64_t &s1) const {
s0 = s[0];
s1 = s[1];
}
};