Skip to main content

vrd/
distribution.rs

1// Copyright © 2023-2026 vrd. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! Trait-based distributions and pluggable user-defined samplers.
5//!
6//! [`Distribution<T>`] is the universal handshake between a sampler
7//! and an RNG: anything that can map a [`Random`] into a `T` is a
8//! distribution. The built-in samplers (`Normal`, `Exponential`,
9//! `Uniform`, `Poisson`) live here as concrete impls and forward to
10//! the optimised methods on [`Random`]. Users add their own
11//! distributions by implementing the trait.
12//!
13//! # Examples
14//!
15//! ```
16//! use vrd::{Random, Distribution};
17//! use vrd::distribution::{Normal, Exponential};
18//!
19//! let mut rng = Random::from_u64_seed(1);
20//! let z = Normal { mu: 0.0, sigma: 1.0 }.sample(&mut rng);
21//! let x = Exponential { rate: 1.5 }.sample(&mut rng);
22//! # let _ = (z, x);
23//! ```
24//!
25//! ```
26//! use vrd::{Random, Distribution};
27//!
28//! // User-defined distribution: Bernoulli(p).
29//! struct Bernoulli { p: f64 }
30//!
31//! impl Distribution<bool> for Bernoulli {
32//!     fn sample(&self, rng: &mut Random) -> bool {
33//!         rng.double() < self.p
34//!     }
35//! }
36//!
37//! let mut rng = Random::from_u64_seed(1);
38//! let coin = Bernoulli { p: 0.5 }.sample(&mut rng);
39//! # let _ = coin;
40//! ```
41
42use crate::Random;
43
44/// A distribution that can be sampled with a mutable [`Random`].
45///
46/// `T` is the sample type - usually `f64` for continuous
47/// distributions, `u64` for integer ones, but any `T` is allowed.
48pub trait Distribution<T> {
49    /// Draws one sample from `self` using `rng`.
50    fn sample(&self, rng: &mut Random) -> T;
51
52    /// Returns an iterator that calls `sample` repeatedly. Useful
53    /// for `take(n).collect()`-style consumption.
54    fn samples<'a>(&'a self, rng: &'a mut Random) -> Iter<'a, Self, T>
55    where
56        Self: Sized,
57    {
58        Iter {
59            dist: self,
60            rng,
61            _t: core::marker::PhantomData,
62        }
63    }
64}
65
66/// Iterator returned by [`Distribution::samples`].
67#[derive(Debug)]
68pub struct Iter<'a, D: Distribution<T> + ?Sized, T> {
69    /// The distribution being sampled.
70    dist: &'a D,
71    /// The RNG `dist` draws from on each `next()` call.
72    rng: &'a mut Random,
73    /// Anchors the unconstrained type parameter `T`.
74    _t: core::marker::PhantomData<T>,
75}
76
77impl<D: Distribution<T>, T> Iterator for Iter<'_, D, T> {
78    type Item = T;
79
80    fn next(&mut self) -> Option<T> {
81        Some(self.dist.sample(self.rng))
82    }
83}
84
85// ---------------- Built-in continuous distributions ----------------------
86
87/// Standard normal `N(mu, sigma^2)` - Ziggurat sampler, see
88/// [`Random::normal`].
89#[derive(Clone, Copy, Debug)]
90pub struct Normal {
91    /// Mean.
92    pub mu: f64,
93    /// Standard deviation. Must be ≥ 0.
94    pub sigma: f64,
95}
96
97impl Distribution<f64> for Normal {
98    fn sample(&self, rng: &mut Random) -> f64 {
99        rng.normal(self.mu, self.sigma)
100    }
101}
102
103/// Exponential with rate `lambda`. Mean is `1/lambda`.
104#[derive(Clone, Copy, Debug)]
105pub struct Exponential {
106    /// Rate parameter `λ`. Must be > 0.
107    pub rate: f64,
108}
109
110impl Distribution<f64> for Exponential {
111    fn sample(&self, rng: &mut Random) -> f64 {
112        rng.exponential(self.rate)
113    }
114}
115
116/// Continuous uniform on `[low, high)`.
117#[derive(Clone, Copy, Debug)]
118pub struct Uniform {
119    /// Inclusive lower bound.
120    pub low: f64,
121    /// Exclusive upper bound. Must be > `low`.
122    pub high: f64,
123}
124
125impl Distribution<f64> for Uniform {
126    fn sample(&self, rng: &mut Random) -> f64 {
127        rng.uniform(self.low, self.high)
128    }
129}
130
131/// Poisson with mean `lambda`.
132#[derive(Clone, Copy, Debug)]
133pub struct Poisson {
134    /// Mean parameter `λ`. Must be > 0.
135    pub mean: f64,
136}
137
138impl Distribution<u64> for Poisson {
139    fn sample(&self, rng: &mut Random) -> u64 {
140        rng.poisson(self.mean)
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    #[cfg(feature = "alloc")]
148    use alloc::vec::Vec;
149    #[cfg(all(not(feature = "alloc"), feature = "std"))]
150    use std::vec::Vec;
151
152    #[test]
153    fn normal_distribution_samples() {
154        let mut rng = Random::from_u64_seed(1);
155        let n = Normal {
156            mu: 0.0,
157            sigma: 1.0,
158        };
159        for _ in 0..256 {
160            assert!(n.sample(&mut rng).is_finite());
161        }
162    }
163
164    #[test]
165    fn exponential_distribution_samples() {
166        let mut rng = Random::from_u64_seed(1);
167        let e = Exponential { rate: 2.0 };
168        for _ in 0..256 {
169            assert!(e.sample(&mut rng) >= 0.0);
170        }
171    }
172
173    #[test]
174    fn uniform_distribution_in_range() {
175        let mut rng = Random::from_u64_seed(1);
176        let u = Uniform {
177            low: -5.0,
178            high: 5.0,
179        };
180        for _ in 0..256 {
181            let s = u.sample(&mut rng);
182            assert!((-5.0..5.0).contains(&s));
183        }
184    }
185
186    #[test]
187    fn poisson_distribution_samples() {
188        let mut rng = Random::from_u64_seed(1);
189        let p = Poisson { mean: 3.0 };
190        for _ in 0..256 {
191            // u64 is always >= 0; just confirm it doesn't panic /
192            // exhaust the iterator.
193            let _ = p.sample(&mut rng);
194        }
195    }
196
197    /// User-defined distribution interop check.
198    #[test]
199    fn custom_distribution_compiles_and_runs() {
200        struct Coin {
201            bias: f64,
202        }
203        impl Distribution<bool> for Coin {
204            fn sample(&self, rng: &mut Random) -> bool {
205                rng.double() < self.bias
206            }
207        }
208        let mut rng = Random::from_u64_seed(7);
209        let fair: Vec<bool> =
210            Coin { bias: 0.5 }.samples(&mut rng).take(64).collect();
211        assert_eq!(fair.len(), 64);
212    }
213}