eis/eqpalg/utility/XorShift128Plus.hpp

134 lines
3.3 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#pragma once
#include <chrono>
#include <iomanip>
#include <iostream>
#include <limits>
#include <random>
#include <string>
class XorShift128Plus {
private:
uint64_t s[2];
// 增强的字符串哈希
static std::pair<uint64_t, uint64_t>
hash_string_pair_std(const std::string &str) {
std::hash<std::string> hasher;
// 第一个哈希:原始字符串
uint64_t hash1 = static_cast<uint64_t>(hasher(str));
// 第二个哈希:反转字符串
std::string reversed = str;
std::reverse(reversed.begin(), reversed.end());
uint64_t hash2 = static_cast<uint64_t>(hasher(reversed));
// 如果字符串太短,添加额外混合
if (str.length() < 4) {
hash1 ^= 0x123456789ABCDEF0ULL;
hash2 ^= 0x0FEDCBA987654321ULL;
}
return {hash1, hash2};
}
public:
// 构造函数1使用字符串作为种子
XorShift128Plus(const std::string &seed_str) {
auto seeds = hash_string_pair_std(seed_str);
s[0] = seeds.first;
s[1] = seeds.second;
if (s[0] == 0 && s[1] == 0) {
s[0] = 0x123456789ABCDEF0ULL;
s[1] = 0x0FEDCBA987654321ULL;
}
// 预热,确保随机性
for (int i = 0; i < 10; ++i) {
next();
}
}
// 构造函数2使用两个64位整数作为种子
XorShift128Plus(uint64_t seed1 = 0, uint64_t seed2 = 0) {
if (seed1 == 0 && seed2 == 0) {
std::random_device rd;
// 正确组合rd()的结果到64位
s[0] = (static_cast<uint64_t>(rd()) << 32) | rd();
s[1] = (static_cast<uint64_t>(rd()) << 32) | rd();
} else {
s[0] = seed1;
s[1] = seed2;
}
if (s[0] == 0 && s[1] == 0) {
s[0] = 0x123456789ABCDEF0ULL;
s[1] = 0x0FEDCBA987654321ULL;
}
// 预热
for (int i = 0; i < 10; ++i) {
next();
}
}
// 正确的XorShift128Plus算法
uint64_t next() {
uint64_t x = s[0];
uint64_t y = s[1];
s[0] = y;
x ^= x << 23;
x ^= x >> 17;
x ^= y ^ (y >> 26);
s[1] = x;
return s[1] + y;
}
// 获取[min, max]范围内的随机整数(无偏)
int get_int(int min, int max) {
if (min >= max)
return min;
uint32_t range = static_cast<uint32_t>(max - min + 1);
// 拒绝采样法避免取模偏差
uint32_t limit = std::numeric_limits<uint32_t>::max() -
(std::numeric_limits<uint32_t>::max() % range);
uint32_t raw;
do {
// 取next()的高32位质量更好
raw = static_cast<uint32_t>(next() >> 32);
} while (raw >= limit);
return min + static_cast<int>(raw % range);
}
// 快速版本当range是2的幂时使用
int get_int_fast(int min, int max) {
uint32_t range = max - min + 1;
// 检查range是否是2的幂
if (range > 0 && (range & (range - 1)) == 0) {
uint64_t raw = next();
return min + static_cast<int>(raw & (range - 1));
}
// 否则使用通用版本
return get_int(min, max);
}
// 获取[0, 1)范围内的随机浮点数
double get_double() {
// 使用52位精度IEEE 754双精度的尾数位
uint64_t bits = (next() >> 12) | 0x3FF0000000000000ULL;
return *reinterpret_cast<double *>(&bits) - 1.0;
}
// 获取当前状态(用于调试)
void get_state(uint64_t &s0, uint64_t &s1) const {
s0 = s[0];
s1 = s[1];
}
};