vrd/mersenne_twister.rs
1// Copyright © 2023-2026 vrd. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! Mersenne Twister (MT19937) configuration types.
5//!
6//! The actual MT19937 generator is implemented in [`crate::random`]; this
7//! module provides the configuration parameters and validation.
8
9use core::fmt;
10
11#[cfg(feature = "serde")]
12use serde::{Deserialize, Serialize};
13
14/// Errors produced by [`MersenneTwisterConfig`].
15///
16/// # Examples
17///
18/// ```
19/// use vrd::MersenneTwisterError;
20///
21/// let err = MersenneTwisterError::InvalidConfig("N must be at least 1");
22/// assert_eq!(format!("{err}"), "Invalid configuration: N must be at least 1");
23/// ```
24#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
25#[non_exhaustive]
26pub enum MersenneTwisterError {
27 /// A configuration parameter was outside its valid range.
28 ///
29 /// The payload is a `&'static str` so the error type stays
30 /// allocation-free and works under `no_std` without `alloc`.
31 InvalidConfig(&'static str),
32}
33
34impl fmt::Display for MersenneTwisterError {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 match self {
37 MersenneTwisterError::InvalidConfig(msg) => {
38 write!(f, "Invalid configuration: {}", msg)
39 }
40 }
41 }
42}
43
44#[cfg(feature = "std")]
45impl std::error::Error for MersenneTwisterError {}
46
47/// Parameter values for the Mersenne Twister algorithm.
48///
49/// The defaults match the canonical MT19937 constants. Custom parameters
50/// must satisfy the well-known invariants - see
51/// [`MersenneTwisterConfig::validate`].
52///
53/// # Examples
54///
55/// ```
56/// use vrd::MersenneTwisterParams;
57///
58/// let p = MersenneTwisterParams::default();
59/// assert_eq!(p.matrix_a, 0x9908b0df);
60/// assert_eq!(p.upper_mask, 0x80000000);
61/// ```
62#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
63#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
64pub struct MersenneTwisterParams {
65 /// Constant whose highest bit must be set (canonical: `0x9908b0df`).
66 pub matrix_a: u32,
67 /// Upper-bit mask (canonical: `0x80000000`).
68 pub upper_mask: u32,
69 /// Lower-bit mask (canonical: `0x7fffffff`).
70 pub lower_mask: u32,
71 /// Tempering mask B (canonical: `0x9d2c5680`).
72 pub tempering_mask_b: u32,
73 /// Tempering mask C (canonical: `0xefc60000`).
74 pub tempering_mask_c: u32,
75}
76
77impl Default for MersenneTwisterParams {
78 /// Returns the canonical MT19937 constants.
79 ///
80 /// # Examples
81 ///
82 /// ```
83 /// use vrd::MersenneTwisterParams;
84 ///
85 /// let p = MersenneTwisterParams::default();
86 /// assert_eq!(p.matrix_a, 0x9908b0df);
87 /// ```
88 fn default() -> Self {
89 MersenneTwisterParams {
90 matrix_a: 0x9908b0df,
91 upper_mask: 0x80000000,
92 lower_mask: 0x7fffffff,
93 tempering_mask_b: 0x9d2c5680,
94 tempering_mask_c: 0xefc60000,
95 }
96 }
97}
98
99/// Configuration for an MT19937-style Mersenne Twister.
100///
101/// `N` is the array length; `M` is the recurrence offset. The canonical
102/// MT19937 instantiation is `MersenneTwisterConfig::<624, 397>`.
103///
104/// # Examples
105///
106/// ```
107/// use vrd::MersenneTwisterConfig;
108///
109/// let cfg = MersenneTwisterConfig::<624, 397>::default();
110/// assert_eq!(cfg.params.upper_mask, 0x80000000);
111/// ```
112#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
113#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
114pub struct MersenneTwisterConfig<const N: usize, const M: usize> {
115 /// The validated configuration parameters.
116 pub params: MersenneTwisterParams,
117}
118
119impl<const N: usize, const M: usize> MersenneTwisterConfig<N, M> {
120 /// Builds a config with custom parameters, validating the invariants.
121 ///
122 /// # Examples
123 ///
124 /// ```
125 /// use vrd::{MersenneTwisterConfig, MersenneTwisterParams};
126 ///
127 /// let p = MersenneTwisterParams::default();
128 /// let cfg = MersenneTwisterConfig::<624, 397>::new_custom(p).unwrap();
129 /// assert_eq!(cfg.params, p);
130 /// ```
131 ///
132 /// # Errors
133 ///
134 /// Returns [`MersenneTwisterError::InvalidConfig`] when any parameter
135 /// (or the `N`/`M` const generics) violates a Mersenne-Twister
136 /// invariant.
137 pub fn new_custom(
138 params: MersenneTwisterParams,
139 ) -> Result<Self, MersenneTwisterError> {
140 Self::validate(¶ms)?;
141 Ok(MersenneTwisterConfig { params })
142 }
143
144 /// Validates `params` against the Mersenne-Twister invariants.
145 ///
146 /// # Examples
147 ///
148 /// ```
149 /// use vrd::{MersenneTwisterConfig, MersenneTwisterParams};
150 ///
151 /// let p = MersenneTwisterParams::default();
152 /// assert!(MersenneTwisterConfig::<624, 397>::validate(&p).is_ok());
153 ///
154 /// // Invalid: M >= N.
155 /// assert!(MersenneTwisterConfig::<10, 10>::validate(&p).is_err());
156 /// ```
157 ///
158 /// # Errors
159 ///
160 /// Returns [`MersenneTwisterError::InvalidConfig`] when the invariants
161 /// are violated. The static message in the error names which one.
162 pub fn validate(
163 params: &MersenneTwisterParams,
164 ) -> Result<(), MersenneTwisterError> {
165 if N < 1 {
166 return Err(MersenneTwisterError::InvalidConfig(
167 "N must be at least 1",
168 ));
169 }
170 if M < 1 || M >= N {
171 return Err(MersenneTwisterError::InvalidConfig(
172 "M must be at least 1 and less than N",
173 ));
174 }
175 if params.matrix_a & 0x80000000 != 0x80000000 {
176 return Err(MersenneTwisterError::InvalidConfig(
177 "matrix_a must have its highest bit set",
178 ));
179 }
180 if params.upper_mask != 0x80000000 {
181 return Err(MersenneTwisterError::InvalidConfig(
182 "upper_mask must be 0x80000000",
183 ));
184 }
185 if params.lower_mask != 0x7fffffff {
186 return Err(MersenneTwisterError::InvalidConfig(
187 "lower_mask must be 0x7fffffff",
188 ));
189 }
190 if params.tempering_mask_b != 0x9d2c5680 {
191 return Err(MersenneTwisterError::InvalidConfig(
192 "tempering_mask_b must be 0x9d2c5680",
193 ));
194 }
195 if params.tempering_mask_c != 0xefc60000 {
196 return Err(MersenneTwisterError::InvalidConfig(
197 "tempering_mask_c must be 0xefc60000",
198 ));
199 }
200 Ok(())
201 }
202
203 /// Builds a config using the canonical MT19937 parameters.
204 ///
205 /// # Examples
206 ///
207 /// ```
208 /// use vrd::MersenneTwisterConfig;
209 ///
210 /// let cfg = MersenneTwisterConfig::<624, 397>::new().unwrap();
211 /// assert_eq!(cfg.params.matrix_a, 0x9908b0df);
212 /// ```
213 ///
214 /// # Errors
215 ///
216 /// Returns [`MersenneTwisterError::InvalidConfig`] only if the
217 /// `N`/`M` const generics violate the Mersenne-Twister invariants.
218 /// `MersenneTwisterConfig::<624, 397>::new()` is infallible.
219 pub fn new() -> Result<Self, MersenneTwisterError> {
220 Self::new_custom(MersenneTwisterParams::default())
221 }
222
223 /// Replaces the parameters in place after re-validating them.
224 /// On error, the existing parameters are preserved.
225 ///
226 /// # Examples
227 ///
228 /// ```
229 /// use vrd::{MersenneTwisterConfig, MersenneTwisterParams};
230 ///
231 /// let mut cfg = MersenneTwisterConfig::<624, 397>::default();
232 /// cfg.set_config(MersenneTwisterParams::default()).unwrap();
233 /// ```
234 ///
235 /// # Errors
236 ///
237 /// Returns [`MersenneTwisterError::InvalidConfig`] when `params`
238 /// fails validation; the existing `self.params` is left unchanged.
239 pub fn set_config(
240 &mut self,
241 params: MersenneTwisterParams,
242 ) -> Result<(), MersenneTwisterError> {
243 Self::validate(¶ms)?;
244 self.params = params;
245 Ok(())
246 }
247}
248
249impl Default for MersenneTwisterConfig<624, 397> {
250 /// Returns the canonical MT19937 configuration.
251 ///
252 /// # Examples
253 ///
254 /// ```
255 /// use vrd::MersenneTwisterConfig;
256 ///
257 /// let cfg = MersenneTwisterConfig::<624, 397>::default();
258 /// assert_eq!(cfg.params.lower_mask, 0x7fffffff);
259 /// ```
260 fn default() -> Self {
261 MersenneTwisterConfig::new()
262 .expect("canonical MT19937 parameters always validate")
263 }
264}
265
266impl<const N: usize, const M: usize> fmt::Display
267 for MersenneTwisterConfig<N, M>
268{
269 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270 let p = &self.params;
271 write!(f, "MersenneTwisterConfig {{ matrix_a: 0x{:08x}, upper_mask: 0x{:08x}, lower_mask: 0x{:08x}, tempering_mask_b: 0x{:08x}, tempering_mask_c: 0x{:08x} }}", p.matrix_a, p.upper_mask, p.lower_mask, p.tempering_mask_b, p.tempering_mask_c)
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[cfg(feature = "alloc")]
280 use alloc::format;
281 #[cfg(all(not(feature = "alloc"), feature = "std"))]
282 use std::format;
283
284 #[test]
285 fn test_exhaustive_mt_coverage() {
286 let mut p = MersenneTwisterParams::default();
287
288 // Error Display
289 let err = MersenneTwisterError::InvalidConfig("foo");
290 let _ = format!("{}", err);
291 #[cfg(feature = "std")]
292 {
293 use std::error::Error;
294 assert!(err.source().is_none());
295 }
296
297 // Params Debug
298 let _ = format!("{:?}", p);
299
300 // Config Display
301 let c = MersenneTwisterConfig::<624, 397>::default();
302 let _ = format!("{}", c);
303
304 // Validation branches
305 p.matrix_a = 0;
306 assert!(
307 MersenneTwisterConfig::<624, 397>::validate(&p).is_err()
308 );
309 p = MersenneTwisterParams::default();
310 p.upper_mask = 0;
311 assert!(
312 MersenneTwisterConfig::<624, 397>::validate(&p).is_err()
313 );
314 p = MersenneTwisterParams::default();
315 p.lower_mask = 0;
316 assert!(
317 MersenneTwisterConfig::<624, 397>::validate(&p).is_err()
318 );
319 p = MersenneTwisterParams::default();
320 p.tempering_mask_b = 0;
321 assert!(
322 MersenneTwisterConfig::<624, 397>::validate(&p).is_err()
323 );
324 p = MersenneTwisterParams::default();
325 p.tempering_mask_c = 0;
326 assert!(
327 MersenneTwisterConfig::<624, 397>::validate(&p).is_err()
328 );
329
330 // N, M bounds
331 assert!(MersenneTwisterConfig::<0, 0>::validate(&p).is_err());
332 assert!(MersenneTwisterConfig::<10, 0>::validate(&p).is_err());
333 assert!(MersenneTwisterConfig::<10, 10>::validate(&p).is_err());
334
335 // set_config
336 let mut cfg = MersenneTwisterConfig::<624, 397>::default();
337 assert!(cfg
338 .set_config(MersenneTwisterParams::default())
339 .is_ok());
340 }
341
342 /// Covers the `Display` impl on `MersenneTwisterConfig` -
343 /// the existing coverage was on Debug only.
344 #[test]
345 #[cfg(any(feature = "alloc", feature = "std"))]
346 fn test_mersenne_twister_config_display() {
347 let cfg = MersenneTwisterConfig::<624, 397>::default();
348 let s = format!("{}", cfg);
349 // All five field names must appear in the formatted output.
350 assert!(s.contains("matrix_a"), "got: {s}");
351 assert!(s.contains("upper_mask"), "got: {s}");
352 assert!(s.contains("lower_mask"), "got: {s}");
353 assert!(s.contains("tempering_mask_b"), "got: {s}");
354 assert!(s.contains("tempering_mask_c"), "got: {s}");
355 }
356}