Skip to main content

vrd/
mersenne_twister.rs

1// Copyright © 2023-2026 vrd. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! Mersenne Twister (MT19937) configuration types.
5//!
6//! The actual MT19937 generator is implemented in [`crate::random`]; this
7//! module provides the configuration parameters and validation.
8
9use core::fmt;
10
11#[cfg(feature = "serde")]
12use serde::{Deserialize, Serialize};
13
14/// Errors produced by [`MersenneTwisterConfig`].
15///
16/// # Examples
17///
18/// ```
19/// use vrd::MersenneTwisterError;
20///
21/// let err = MersenneTwisterError::InvalidConfig("N must be at least 1");
22/// assert_eq!(format!("{err}"), "Invalid configuration: N must be at least 1");
23/// ```
24#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
25#[non_exhaustive]
26pub enum MersenneTwisterError {
27    /// A configuration parameter was outside its valid range.
28    ///
29    /// The payload is a `&'static str` so the error type stays
30    /// allocation-free and works under `no_std` without `alloc`.
31    InvalidConfig(&'static str),
32}
33
34impl fmt::Display for MersenneTwisterError {
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        match self {
37            MersenneTwisterError::InvalidConfig(msg) => {
38                write!(f, "Invalid configuration: {}", msg)
39            }
40        }
41    }
42}
43
44#[cfg(feature = "std")]
45impl std::error::Error for MersenneTwisterError {}
46
47/// Parameter values for the Mersenne Twister algorithm.
48///
49/// The defaults match the canonical MT19937 constants. Custom parameters
50/// must satisfy the well-known invariants - see
51/// [`MersenneTwisterConfig::validate`].
52///
53/// # Examples
54///
55/// ```
56/// use vrd::MersenneTwisterParams;
57///
58/// let p = MersenneTwisterParams::default();
59/// assert_eq!(p.matrix_a, 0x9908b0df);
60/// assert_eq!(p.upper_mask, 0x80000000);
61/// ```
62#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
63#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
64pub struct MersenneTwisterParams {
65    /// Constant whose highest bit must be set (canonical: `0x9908b0df`).
66    pub matrix_a: u32,
67    /// Upper-bit mask (canonical: `0x80000000`).
68    pub upper_mask: u32,
69    /// Lower-bit mask (canonical: `0x7fffffff`).
70    pub lower_mask: u32,
71    /// Tempering mask B (canonical: `0x9d2c5680`).
72    pub tempering_mask_b: u32,
73    /// Tempering mask C (canonical: `0xefc60000`).
74    pub tempering_mask_c: u32,
75}
76
77impl Default for MersenneTwisterParams {
78    /// Returns the canonical MT19937 constants.
79    ///
80    /// # Examples
81    ///
82    /// ```
83    /// use vrd::MersenneTwisterParams;
84    ///
85    /// let p = MersenneTwisterParams::default();
86    /// assert_eq!(p.matrix_a, 0x9908b0df);
87    /// ```
88    fn default() -> Self {
89        MersenneTwisterParams {
90            matrix_a: 0x9908b0df,
91            upper_mask: 0x80000000,
92            lower_mask: 0x7fffffff,
93            tempering_mask_b: 0x9d2c5680,
94            tempering_mask_c: 0xefc60000,
95        }
96    }
97}
98
99/// Configuration for an MT19937-style Mersenne Twister.
100///
101/// `N` is the array length; `M` is the recurrence offset. The canonical
102/// MT19937 instantiation is `MersenneTwisterConfig::<624, 397>`.
103///
104/// # Examples
105///
106/// ```
107/// use vrd::MersenneTwisterConfig;
108///
109/// let cfg = MersenneTwisterConfig::<624, 397>::default();
110/// assert_eq!(cfg.params.upper_mask, 0x80000000);
111/// ```
112#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
113#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
114pub struct MersenneTwisterConfig<const N: usize, const M: usize> {
115    /// The validated configuration parameters.
116    pub params: MersenneTwisterParams,
117}
118
119impl<const N: usize, const M: usize> MersenneTwisterConfig<N, M> {
120    /// Builds a config with custom parameters, validating the invariants.
121    ///
122    /// # Examples
123    ///
124    /// ```
125    /// use vrd::{MersenneTwisterConfig, MersenneTwisterParams};
126    ///
127    /// let p = MersenneTwisterParams::default();
128    /// let cfg = MersenneTwisterConfig::<624, 397>::new_custom(p).unwrap();
129    /// assert_eq!(cfg.params, p);
130    /// ```
131    ///
132    /// # Errors
133    ///
134    /// Returns [`MersenneTwisterError::InvalidConfig`] when any parameter
135    /// (or the `N`/`M` const generics) violates a Mersenne-Twister
136    /// invariant.
137    pub fn new_custom(
138        params: MersenneTwisterParams,
139    ) -> Result<Self, MersenneTwisterError> {
140        Self::validate(&params)?;
141        Ok(MersenneTwisterConfig { params })
142    }
143
144    /// Validates `params` against the Mersenne-Twister invariants.
145    ///
146    /// # Examples
147    ///
148    /// ```
149    /// use vrd::{MersenneTwisterConfig, MersenneTwisterParams};
150    ///
151    /// let p = MersenneTwisterParams::default();
152    /// assert!(MersenneTwisterConfig::<624, 397>::validate(&p).is_ok());
153    ///
154    /// // Invalid: M >= N.
155    /// assert!(MersenneTwisterConfig::<10, 10>::validate(&p).is_err());
156    /// ```
157    ///
158    /// # Errors
159    ///
160    /// Returns [`MersenneTwisterError::InvalidConfig`] when the invariants
161    /// are violated. The static message in the error names which one.
162    pub fn validate(
163        params: &MersenneTwisterParams,
164    ) -> Result<(), MersenneTwisterError> {
165        if N < 1 {
166            return Err(MersenneTwisterError::InvalidConfig(
167                "N must be at least 1",
168            ));
169        }
170        if M < 1 || M >= N {
171            return Err(MersenneTwisterError::InvalidConfig(
172                "M must be at least 1 and less than N",
173            ));
174        }
175        if params.matrix_a & 0x80000000 != 0x80000000 {
176            return Err(MersenneTwisterError::InvalidConfig(
177                "matrix_a must have its highest bit set",
178            ));
179        }
180        if params.upper_mask != 0x80000000 {
181            return Err(MersenneTwisterError::InvalidConfig(
182                "upper_mask must be 0x80000000",
183            ));
184        }
185        if params.lower_mask != 0x7fffffff {
186            return Err(MersenneTwisterError::InvalidConfig(
187                "lower_mask must be 0x7fffffff",
188            ));
189        }
190        if params.tempering_mask_b != 0x9d2c5680 {
191            return Err(MersenneTwisterError::InvalidConfig(
192                "tempering_mask_b must be 0x9d2c5680",
193            ));
194        }
195        if params.tempering_mask_c != 0xefc60000 {
196            return Err(MersenneTwisterError::InvalidConfig(
197                "tempering_mask_c must be 0xefc60000",
198            ));
199        }
200        Ok(())
201    }
202
203    /// Builds a config using the canonical MT19937 parameters.
204    ///
205    /// # Examples
206    ///
207    /// ```
208    /// use vrd::MersenneTwisterConfig;
209    ///
210    /// let cfg = MersenneTwisterConfig::<624, 397>::new().unwrap();
211    /// assert_eq!(cfg.params.matrix_a, 0x9908b0df);
212    /// ```
213    ///
214    /// # Errors
215    ///
216    /// Returns [`MersenneTwisterError::InvalidConfig`] only if the
217    /// `N`/`M` const generics violate the Mersenne-Twister invariants.
218    /// `MersenneTwisterConfig::<624, 397>::new()` is infallible.
219    pub fn new() -> Result<Self, MersenneTwisterError> {
220        Self::new_custom(MersenneTwisterParams::default())
221    }
222
223    /// Replaces the parameters in place after re-validating them.
224    /// On error, the existing parameters are preserved.
225    ///
226    /// # Examples
227    ///
228    /// ```
229    /// use vrd::{MersenneTwisterConfig, MersenneTwisterParams};
230    ///
231    /// let mut cfg = MersenneTwisterConfig::<624, 397>::default();
232    /// cfg.set_config(MersenneTwisterParams::default()).unwrap();
233    /// ```
234    ///
235    /// # Errors
236    ///
237    /// Returns [`MersenneTwisterError::InvalidConfig`] when `params`
238    /// fails validation; the existing `self.params` is left unchanged.
239    pub fn set_config(
240        &mut self,
241        params: MersenneTwisterParams,
242    ) -> Result<(), MersenneTwisterError> {
243        Self::validate(&params)?;
244        self.params = params;
245        Ok(())
246    }
247}
248
249impl Default for MersenneTwisterConfig<624, 397> {
250    /// Returns the canonical MT19937 configuration.
251    ///
252    /// # Examples
253    ///
254    /// ```
255    /// use vrd::MersenneTwisterConfig;
256    ///
257    /// let cfg = MersenneTwisterConfig::<624, 397>::default();
258    /// assert_eq!(cfg.params.lower_mask, 0x7fffffff);
259    /// ```
260    fn default() -> Self {
261        MersenneTwisterConfig::new()
262            .expect("canonical MT19937 parameters always validate")
263    }
264}
265
266impl<const N: usize, const M: usize> fmt::Display
267    for MersenneTwisterConfig<N, M>
268{
269    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270        let p = &self.params;
271        write!(f, "MersenneTwisterConfig {{ matrix_a: 0x{:08x}, upper_mask: 0x{:08x}, lower_mask: 0x{:08x}, tempering_mask_b: 0x{:08x}, tempering_mask_c: 0x{:08x} }}", p.matrix_a, p.upper_mask, p.lower_mask, p.tempering_mask_b, p.tempering_mask_c)
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    #[cfg(feature = "alloc")]
280    use alloc::format;
281    #[cfg(all(not(feature = "alloc"), feature = "std"))]
282    use std::format;
283
284    #[test]
285    fn test_exhaustive_mt_coverage() {
286        let mut p = MersenneTwisterParams::default();
287
288        // Error Display
289        let err = MersenneTwisterError::InvalidConfig("foo");
290        let _ = format!("{}", err);
291        #[cfg(feature = "std")]
292        {
293            use std::error::Error;
294            assert!(err.source().is_none());
295        }
296
297        // Params Debug
298        let _ = format!("{:?}", p);
299
300        // Config Display
301        let c = MersenneTwisterConfig::<624, 397>::default();
302        let _ = format!("{}", c);
303
304        // Validation branches
305        p.matrix_a = 0;
306        assert!(
307            MersenneTwisterConfig::<624, 397>::validate(&p).is_err()
308        );
309        p = MersenneTwisterParams::default();
310        p.upper_mask = 0;
311        assert!(
312            MersenneTwisterConfig::<624, 397>::validate(&p).is_err()
313        );
314        p = MersenneTwisterParams::default();
315        p.lower_mask = 0;
316        assert!(
317            MersenneTwisterConfig::<624, 397>::validate(&p).is_err()
318        );
319        p = MersenneTwisterParams::default();
320        p.tempering_mask_b = 0;
321        assert!(
322            MersenneTwisterConfig::<624, 397>::validate(&p).is_err()
323        );
324        p = MersenneTwisterParams::default();
325        p.tempering_mask_c = 0;
326        assert!(
327            MersenneTwisterConfig::<624, 397>::validate(&p).is_err()
328        );
329
330        // N, M bounds
331        assert!(MersenneTwisterConfig::<0, 0>::validate(&p).is_err());
332        assert!(MersenneTwisterConfig::<10, 0>::validate(&p).is_err());
333        assert!(MersenneTwisterConfig::<10, 10>::validate(&p).is_err());
334
335        // set_config
336        let mut cfg = MersenneTwisterConfig::<624, 397>::default();
337        assert!(cfg
338            .set_config(MersenneTwisterParams::default())
339            .is_ok());
340    }
341
342    /// Covers the `Display` impl on `MersenneTwisterConfig` -
343    /// the existing coverage was on Debug only.
344    #[test]
345    #[cfg(any(feature = "alloc", feature = "std"))]
346    fn test_mersenne_twister_config_display() {
347        let cfg = MersenneTwisterConfig::<624, 397>::default();
348        let s = format!("{}", cfg);
349        // All five field names must appear in the formatted output.
350        assert!(s.contains("matrix_a"), "got: {s}");
351        assert!(s.contains("upper_mask"), "got: {s}");
352        assert!(s.contains("lower_mask"), "got: {s}");
353        assert!(s.contains("tempering_mask_b"), "got: {s}");
354        assert!(s.contains("tempering_mask_c"), "got: {s}");
355    }
356}