134 lines
3.3 KiB
C++
134 lines
3.3 KiB
C++
#pragma once
|
||
#include <chrono>
|
||
#include <iomanip>
|
||
#include <iostream>
|
||
#include <limits>
|
||
#include <random>
|
||
#include <string>
|
||
|
||
class XorShift128Plus {
|
||
private:
|
||
uint64_t s[2];
|
||
|
||
// 增强的字符串哈希
|
||
static std::pair<uint64_t, uint64_t>
|
||
hash_string_pair_std(const std::string &str) {
|
||
std::hash<std::string> hasher;
|
||
|
||
// 第一个哈希:原始字符串
|
||
uint64_t hash1 = static_cast<uint64_t>(hasher(str));
|
||
|
||
// 第二个哈希:反转字符串
|
||
std::string reversed = str;
|
||
std::reverse(reversed.begin(), reversed.end());
|
||
uint64_t hash2 = static_cast<uint64_t>(hasher(reversed));
|
||
|
||
// 如果字符串太短,添加额外混合
|
||
if (str.length() < 4) {
|
||
hash1 ^= 0x123456789ABCDEF0ULL;
|
||
hash2 ^= 0x0FEDCBA987654321ULL;
|
||
}
|
||
|
||
return {hash1, hash2};
|
||
}
|
||
|
||
public:
|
||
// 构造函数1:使用字符串作为种子
|
||
XorShift128Plus(const std::string &seed_str) {
|
||
auto seeds = hash_string_pair_std(seed_str);
|
||
s[0] = seeds.first;
|
||
s[1] = seeds.second;
|
||
|
||
if (s[0] == 0 && s[1] == 0) {
|
||
s[0] = 0x123456789ABCDEF0ULL;
|
||
s[1] = 0x0FEDCBA987654321ULL;
|
||
}
|
||
|
||
// 预热,确保随机性
|
||
for (int i = 0; i < 10; ++i) {
|
||
next();
|
||
}
|
||
}
|
||
|
||
// 构造函数2:使用两个64位整数作为种子
|
||
XorShift128Plus(uint64_t seed1 = 0, uint64_t seed2 = 0) {
|
||
if (seed1 == 0 && seed2 == 0) {
|
||
std::random_device rd;
|
||
// 正确组合rd()的结果到64位
|
||
s[0] = (static_cast<uint64_t>(rd()) << 32) | rd();
|
||
s[1] = (static_cast<uint64_t>(rd()) << 32) | rd();
|
||
} else {
|
||
s[0] = seed1;
|
||
s[1] = seed2;
|
||
}
|
||
|
||
if (s[0] == 0 && s[1] == 0) {
|
||
s[0] = 0x123456789ABCDEF0ULL;
|
||
s[1] = 0x0FEDCBA987654321ULL;
|
||
}
|
||
|
||
// 预热
|
||
for (int i = 0; i < 10; ++i) {
|
||
next();
|
||
}
|
||
}
|
||
|
||
// 正确的XorShift128Plus算法
|
||
uint64_t next() {
|
||
uint64_t x = s[0];
|
||
uint64_t y = s[1];
|
||
s[0] = y;
|
||
x ^= x << 23;
|
||
x ^= x >> 17;
|
||
x ^= y ^ (y >> 26);
|
||
s[1] = x;
|
||
return s[1] + y;
|
||
}
|
||
|
||
// 获取[min, max]范围内的随机整数(无偏)
|
||
int get_int(int min, int max) {
|
||
if (min >= max)
|
||
return min;
|
||
|
||
uint32_t range = static_cast<uint32_t>(max - min + 1);
|
||
|
||
// 拒绝采样法避免取模偏差
|
||
uint32_t limit = std::numeric_limits<uint32_t>::max() -
|
||
(std::numeric_limits<uint32_t>::max() % range);
|
||
uint32_t raw;
|
||
|
||
do {
|
||
// 取next()的高32位(质量更好)
|
||
raw = static_cast<uint32_t>(next() >> 32);
|
||
} while (raw >= limit);
|
||
|
||
return min + static_cast<int>(raw % range);
|
||
}
|
||
|
||
// 快速版本:当range是2的幂时使用
|
||
int get_int_fast(int min, int max) {
|
||
uint32_t range = max - min + 1;
|
||
|
||
// 检查range是否是2的幂
|
||
if (range > 0 && (range & (range - 1)) == 0) {
|
||
uint64_t raw = next();
|
||
return min + static_cast<int>(raw & (range - 1));
|
||
}
|
||
|
||
// 否则使用通用版本
|
||
return get_int(min, max);
|
||
}
|
||
|
||
// 获取[0, 1)范围内的随机浮点数
|
||
double get_double() {
|
||
// 使用52位精度(IEEE 754双精度的尾数位)
|
||
uint64_t bits = (next() >> 12) | 0x3FF0000000000000ULL;
|
||
return *reinterpret_cast<double *>(&bits) - 1.0;
|
||
}
|
||
|
||
// 获取当前状态(用于调试)
|
||
void get_state(uint64_t &s0, uint64_t &s1) const {
|
||
s0 = s[0];
|
||
s1 = s[1];
|
||
}
|
||
}; |