vrd/xoshiro_simd.rs
1// Copyright © 2023-2026 vrd. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! SIMD-batched Xoshiro256++ for `fill_bytes`.
5//!
6//! Holds **K independent** Xoshiro256++ states (K = 2 on AArch64 NEON,
7//! K = 4 on x86_64 AVX2) in SIMD registers and advances all of them in
8//! one inner-loop iteration. Each `fill_bytes` call derives the K lane
9//! states by SplitMix64-whitening the scalar generator's state with a
10//! distinct lane-specific constant: cheap (~10 ns of setup), and
11//! statistically independent lanes by construction. The scalar state
12//! is advanced by the equivalent number of `next_u64` calls so that
13//! subsequent scalar calls remain consistent with the scalar-only
14//! path.
15//!
16//! An earlier draft used [`crate::xoshiro::Xoshiro256PlusPlus::jump`]
17//! for 2¹²⁸-step
18//! separation per lane, but at 256 scalar `next_u64`s per call its
19//! ~256 ns setup wiped out the SIMD win for buffers under ~4 KiB. The
20//! SplitMix derivation keeps lanes uncorrelated (probability of state
21//! collision is ≤ K²/2²⁵⁶ - negligible) at a fraction of the cost.
22//!
23//! # Reproducibility contract
24//!
25//! The same seed produces a **different byte stream** between the
26//! scalar path and the SIMD path. This is fundamental: there is no
27//! correctness-preserving way to interleave K independent Xoshiro
28//! generators into the *same* sequence a single-threaded generator
29//! would produce. Code that depends on bit-for-bit reproducibility
30//! across feature sets must use the scalar path.
31//!
32//! Statistical quality is unchanged - each lane is a full Xoshiro256++
33//! and inherits all of its properties.
34
35#![allow(unsafe_code)]
36
37use crate::xoshiro::Xoshiro256PlusPlus;
38
39/// Derives K lane states from the scalar generator's current state.
40/// Lane K's seed material is the scalar state's first word XORed with
41/// a lane-specific 64-bit constant, then run through four SplitMix64
42/// rounds. Output: `K` arrays of `[u64; 4]`.
43///
44/// SplitMix64 constants per <https://prng.di.unimi.it/splitmix64.c>.
45#[inline]
46fn derive_lanes<const K: usize>(
47 rng: &Xoshiro256PlusPlus,
48) -> [[u64; 4]; K] {
49 const LANE_SALT: [u64; 4] = [
50 0xA076_1D64_78BD_642F,
51 0xE703_7ED1_A0B4_28DB,
52 0x8EBC_6AF0_9C88_C6E3,
53 0x5899_65CC_7537_4CC3,
54 ];
55 let base = rng.state_snapshot();
56 let mut out = [[0u64; 4]; K];
57 for (k, lane) in out.iter_mut().enumerate() {
58 let mut sm = base[0] ^ LANE_SALT[k];
59 for slot in lane.iter_mut() {
60 *slot = splitmix64(&mut sm);
61 }
62 // Mix in the rest of the base state so a lane's distribution
63 // tracks the full 256-bit scalar seed, not just word 0.
64 for (slot, &b) in lane.iter_mut().zip(base.iter()) {
65 *slot ^= b.rotate_left(((k as u32 + 1) * 13) % 64);
66 }
67 }
68 out
69}
70
71/// SplitMix64 - Stafford's variant 13. Used to whiten the
72/// per-lane seed material derived from the scalar generator's
73/// state so the SIMD lanes are statistically independent.
74#[inline]
75fn splitmix64(state: &mut u64) -> u64 {
76 *state = state.wrapping_add(0x9E37_79B9_7F4A_7C15);
77 let mut z = *state;
78 z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
79 z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
80 z ^ (z >> 31)
81}
82
83/// Below this many bytes, the SIMD setup cost outweighs the per-byte
84/// savings; fall back to the scalar generator instead. Calibrated for
85/// the SplitMix-derived lanes, not for jump-derived lanes (which had a
86/// ~4 KiB break-even).
87const SIMD_THRESHOLD: usize = 64;
88
89/// Fills `dest` with random bytes using the best SIMD path available
90/// for the active target. Falls back to the scalar generator for
91/// buffers smaller than 64 bytes (where setup cost beats the
92/// per-byte savings) and for the trailing bytes that don't fill a
93/// full SIMD register.
94///
95/// Three target-specific definitions live below; rustc only
96/// compiles the one matching the active `target_arch`.
97#[cfg(target_arch = "aarch64")]
98#[inline]
99pub fn fill_bytes(rng: &mut Xoshiro256PlusPlus, dest: &mut [u8]) {
100 if dest.len() < SIMD_THRESHOLD {
101 rng.fill_bytes_scalar(dest);
102 return;
103 }
104 aarch64::fill_bytes_neon(rng, dest);
105}
106
107/// x86_64 dispatch: prefer AVX2 if the CPU supports it; else scalar.
108#[cfg(target_arch = "x86_64")]
109#[inline]
110pub fn fill_bytes(rng: &mut Xoshiro256PlusPlus, dest: &mut [u8]) {
111 if dest.len() < SIMD_THRESHOLD || !is_avx2_available() {
112 rng.fill_bytes_scalar(dest);
113 return;
114 }
115 // SAFETY: gated on runtime AVX2 detection above.
116 unsafe { x86_64::fill_bytes_avx2(rng, dest) };
117}
118
119/// Fallback for architectures without a SIMD path.
120#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
121#[inline]
122pub fn fill_bytes(rng: &mut Xoshiro256PlusPlus, dest: &mut [u8]) {
123 rng.fill_bytes_scalar(dest);
124}
125
126/// Runtime AVX2 detection (std path) - uses
127/// `std::is_x86_feature_detected!`.
128#[cfg(all(target_arch = "x86_64", feature = "std"))]
129#[inline]
130fn is_avx2_available() -> bool {
131 std::is_x86_feature_detected!("avx2")
132}
133
134/// Compile-time AVX2 detection (no_std path) - `std` isn't
135/// available, so we fall back to `cfg!(target_feature = "avx2")`
136/// which only reports `true` when the crate was built with
137/// `target-cpu` exposing AVX2.
138#[cfg(all(target_arch = "x86_64", not(feature = "std")))]
139#[inline]
140fn is_avx2_available() -> bool {
141 cfg!(target_feature = "avx2")
142}
143
144// --------------------------- AArch64 NEON -----------------------------
145//
146// 2-lane Xoshiro256++. Each iteration writes 16 output bytes; the
147// computation is the scalar update with `uint64x2_t` substituted for
148// `u64`. Throughput target on Apple M-series: ~20 GB/s (vs. 7.5 GB/s
149// scalar baseline).
150
151/// AArch64 NEON 2-lane implementation. K = 2 Xoshiro256++
152/// states held in `uint64x2_t` registers; 16 bytes emitted per
153/// inner step.
154#[cfg(target_arch = "aarch64")]
155mod aarch64 {
156 use super::Xoshiro256PlusPlus;
157 use core::arch::aarch64::*;
158
159 /// Two independent Xoshiro256++ states packed into 4 × `uint64x2_t`
160 /// registers. Lane i of register `s[j]` is word j of state i.
161 struct Lanes {
162 /// Four 128-bit registers. Each holds the i-th word of
163 /// both lanes' Xoshiro256++ state.
164 s: [uint64x2_t; 4],
165 }
166
167 impl Lanes {
168 /// Build a 2-lane `Lanes` from two pre-computed Xoshiro256++
169 /// states. `rng` is **not** mutated by [`fill_bytes_neon`] -
170 /// the scalar state is advanced once after the SIMD loop by
171 /// reading lane 0's final state and writing it back.
172 #[inline]
173 fn from_pair(lane0: [u64; 4], lane1: [u64; 4]) -> Self {
174 // SAFETY: vsetq_lane_u64 takes a valid lane index in [0, 2).
175 unsafe {
176 let mut s = [vdupq_n_u64(0); 4];
177 for (j, slot) in s.iter_mut().enumerate() {
178 let r = vsetq_lane_u64::<0>(lane0[j], *slot);
179 *slot = vsetq_lane_u64::<1>(lane1[j], r);
180 }
181 Self { s }
182 }
183 }
184
185 /// Reads lane 0's final state out of the SIMD registers.
186 #[inline]
187 fn lane0_state(&self) -> [u64; 4] {
188 // SAFETY: vgetq_lane_u64 lane index is in [0, 2).
189 unsafe {
190 [
191 vgetq_lane_u64::<0>(self.s[0]),
192 vgetq_lane_u64::<0>(self.s[1]),
193 vgetq_lane_u64::<0>(self.s[2]),
194 vgetq_lane_u64::<0>(self.s[3]),
195 ]
196 }
197 }
198
199 /// One Xoshiro256++ step. Returns the per-lane outputs as a
200 /// single `uint64x2_t` (= 16 bytes when stored).
201 #[inline]
202 unsafe fn step(&mut self) -> uint64x2_t {
203 let s = &mut self.s;
204 // res = rotl(s0 + s3, 23) + s0
205 let sum = vaddq_u64(s[0], s[3]);
206 let res = vaddq_u64(rotl::<23, 41>(sum), s[0]);
207
208 let t = vshlq_n_u64::<17>(s[1]);
209
210 s[2] = veorq_u64(s[2], s[0]);
211 s[3] = veorq_u64(s[3], s[1]);
212 s[1] = veorq_u64(s[1], s[2]);
213 s[0] = veorq_u64(s[0], s[3]);
214
215 s[2] = veorq_u64(s[2], t);
216 s[3] = rotl::<45, 19>(s[3]);
217
218 res
219 }
220 }
221
222 /// Vector rotate-left for two lanes. NEON has no native
223 /// rotate; emulate via shift + or. `N_INV` must equal
224 /// `64 - N` (the call sites use 23/41 and 45/19).
225 ///
226 /// # Safety
227 /// Sound for any `N` and `N_INV` in `0..64`. Const-generic
228 /// arguments are checked at compile time.
229 #[inline]
230 unsafe fn rotl<const N: i32, const N_INV: i32>(
231 x: uint64x2_t,
232 ) -> uint64x2_t {
233 vorrq_u64(vshlq_n_u64::<N>(x), vshrq_n_u64::<N_INV>(x))
234 }
235
236 /// NEON `fill_bytes` entry point - called by the parent
237 /// module's dispatch when the buffer is ≥ `SIMD_THRESHOLD`
238 /// bytes on AArch64.
239 pub(super) fn fill_bytes_neon(
240 rng: &mut Xoshiro256PlusPlus,
241 dest: &mut [u8],
242 ) {
243 // Caller (super::fill_bytes) guarantees dest.len() >= SIMD_THRESHOLD;
244 // no need for a redundant < 16 early-return.
245 // Two independent 2-lane states give 4-way effective
246 // parallelism, plenty for the M-series' 4-wide NEON pipeline.
247 let four = super::derive_lanes::<4>(rng);
248 let mut lanes_a = Lanes::from_pair(four[0], four[1]);
249 let mut lanes_b = Lanes::from_pair(four[2], four[3]);
250 let dest_len = dest.len();
251 let mut i = 0;
252 // Two steps from each lane group per iteration = 64 bytes
253 // out, 4-way interleaving for the M-series' two NEON ports.
254 while i + 64 <= dest_len {
255 // SAFETY: each step is a pure register update; the four
256 // 16-byte stores land in `dest[i..i+64]`, in bounds by
257 // the loop guard.
258 unsafe {
259 let oa0 = lanes_a.step();
260 let ob0 = lanes_b.step();
261 let oa1 = lanes_a.step();
262 let ob1 = lanes_b.step();
263 let p = dest.as_mut_ptr().add(i) as *mut u64;
264 vst1q_u64(p, oa0);
265 vst1q_u64(p.add(2), ob0);
266 vst1q_u64(p.add(4), oa1);
267 vst1q_u64(p.add(6), ob1);
268 }
269 i += 64;
270 }
271 while i + 32 <= dest_len {
272 // SAFETY: 32-byte tail of the unrolled loop.
273 unsafe {
274 let oa = lanes_a.step();
275 let ob = lanes_b.step();
276 let p = dest.as_mut_ptr().add(i) as *mut u64;
277 vst1q_u64(p, oa);
278 vst1q_u64(p.add(2), ob);
279 }
280 i += 32;
281 }
282 while i + 16 <= dest_len {
283 // SAFETY: 16-byte store stays within `dest`.
284 unsafe {
285 let out = lanes_a.step();
286 vst1q_u64(dest.as_mut_ptr().add(i) as *mut u64, out);
287 }
288 i += 16;
289 }
290 // Advance the scalar state by taking lanes_a's lane-0 final
291 // state - deterministic, well-randomised Xoshiro256++.
292 rng.set_state(lanes_a.lane0_state());
293 if i < dest_len {
294 rng.fill_bytes_scalar(&mut dest[i..]);
295 }
296 }
297}
298
299// ---------------------------- x86_64 AVX2 -----------------------------
300//
301// 4-lane Xoshiro256++. Each iteration writes 32 output bytes.
302// Throughput target on a modern AVX2 part: ~25–40 GB/s.
303
304/// x86_64 AVX2 4-lane implementation. K = 4 Xoshiro256++
305/// states held in `__m256i` registers; 32 bytes emitted per
306/// inner step.
307#[cfg(target_arch = "x86_64")]
308mod x86_64 {
309 use super::Xoshiro256PlusPlus;
310 use core::arch::x86_64::*;
311
312 /// Four independent Xoshiro256++ states packed into 4 ×
313 /// `__m256i` registers. Each register holds the i-th word
314 /// of all four lanes' state.
315 struct Lanes {
316 /// Four 256-bit registers; each register `s[j]` holds
317 /// the j-th word of all four lanes' Xoshiro256++ state.
318 s: [__m256i; 4],
319 }
320
321 impl Lanes {
322 /// Derives four independent lane states from the
323 /// scalar generator (via SplitMix64 whitening) and
324 /// loads them into `__m256i` registers in transposed
325 /// layout.
326 ///
327 /// # Safety
328 /// Requires the AVX2 target feature at runtime; enforced
329 /// by `#[target_feature(enable = "avx2")]` and the
330 /// runtime detection in the parent `fill_bytes`.
331 #[target_feature(enable = "avx2")]
332 unsafe fn from_rng(rng: &Xoshiro256PlusPlus) -> Self {
333 let lane_states = super::derive_lanes::<4>(rng);
334 // Transpose: register j holds
335 // [lane0[j], lane1[j], lane2[j], lane3[j]].
336 let mut s = [_mm256_setzero_si256(); 4];
337 for (j, slot) in s.iter_mut().enumerate() {
338 *slot = _mm256_set_epi64x(
339 lane_states[3][j] as i64,
340 lane_states[2][j] as i64,
341 lane_states[1][j] as i64,
342 lane_states[0][j] as i64,
343 );
344 }
345 Self { s }
346 }
347
348 /// One Xoshiro256++ step across all four lanes.
349 /// Returns the per-lane outputs as a single `__m256i`
350 /// (= 32 bytes when stored).
351 ///
352 /// # Safety
353 /// Same AVX2 contract as [`Self::from_rng`].
354 #[inline]
355 #[target_feature(enable = "avx2")]
356 unsafe fn step(&mut self) -> __m256i {
357 let s = &mut self.s;
358 let sum = _mm256_add_epi64(s[0], s[3]);
359 let res = _mm256_add_epi64(rotl::<23, 41>(sum), s[0]);
360
361 let t = _mm256_slli_epi64::<17>(s[1]);
362
363 s[2] = _mm256_xor_si256(s[2], s[0]);
364 s[3] = _mm256_xor_si256(s[3], s[1]);
365 s[1] = _mm256_xor_si256(s[1], s[2]);
366 s[0] = _mm256_xor_si256(s[0], s[3]);
367
368 s[2] = _mm256_xor_si256(s[2], t);
369 s[3] = rotl::<45, 19>(s[3]);
370 res
371 }
372
373 /// Extracts lane 0's final state from the SIMD registers.
374 #[target_feature(enable = "avx2")]
375 unsafe fn lane0_state(&self) -> [u64; 4] {
376 let mut tmp = [0i64; 4];
377 let mut out = [0u64; 4];
378 for (j, slot) in out.iter_mut().enumerate() {
379 _mm256_storeu_si256(
380 tmp.as_mut_ptr() as *mut __m256i,
381 self.s[j],
382 );
383 *slot = tmp[0] as u64;
384 }
385 out
386 }
387 }
388
389 /// Vector rotate-left for four lanes. AVX2 has no native
390 /// rotate; emulate via shift + or. `N_INV` must equal
391 /// `64 - N` (call sites use 23/41 and 45/19).
392 ///
393 /// # Safety
394 /// AVX2 target feature required.
395 #[inline]
396 #[target_feature(enable = "avx2")]
397 unsafe fn rotl<const N: i32, const N_INV: i32>(
398 x: __m256i,
399 ) -> __m256i {
400 _mm256_or_si256(
401 _mm256_slli_epi64::<N>(x),
402 _mm256_srli_epi64::<N_INV>(x),
403 )
404 }
405
406 /// AVX2 `fill_bytes` entry point - called by the parent
407 /// module's dispatch when the buffer is ≥ `SIMD_THRESHOLD`
408 /// bytes on x86_64 and the runtime CPU advertises AVX2.
409 ///
410 /// # Safety
411 /// Caller (`super::fill_bytes`) verifies AVX2 availability
412 /// via `is_avx2_available()` before calling.
413 #[target_feature(enable = "avx2")]
414 pub(super) unsafe fn fill_bytes_avx2(
415 rng: &mut Xoshiro256PlusPlus,
416 dest: &mut [u8],
417 ) {
418 // Caller (super::fill_bytes) guarantees dest.len() >= SIMD_THRESHOLD,
419 // which is well above 32 - no redundant early-return needed.
420 let mut lanes = Lanes::from_rng(rng);
421 let mut i = 0;
422 while i + 32 <= dest.len() {
423 let out = lanes.step();
424 _mm256_storeu_si256(
425 dest.as_mut_ptr().add(i) as *mut __m256i,
426 out,
427 );
428 i += 32;
429 }
430 let new_scalar = lanes.lane0_state();
431 rng.set_state(new_scalar);
432 if i < dest.len() {
433 rng.fill_bytes_scalar(&mut dest[i..]);
434 }
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441 #[cfg(feature = "alloc")]
442 use alloc::vec;
443 #[cfg(all(not(feature = "alloc"), feature = "std"))]
444 use std::vec;
445
446 /// Statistical: bytes should be uniformly distributed. With 64 KiB
447 /// the per-byte count is ~256, well above the χ² alarm threshold
448 /// for a fair 8-bit distribution.
449 #[test]
450 fn fill_produces_uniform_bytes() {
451 let mut rng = Xoshiro256PlusPlus::from_u64_seed(0xC0DE_BEEF);
452 let mut buf = [0u8; 64 * 1024];
453 fill_bytes(&mut rng, &mut buf);
454
455 let mut counts = [0u32; 256];
456 for &b in &buf[..] {
457 counts[b as usize] += 1;
458 }
459 let mean = (buf.len() / 256) as f64;
460 let chi2: f64 = counts
461 .iter()
462 .map(|&c| {
463 let diff = c as f64 - mean;
464 diff * diff / mean
465 })
466 .sum();
467 // χ² with 255 degrees of freedom: 99.99% upper critical ≈ 358.
468 // We use a loose 500 to accommodate run-to-run variance.
469 assert!(chi2 < 500.0, "χ² = {chi2} too high");
470 }
471
472 /// The SIMD path must handle unaligned and short buffers via the
473 /// scalar tail.
474 #[test]
475 fn fill_handles_short_and_unaligned_lengths() {
476 let mut rng = Xoshiro256PlusPlus::from_u64_seed(1);
477 for &len in
478 &[0usize, 1, 7, 15, 16, 17, 31, 33, 63, 65, 127, 129]
479 {
480 let mut buf = vec![0u8; len];
481 fill_bytes(&mut rng, &mut buf);
482 // Most buffers will have at least one non-zero byte. Skip
483 // the len=0 case which is vacuously fine.
484 if len > 4 {
485 assert!(
486 buf.iter().any(|&b| b != 0),
487 "no entropy at len {len}"
488 );
489 }
490 }
491 }
492
493 /// SIMD must produce a different stream than scalar from the same
494 /// seed - this is the documented contract. Only meaningful on
495 /// architectures with a real SIMD path.
496 #[test]
497 #[cfg(any(target_arch = "aarch64", target_arch = "x86_64"))]
498 fn simd_diverges_from_scalar() {
499 let mut a = Xoshiro256PlusPlus::from_u64_seed(42);
500 let mut b = Xoshiro256PlusPlus::from_u64_seed(42);
501 let mut sa = [0u8; 256];
502 let mut sb = [0u8; 256];
503 fill_bytes(&mut a, &mut sa);
504 b.fill_bytes_scalar(&mut sb);
505 assert_ne!(sa, sb, "SIMD and scalar must diverge");
506 }
507}