Skip to content

Commit a20b6bf

Browse files
committed
Implement _random using the MT19937 algorithm
1 parent 908abef commit a20b6bf

File tree

4 files changed

+299
-80
lines changed

4 files changed

+299
-80
lines changed

Cargo.lock

Lines changed: 1 addition & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vm/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ num-traits = "0.2.8"
3030
num-integer = "0.1.41"
3131
num-rational = "0.2.2"
3232
num-iter = "0.1.39"
33-
rand = { version = "0.7", features = ["small_rng"] }
34-
rand_distr = "0.2"
33+
rand = "0.7"
34+
rand_core = "0.5"
3535
getrandom = "0.1"
3636
log = "0.4"
3737
rustpython-derive = {path = "../derive", version = "0.1.1"}

vm/src/stdlib/random.rs

Lines changed: 94 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,60 @@
33
use std::cell::RefCell;
44

55
use num_bigint::{BigInt, Sign};
6+
use num_traits::Signed;
7+
use rand::RngCore;
68

7-
use rand::distributions::Distribution;
8-
use rand::{RngCore, SeedableRng};
9-
use rand::rngs::SmallRng;
10-
use rand_distr::Normal;
11-
12-
use crate::function::OptionalArg;
9+
use crate::function::OptionalOption;
10+
use crate::obj::objint::PyIntRef;
1311
use crate::obj::objtype::PyClassRef;
14-
use crate::pyobject::{PyClassImpl, PyObjectRef, PyRef, PyValue, PyResult};
12+
use crate::pyobject::{PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue};
13+
use crate::VirtualMachine;
14+
15+
mod mersenne;
16+
17+
#[derive(Debug)]
18+
enum PyRng {
19+
Std(rand::rngs::ThreadRng),
20+
MT(mersenne::MT19937),
21+
}
1522

16-
use crate::vm::VirtualMachine;
23+
impl Default for PyRng {
24+
fn default() -> Self {
25+
PyRng::Std(rand::thread_rng())
26+
}
27+
}
28+
29+
impl RngCore for PyRng {
30+
fn next_u32(&mut self) -> u32 {
31+
match self {
32+
Self::Std(s) => s.next_u32(),
33+
Self::MT(m) => m.next_u32(),
34+
}
35+
}
36+
fn next_u64(&mut self) -> u64 {
37+
match self {
38+
Self::Std(s) => s.next_u64(),
39+
Self::MT(m) => m.next_u64(),
40+
}
41+
}
42+
fn fill_bytes(&mut self, dest: &mut [u8]) {
43+
match self {
44+
Self::Std(s) => s.fill_bytes(dest),
45+
Self::MT(m) => m.fill_bytes(dest),
46+
}
47+
}
48+
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
49+
match self {
50+
Self::Std(s) => s.try_fill_bytes(dest),
51+
Self::MT(m) => m.try_fill_bytes(dest),
52+
}
53+
}
54+
}
1755

1856
#[pyclass(name = "Random")]
1957
#[derive(Debug)]
2058
struct PyRandom {
21-
rng: RefCell<SmallRng>
59+
rng: RefCell<PyRng>,
2260
}
2361

2462
impl PyValue for PyRandom {
@@ -27,86 +65,74 @@ impl PyValue for PyRandom {
2765
}
2866
}
2967

30-
#[pyimpl]
68+
#[pyimpl(flags(BASETYPE))]
3169
impl PyRandom {
3270
#[pyslot(new)]
3371
fn new(cls: PyClassRef, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
3472
PyRandom {
35-
rng: RefCell::new(SmallRng::from_entropy())
36-
}.into_ref_with_type(vm, cls)
73+
rng: RefCell::new(PyRng::default()),
74+
}
75+
.into_ref_with_type(vm, cls)
76+
}
77+
78+
#[pymethod]
79+
fn random(&self) -> f64 {
80+
gen_res53(&mut *self.rng.borrow_mut())
3781
}
3882

