vrd/distribution.rs
1// Copyright © 2023-2026 vrd. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! Trait-based distributions and pluggable user-defined samplers.
5//!
6//! [`Distribution<T>`] is the universal handshake between a sampler
7//! and an RNG: anything that can map a [`Random`] into a `T` is a
8//! distribution. The built-in samplers (`Normal`, `Exponential`,
9//! `Uniform`, `Poisson`) live here as concrete impls and forward to
10//! the optimised methods on [`Random`]. Users add their own
11//! distributions by implementing the trait.
12//!
13//! # Examples
14//!
15//! ```
16//! use vrd::{Random, Distribution};
17//! use vrd::distribution::{Normal, Exponential};
18//!
19//! let mut rng = Random::from_u64_seed(1);
20//! let z = Normal { mu: 0.0, sigma: 1.0 }.sample(&mut rng);
21//! let x = Exponential { rate: 1.5 }.sample(&mut rng);
22//! # let _ = (z, x);
23//! ```
24//!
25//! ```
26//! use vrd::{Random, Distribution};
27//!
28//! // User-defined distribution: Bernoulli(p).
29//! struct Bernoulli { p: f64 }
30//!
31//! impl Distribution<bool> for Bernoulli {
32//! fn sample(&self, rng: &mut Random) -> bool {
33//! rng.double() < self.p
34//! }
35//! }
36//!
37//! let mut rng = Random::from_u64_seed(1);
38//! let coin = Bernoulli { p: 0.5 }.sample(&mut rng);
39//! # let _ = coin;
40//! ```
41
42use crate::Random;
43
44/// A distribution that can be sampled with a mutable [`Random`].
45///
46/// `T` is the sample type - usually `f64` for continuous
47/// distributions, `u64` for integer ones, but any `T` is allowed.
48pub trait Distribution<T> {
49 /// Draws one sample from `self` using `rng`.
50 fn sample(&self, rng: &mut Random) -> T;
51
52 /// Returns an iterator that calls `sample` repeatedly. Useful
53 /// for `take(n).collect()`-style consumption.
54 fn samples<'a>(&'a self, rng: &'a mut Random) -> Iter<'a, Self, T>
55 where
56 Self: Sized,
57 {
58 Iter {
59 dist: self,
60 rng,
61 _t: core::marker::PhantomData,
62 }
63 }
64}
65
66/// Iterator returned by [`Distribution::samples`].
67#[derive(Debug)]
68pub struct Iter<'a, D: Distribution<T> + ?Sized, T> {
69 /// The distribution being sampled.
70 dist: &'a D,
71 /// The RNG `dist` draws from on each `next()` call.
72 rng: &'a mut Random,
73 /// Anchors the unconstrained type parameter `T`.
74 _t: core::marker::PhantomData<T>,
75}
76
77impl<D: Distribution<T>, T> Iterator for Iter<'_, D, T> {
78 type Item = T;
79
80 fn next(&mut self) -> Option<T> {
81 Some(self.dist.sample(self.rng))
82 }
83}
84
85// ---------------- Built-in continuous distributions ----------------------
86
87/// Standard normal `N(mu, sigma^2)` - Ziggurat sampler, see
88/// [`Random::normal`].
89#[derive(Clone, Copy, Debug)]
90pub struct Normal {
91 /// Mean.
92 pub mu: f64,
93 /// Standard deviation. Must be ≥ 0.
94 pub sigma: f64,
95}
96
97impl Distribution<f64> for Normal {
98 fn sample(&self, rng: &mut Random) -> f64 {
99 rng.normal(self.mu, self.sigma)
100 }
101}
102
103/// Exponential with rate `lambda`. Mean is `1/lambda`.
104#[derive(Clone, Copy, Debug)]
105pub struct Exponential {
106 /// Rate parameter `λ`. Must be > 0.
107 pub rate: f64,
108}
109
110impl Distribution<f64> for Exponential {
111 fn sample(&self, rng: &mut Random) -> f64 {
112 rng.exponential(self.rate)
113 }
114}
115
116/// Continuous uniform on `[low, high)`.
117#[derive(Clone, Copy, Debug)]
118pub struct Uniform {
119 /// Inclusive lower bound.
120 pub low: f64,
121 /// Exclusive upper bound. Must be > `low`.
122 pub high: f64,
123}
124
125impl Distribution<f64> for Uniform {
126 fn sample(&self, rng: &mut Random) -> f64 {
127 rng.uniform(self.low, self.high)
128 }
129}
130
131/// Poisson with mean `lambda`.
132#[derive(Clone, Copy, Debug)]
133pub struct Poisson {
134 /// Mean parameter `λ`. Must be > 0.
135 pub mean: f64,
136}
137
138impl Distribution<u64> for Poisson {
139 fn sample(&self, rng: &mut Random) -> u64 {
140 rng.poisson(self.mean)
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 #[cfg(feature = "alloc")]
148 use alloc::vec::Vec;
149 #[cfg(all(not(feature = "alloc"), feature = "std"))]
150 use std::vec::Vec;
151
152 #[test]
153 fn normal_distribution_samples() {
154 let mut rng = Random::from_u64_seed(1);
155 let n = Normal {
156 mu: 0.0,
157 sigma: 1.0,
158 };
159 for _ in 0..256 {
160 assert!(n.sample(&mut rng).is_finite());
161 }
162 }
163
164 #[test]
165 fn exponential_distribution_samples() {
166 let mut rng = Random::from_u64_seed(1);
167 let e = Exponential { rate: 2.0 };
168 for _ in 0..256 {
169 assert!(e.sample(&mut rng) >= 0.0);
170 }
171 }
172
173 #[test]
174 fn uniform_distribution_in_range() {
175 let mut rng = Random::from_u64_seed(1);
176 let u = Uniform {
177 low: -5.0,
178 high: 5.0,
179 };
180 for _ in 0..256 {
181 let s = u.sample(&mut rng);
182 assert!((-5.0..5.0).contains(&s));
183 }
184 }
185
186 #[test]
187 fn poisson_distribution_samples() {
188 let mut rng = Random::from_u64_seed(1);
189 let p = Poisson { mean: 3.0 };
190 for _ in 0..256 {
191 // u64 is always >= 0; just confirm it doesn't panic /
192 // exhaust the iterator.
193 let _ = p.sample(&mut rng);
194 }
195 }
196
197 /// User-defined distribution interop check.
198 #[test]
199 fn custom_distribution_compiles_and_runs() {
200 struct Coin {
201 bias: f64,
202 }
203 impl Distribution<bool> for Coin {
204 fn sample(&self, rng: &mut Random) -> bool {
205 rng.double() < self.bias
206 }
207 }
208 let mut rng = Random::from_u64_seed(7);
209 let fair: Vec<bool> =
210 Coin { bias: 0.5 }.samples(&mut rng).take(64).collect();
211 assert_eq!(fair.len(), 64);
212 }
213}