Skip to main content

vrd/
xoshiro_simd.rs

1// Copyright © 2023-2026 vrd. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! SIMD-batched Xoshiro256++ for `fill_bytes`.
5//!
6//! Holds **K independent** Xoshiro256++ states (K = 2 on AArch64 NEON,
7//! K = 4 on x86_64 AVX2) in SIMD registers and advances all of them in
8//! one inner-loop iteration. Each `fill_bytes` call derives the K lane
9//! states by SplitMix64-whitening the scalar generator's state with a
10//! distinct lane-specific constant: cheap (~10 ns of setup), and
11//! statistically independent lanes by construction. The scalar state
12//! is advanced by the equivalent number of `next_u64` calls so that
13//! subsequent scalar calls remain consistent with the scalar-only
14//! path.
15//!
16//! An earlier draft used [`crate::xoshiro::Xoshiro256PlusPlus::jump`]
17//! for 2¹²⁸-step
18//! separation per lane, but at 256 scalar `next_u64`s per call its
19//! ~256 ns setup wiped out the SIMD win for buffers under ~4 KiB. The
20//! SplitMix derivation keeps lanes uncorrelated (probability of state
21//! collision is ≤ K²/2²⁵⁶ - negligible) at a fraction of the cost.
22//!
23//! # Reproducibility contract
24//!
25//! The same seed produces a **different byte stream** between the
26//! scalar path and the SIMD path. This is fundamental: there is no
27//! correctness-preserving way to interleave K independent Xoshiro
28//! generators into the *same* sequence a single-threaded generator
29//! would produce. Code that depends on bit-for-bit reproducibility
30//! across feature sets must use the scalar path.
31//!
32//! Statistical quality is unchanged - each lane is a full Xoshiro256++
33//! and inherits all of its properties.
34
35#![allow(unsafe_code)]
36
37use crate::xoshiro::Xoshiro256PlusPlus;
38
39/// Derives K lane states from the scalar generator's current state.
40/// Lane K's seed material is the scalar state's first word XORed with
41/// a lane-specific 64-bit constant, then run through four SplitMix64
42/// rounds. Output: `K` arrays of `[u64; 4]`.
43///
44/// SplitMix64 constants per <https://prng.di.unimi.it/splitmix64.c>.
45#[inline]
46fn derive_lanes<const K: usize>(
47    rng: &Xoshiro256PlusPlus,
48) -> [[u64; 4]; K] {
49    const LANE_SALT: [u64; 4] = [
50        0xA076_1D64_78BD_642F,
51        0xE703_7ED1_A0B4_28DB,
52        0x8EBC_6AF0_9C88_C6E3,
53        0x5899_65CC_7537_4CC3,
54    ];
55    let base = rng.state_snapshot();
56    let mut out = [[0u64; 4]; K];
57    for (k, lane) in out.iter_mut().enumerate() {
58        let mut sm = base[0] ^ LANE_SALT[k];
59        for slot in lane.iter_mut() {
60            *slot = splitmix64(&mut sm);
61        }
62        // Mix in the rest of the base state so a lane's distribution
63        // tracks the full 256-bit scalar seed, not just word 0.
64        for (slot, &b) in lane.iter_mut().zip(base.iter()) {
65            *slot ^= b.rotate_left(((k as u32 + 1) * 13) % 64);
66        }
67    }
68    out
69}
70
71/// SplitMix64 - Stafford's variant 13. Used to whiten the
72/// per-lane seed material derived from the scalar generator's
73/// state so the SIMD lanes are statistically independent.
74#[inline]
75fn splitmix64(state: &mut u64) -> u64 {
76    *state = state.wrapping_add(0x9E37_79B9_7F4A_7C15);
77    let mut z = *state;
78    z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
79    z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
80    z ^ (z >> 31)
81}
82
83/// Below this many bytes, the SIMD setup cost outweighs the per-byte
84/// savings; fall back to the scalar generator instead. Calibrated for
85/// the SplitMix-derived lanes, not for jump-derived lanes (which had a
86/// ~4 KiB break-even).
87const SIMD_THRESHOLD: usize = 64;
88
89/// Fills `dest` with random bytes using the best SIMD path available
90/// for the active target. Falls back to the scalar generator for
91/// buffers smaller than 64 bytes (where setup cost beats the
92/// per-byte savings) and for the trailing bytes that don't fill a
93/// full SIMD register.
94///
95/// Three target-specific definitions live below; rustc only
96/// compiles the one matching the active `target_arch`.
97#[cfg(target_arch = "aarch64")]
98#[inline]
99pub fn fill_bytes(rng: &mut Xoshiro256PlusPlus, dest: &mut [u8]) {
100    if dest.len() < SIMD_THRESHOLD {
101        rng.fill_bytes_scalar(dest);
102        return;
103    }
104    aarch64::fill_bytes_neon(rng, dest);
105}
106
107/// x86_64 dispatch: prefer AVX2 if the CPU supports it; else scalar.
108#[cfg(target_arch = "x86_64")]
109#[inline]
110pub fn fill_bytes(rng: &mut Xoshiro256PlusPlus, dest: &mut [u8]) {
111    if dest.len() < SIMD_THRESHOLD || !is_avx2_available() {
112        rng.fill_bytes_scalar(dest);
113        return;
114    }
115    // SAFETY: gated on runtime AVX2 detection above.
116    unsafe { x86_64::fill_bytes_avx2(rng, dest) };
117}
118
119/// Fallback for architectures without a SIMD path.
120#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
121#[inline]
122pub fn fill_bytes(rng: &mut Xoshiro256PlusPlus, dest: &mut [u8]) {
123    rng.fill_bytes_scalar(dest);
124}
125
126/// Runtime AVX2 detection (std path) - uses
127/// `std::is_x86_feature_detected!`.
128#[cfg(all(target_arch = "x86_64", feature = "std"))]
129#[inline]
130fn is_avx2_available() -> bool {
131    std::is_x86_feature_detected!("avx2")
132}
133
134/// Compile-time AVX2 detection (no_std path) - `std` isn't
135/// available, so we fall back to `cfg!(target_feature = "avx2")`
136/// which only reports `true` when the crate was built with
137/// `target-cpu` exposing AVX2.
138#[cfg(all(target_arch = "x86_64", not(feature = "std")))]
139#[inline]
140fn is_avx2_available() -> bool {
141    cfg!(target_feature = "avx2")
142}
143
144// --------------------------- AArch64 NEON -----------------------------
145//
146// 2-lane Xoshiro256++. Each iteration writes 16 output bytes; the
147// computation is the scalar update with `uint64x2_t` substituted for
148// `u64`. Throughput target on Apple M-series: ~20 GB/s (vs. 7.5 GB/s
149// scalar baseline).
150
151/// AArch64 NEON 2-lane implementation. K = 2 Xoshiro256++
152/// states held in `uint64x2_t` registers; 16 bytes emitted per
153/// inner step.
154#[cfg(target_arch = "aarch64")]
155mod aarch64 {
156    use super::Xoshiro256PlusPlus;
157    use core::arch::aarch64::*;
158
159    /// Two independent Xoshiro256++ states packed into 4 × `uint64x2_t`
160    /// registers. Lane i of register `s[j]` is word j of state i.
161    struct Lanes {
162        /// Four 128-bit registers. Each holds the i-th word of
163        /// both lanes' Xoshiro256++ state.
164        s: [uint64x2_t; 4],
165    }
166
167    impl Lanes {
168        /// Build a 2-lane `Lanes` from two pre-computed Xoshiro256++
169        /// states. `rng` is **not** mutated by [`fill_bytes_neon`] -
170        /// the scalar state is advanced once after the SIMD loop by
171        /// reading lane 0's final state and writing it back.
172        #[inline]
173        fn from_pair(lane0: [u64; 4], lane1: [u64; 4]) -> Self {
174            // SAFETY: vsetq_lane_u64 takes a valid lane index in [0, 2).
175            unsafe {
176                let mut s = [vdupq_n_u64(0); 4];
177                for (j, slot) in s.iter_mut().enumerate() {
178                    let r = vsetq_lane_u64::<0>(lane0[j], *slot);
179                    *slot = vsetq_lane_u64::<1>(lane1[j], r);
180                }
181                Self { s }
182            }
183        }
184
185        /// Reads lane 0's final state out of the SIMD registers.
186        #[inline]
187        fn lane0_state(&self) -> [u64; 4] {
188            // SAFETY: vgetq_lane_u64 lane index is in [0, 2).
189            unsafe {
190                [
191                    vgetq_lane_u64::<0>(self.s[0]),
192                    vgetq_lane_u64::<0>(self.s[1]),
193                    vgetq_lane_u64::<0>(self.s[2]),
194                    vgetq_lane_u64::<0>(self.s[3]),
195                ]
196            }
197        }
198
199        /// One Xoshiro256++ step. Returns the per-lane outputs as a
200        /// single `uint64x2_t` (= 16 bytes when stored).
201        #[inline]
202        unsafe fn step(&mut self) -> uint64x2_t {
203            let s = &mut self.s;
204            // res = rotl(s0 + s3, 23) + s0
205            let sum = vaddq_u64(s[0], s[3]);
206            let res = vaddq_u64(rotl::<23, 41>(sum), s[0]);
207
208            let t = vshlq_n_u64::<17>(s[1]);
209
210            s[2] = veorq_u64(s[2], s[0]);
211            s[3] = veorq_u64(s[3], s[1]);
212            s[1] = veorq_u64(s[1], s[2]);
213            s[0] = veorq_u64(s[0], s[3]);
214
215            s[2] = veorq_u64(s[2], t);
216            s[3] = rotl::<45, 19>(s[3]);
217
218            res
219        }
220    }
221
222    /// Vector rotate-left for two lanes. NEON has no native
223    /// rotate; emulate via shift + or. `N_INV` must equal
224    /// `64 - N` (the call sites use 23/41 and 45/19).
225    ///
226    /// # Safety
227    /// Sound for any `N` and `N_INV` in `0..64`. Const-generic
228    /// arguments are checked at compile time.
229    #[inline]
230    unsafe fn rotl<const N: i32, const N_INV: i32>(
231        x: uint64x2_t,
232    ) -> uint64x2_t {
233        vorrq_u64(vshlq_n_u64::<N>(x), vshrq_n_u64::<N_INV>(x))
234    }
235
236    /// NEON `fill_bytes` entry point - called by the parent
237    /// module's dispatch when the buffer is ≥ `SIMD_THRESHOLD`
238    /// bytes on AArch64.
239    pub(super) fn fill_bytes_neon(
240        rng: &mut Xoshiro256PlusPlus,
241        dest: &mut [u8],
242    ) {
243        // Caller (super::fill_bytes) guarantees dest.len() >= SIMD_THRESHOLD;
244        // no need for a redundant < 16 early-return.
245        // Two independent 2-lane states give 4-way effective
246        // parallelism, plenty for the M-series' 4-wide NEON pipeline.
247        let four = super::derive_lanes::<4>(rng);
248        let mut lanes_a = Lanes::from_pair(four[0], four[1]);
249        let mut lanes_b = Lanes::from_pair(four[2], four[3]);
250        let dest_len = dest.len();
251        let mut i = 0;
252        // Two steps from each lane group per iteration = 64 bytes
253        // out, 4-way interleaving for the M-series' two NEON ports.
254        while i + 64 <= dest_len {
255            // SAFETY: each step is a pure register update; the four
256            // 16-byte stores land in `dest[i..i+64]`, in bounds by
257            // the loop guard.
258            unsafe {
259                let oa0 = lanes_a.step();
260                let ob0 = lanes_b.step();
261                let oa1 = lanes_a.step();
262                let ob1 = lanes_b.step();
263                let p = dest.as_mut_ptr().add(i) as *mut u64;
264                vst1q_u64(p, oa0);
265                vst1q_u64(p.add(2), ob0);
266                vst1q_u64(p.add(4), oa1);
267                vst1q_u64(p.add(6), ob1);
268            }
269            i += 64;
270        }
271        while i + 32 <= dest_len {
272            // SAFETY: 32-byte tail of the unrolled loop.
273            unsafe {
274                let oa = lanes_a.step();
275                let ob = lanes_b.step();
276                let p = dest.as_mut_ptr().add(i) as *mut u64;
277                vst1q_u64(p, oa);
278                vst1q_u64(p.add(2), ob);
279            }
280            i += 32;
281        }
282        while i + 16 <= dest_len {
283            // SAFETY: 16-byte store stays within `dest`.
284            unsafe {
285                let out = lanes_a.step();
286                vst1q_u64(dest.as_mut_ptr().add(i) as *mut u64, out);
287            }
288            i += 16;
289        }
290        // Advance the scalar state by taking lanes_a's lane-0 final
291        // state - deterministic, well-randomised Xoshiro256++.
292        rng.set_state(lanes_a.lane0_state());
293        if i < dest_len {
294            rng.fill_bytes_scalar(&mut dest[i..]);
295        }
296    }
297}
298
299// ---------------------------- x86_64 AVX2 -----------------------------
300//
301// 4-lane Xoshiro256++. Each iteration writes 32 output bytes.
302// Throughput target on a modern AVX2 part: ~25–40 GB/s.
303
304/// x86_64 AVX2 4-lane implementation. K = 4 Xoshiro256++
305/// states held in `__m256i` registers; 32 bytes emitted per
306/// inner step.
307#[cfg(target_arch = "x86_64")]
308mod x86_64 {
309    use super::Xoshiro256PlusPlus;
310    use core::arch::x86_64::*;
311
312    /// Four independent Xoshiro256++ states packed into 4 ×
313    /// `__m256i` registers. Each register holds the i-th word
314    /// of all four lanes' state.
315    struct Lanes {
316        /// Four 256-bit registers; each register `s[j]` holds
317        /// the j-th word of all four lanes' Xoshiro256++ state.
318        s: [__m256i; 4],
319    }
320
321    impl Lanes {
322        /// Derives four independent lane states from the
323        /// scalar generator (via SplitMix64 whitening) and
324        /// loads them into `__m256i` registers in transposed
325        /// layout.
326        ///
327        /// # Safety
328        /// Requires the AVX2 target feature at runtime; enforced
329        /// by `#[target_feature(enable = "avx2")]` and the
330        /// runtime detection in the parent `fill_bytes`.
331        #[target_feature(enable = "avx2")]
332        unsafe fn from_rng(rng: &Xoshiro256PlusPlus) -> Self {
333            let lane_states = super::derive_lanes::<4>(rng);
334            // Transpose: register j holds
335            //   [lane0[j], lane1[j], lane2[j], lane3[j]].
336            let mut s = [_mm256_setzero_si256(); 4];
337            for (j, slot) in s.iter_mut().enumerate() {
338                *slot = _mm256_set_epi64x(
339                    lane_states[3][j] as i64,
340                    lane_states[2][j] as i64,
341                    lane_states[1][j] as i64,
342                    lane_states[0][j] as i64,
343                );
344            }
345            Self { s }
346        }
347
348        /// One Xoshiro256++ step across all four lanes.
349        /// Returns the per-lane outputs as a single `__m256i`
350        /// (= 32 bytes when stored).
351        ///
352        /// # Safety
353        /// Same AVX2 contract as [`Self::from_rng`].
354        #[inline]
355        #[target_feature(enable = "avx2")]
356        unsafe fn step(&mut self) -> __m256i {
357            let s = &mut self.s;
358            let sum = _mm256_add_epi64(s[0], s[3]);
359            let res = _mm256_add_epi64(rotl::<23, 41>(sum), s[0]);
360
361            let t = _mm256_slli_epi64::<17>(s[1]);
362
363            s[2] = _mm256_xor_si256(s[2], s[0]);
364            s[3] = _mm256_xor_si256(s[3], s[1]);
365            s[1] = _mm256_xor_si256(s[1], s[2]);
366            s[0] = _mm256_xor_si256(s[0], s[3]);
367
368            s[2] = _mm256_xor_si256(s[2], t);
369            s[3] = rotl::<45, 19>(s[3]);
370            res
371        }
372
373        /// Extracts lane 0's final state from the SIMD registers.
374        #[target_feature(enable = "avx2")]
375        unsafe fn lane0_state(&self) -> [u64; 4] {
376            let mut tmp = [0i64; 4];
377            let mut out = [0u64; 4];
378            for (j, slot) in out.iter_mut().enumerate() {
379                _mm256_storeu_si256(
380                    tmp.as_mut_ptr() as *mut __m256i,
381                    self.s[j],
382                );
383                *slot = tmp[0] as u64;
384            }
385            out
386        }
387    }
388
389    /// Vector rotate-left for four lanes. AVX2 has no native
390    /// rotate; emulate via shift + or. `N_INV` must equal
391    /// `64 - N` (call sites use 23/41 and 45/19).
392    ///
393    /// # Safety
394    /// AVX2 target feature required.
395    #[inline]
396    #[target_feature(enable = "avx2")]
397    unsafe fn rotl<const N: i32, const N_INV: i32>(
398        x: __m256i,
399    ) -> __m256i {
400        _mm256_or_si256(
401            _mm256_slli_epi64::<N>(x),
402            _mm256_srli_epi64::<N_INV>(x),
403        )
404    }
405
406    /// AVX2 `fill_bytes` entry point - called by the parent
407    /// module's dispatch when the buffer is ≥ `SIMD_THRESHOLD`
408    /// bytes on x86_64 and the runtime CPU advertises AVX2.
409    ///
410    /// # Safety
411    /// Caller (`super::fill_bytes`) verifies AVX2 availability
412    /// via `is_avx2_available()` before calling.
413    #[target_feature(enable = "avx2")]
414    pub(super) unsafe fn fill_bytes_avx2(
415        rng: &mut Xoshiro256PlusPlus,
416        dest: &mut [u8],
417    ) {
418        // Caller (super::fill_bytes) guarantees dest.len() >= SIMD_THRESHOLD,
419        // which is well above 32 - no redundant early-return needed.
420        let mut lanes = Lanes::from_rng(rng);
421        let mut i = 0;
422        while i + 32 <= dest.len() {
423            let out = lanes.step();
424            _mm256_storeu_si256(
425                dest.as_mut_ptr().add(i) as *mut __m256i,
426                out,
427            );
428            i += 32;
429        }
430        let new_scalar = lanes.lane0_state();
431        rng.set_state(new_scalar);
432        if i < dest.len() {
433            rng.fill_bytes_scalar(&mut dest[i..]);
434        }
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    #[cfg(feature = "alloc")]
442    use alloc::vec;
443    #[cfg(all(not(feature = "alloc"), feature = "std"))]
444    use std::vec;
445
446    /// Statistical: bytes should be uniformly distributed. With 64 KiB
447    /// the per-byte count is ~256, well above the χ² alarm threshold
448    /// for a fair 8-bit distribution.
449    #[test]
450    fn fill_produces_uniform_bytes() {
451        let mut rng = Xoshiro256PlusPlus::from_u64_seed(0xC0DE_BEEF);
452        let mut buf = [0u8; 64 * 1024];
453        fill_bytes(&mut rng, &mut buf);
454
455        let mut counts = [0u32; 256];
456        for &b in &buf[..] {
457            counts[b as usize] += 1;
458        }
459        let mean = (buf.len() / 256) as f64;
460        let chi2: f64 = counts
461            .iter()
462            .map(|&c| {
463                let diff = c as f64 - mean;
464                diff * diff / mean
465            })
466            .sum();
467        // χ² with 255 degrees of freedom: 99.99% upper critical ≈ 358.
468        // We use a loose 500 to accommodate run-to-run variance.
469        assert!(chi2 < 500.0, "χ² = {chi2} too high");
470    }
471
472    /// The SIMD path must handle unaligned and short buffers via the
473    /// scalar tail.
474    #[test]
475    fn fill_handles_short_and_unaligned_lengths() {
476        let mut rng = Xoshiro256PlusPlus::from_u64_seed(1);
477        for &len in
478            &[0usize, 1, 7, 15, 16, 17, 31, 33, 63, 65, 127, 129]
479        {
480            let mut buf = vec![0u8; len];
481            fill_bytes(&mut rng, &mut buf);
482            // Most buffers will have at least one non-zero byte. Skip
483            // the len=0 case which is vacuously fine.
484            if len > 4 {
485                assert!(
486                    buf.iter().any(|&b| b != 0),
487                    "no entropy at len {len}"
488                );
489            }
490        }
491    }
492
493    /// SIMD must produce a different stream than scalar from the same
494    /// seed - this is the documented contract. Only meaningful on
495    /// architectures with a real SIMD path.
496    #[test]
497    #[cfg(any(target_arch = "aarch64", target_arch = "x86_64"))]
498    fn simd_diverges_from_scalar() {
499        let mut a = Xoshiro256PlusPlus::from_u64_seed(42);
500        let mut b = Xoshiro256PlusPlus::from_u64_seed(42);
501        let mut sa = [0u8; 256];
502        let mut sb = [0u8; 256];
503        fill_bytes(&mut a, &mut sa);
504        b.fill_bytes_scalar(&mut sb);
505        assert_ne!(sa, sb, "SIMD and scalar must diverge");
506    }
507}