39-
#[pymethod]
40-
fn seed(&self, n: Option<usize>, vm: &VirtualMachine) -> PyResult {
41-
let rng = match n {
42-
None => SmallRng::from_entropy(),
83+
#[pymethod]
84+
fn seed(&self, n: OptionalOption<PyIntRef>) {
85+
let new_rng = match n.flat_option() {
86+
None => PyRng::default(),
4387
Some(n) => {
44-
let seed = n as u64;
45-
SmallRng::seed_from_u64(seed)
88+
let (_, mut key) = n.as_bigint().abs().to_u32_digits();
89+
if cfg!(target_endian = "big") {
90+
key.reverse();
91+
}
92+
PyRng::MT(mersenne::MT19937::new_with_slice_seed(&key))
4693
}
4794
};
48-
49-
*self.rng.borrow_mut() = rng;
50-
51-
Ok(vm.ctx.none())
95+
96+
*self.rng.borrow_mut() = new_rng;
5297
}
5398

5499
#[pymethod]
55-
fn getrandbits(&self, k: usize, vm: &VirtualMachine) -> PyResult {
56-
let bytes = (k - 1) / 8 + 1;
57-
let mut bytearray = vec![0u8; bytes];
58-
self.rng.borrow_mut().fill_bytes(&mut bytearray);
59-
60-
let bits = bytes % 8;
61-
if bits > 0 {
62-
bytearray[0] >>= 8 - bits;
100+
fn getrandbits(&self, mut k: usize) -> BigInt {
101+
let mut rng = self.rng.borrow_mut();
102+
103+
let mut gen_u32 = |k| rng.next_u32() >> (32 - k) as u32;
104+
105+
if k <= 32 {
106+
return gen_u32(k).into();
63107
}
64-
65-
println!("{:?}", k);
66-
println!("{:?}", bytearray);
67108

68-
let result = BigInt::from_bytes_be(Sign::Plus, &bytearray);
69-
Ok(vm.ctx.new_bigint(&result))
109+
let words = (k - 1) / 8 + 1;
110+
let mut wordarray = vec![0u32; words];
111+
112+
let it = wordarray.iter_mut();
113+
#[cfg(target_endian = "big")]
114+
let it = it.rev();
115+
for word in it {
116+
*word = gen_u32(k);
117+
k -= 32;
118+
}
119+
120+
BigInt::from_slice(Sign::NoSign, &wordarray)
70121
}
71122
}
72123

73124
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
74125
let ctx = &vm.ctx;
75-
let random_type = PyRandom::make_class(ctx);
76-
77126
py_module!(vm, "_random", {
78-
"Random" => random_type,
79-
"gauss" => ctx.new_function(random_normalvariate), // TODO: is this the same?
80-
"normalvariate" => ctx.new_function(random_normalvariate),
81-
"random" => ctx.new_function(random_random),
82-
// "weibull", ctx.new_function(random_weibullvariate),
127+
"Random" => PyRandom::make_class(ctx),
83128
})
84129
}
85130

86-
fn random_normalvariate(mu: f64, sigma: f64, vm: &VirtualMachine) -> PyResult<f64> {
87-
let normal = Normal::new(mu, sigma).map_err(|rand_err| {
88-
vm.new_exception_msg(
89-
vm.ctx.exceptions.arithmetic_error.clone(),
90-
format!("invalid normal distribution: {:?}", rand_err),
91-
)
92-
})?;
93-
let value = normal.sample(&mut rand::thread_rng());
94-
Ok(value)
95-
}
96-
97-
fn random_random(_vm: &VirtualMachine) -> f64 {
98-
rand::random()
99-
}
100-
101-
/*
102-
* TODO: enable this function:
103-
fn random_weibullvariate(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
104-
arg_check!(vm, args, required = [(alpha, Some(vm.ctx.float_type())), (beta, Some(vm.ctx.float_type()))]);
105-
let alpha = objfloat::get_value(alpha);
106-
let beta = objfloat::get_value(beta);
107-
let weibull = Weibull::new(alpha, beta);
108-
let value = weibull.sample(&mut rand::thread_rng());
109-
let py_value = vm.ctx.new_float(value);
110-
Ok(py_value)
131+
// taken from the translated mersenne twister
132+
/* generates a random number on [0,1) with 53-bit resolution*/
133+
fn gen_res53<R: RngCore>(rng: &mut R) -> f64 {
134+
let a = rng.next_u32() >> 5;
135+
let b = rng.next_u32() >> 6;
136+
(a as f64 * 67108864.0 + b as f64) * (1.0 / 9007199254740992.0)
111137
}
112-
*/
138+
/* These real versions are due to Isaku Wada, 2002/01/09 added */

0 commit comments

Comments
 (0)