Skip to content

Commit 057c326

Browse files
Pass Automated Testing SuitePass Automated Testing Suite
Pass Automated Testing Suite
authored and
Pass Automated Testing Suite
committed
Implement PCG64 without using the uin128 type.
1 parent 29f6478 commit 057c326

File tree

2 files changed

+65
-31
lines changed

2 files changed

+65
-31
lines changed

lib/pcg.ml

+64-30
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,35 @@
55
SPDX-License-Identifier: BSD-3-Clause *)
66
open Stdint
77

8+
module U128 = struct
9+
type t = { high : uint64; low : uint64 }
10+
11+
let of_u64 high low = {high; low}
12+
13+
let one = Uint64.{high = zero; low = one}
14+
15+
let zero = Uint64.{high = zero; low = zero}
16+
17+
let ( + ) a b =
18+
match Uint64.{high = a.high + b.high; low = a.low + b.low} with
19+
| x when x.low < b.low -> {x with high = Uint64.(x.high + one)}
20+
| x -> x
21+
22+
let max32 = Uint32.(max_int |> to_uint64)
23+
let mult64 x y =
24+
let open Uint64 in
25+
let x0 = logand max32 x and y0 = logand max32 y
26+
and x1 = shift_right x 32 and y1 = shift_right y 32 in
27+
let t = shift_right (x0 * y0) 32 + x1 * y0 in
28+
{high = shift_right (logand max32 t + x0 * y1) 32 + (shift_right t 32) + x1 * y1; low = x * y}
29+
30+
let ( * ) a b = match mult64 a.low b.low with
31+
| {high;low} -> {high = Uint64.(high + a.high * b.low + a.low * b.high); low}
32+
33+
(* let ( ** ) a b = match mult64 a.low b with
34+
| x -> {x with high = Uint64.(x.high + a.high * b)} *)
35+
end
36+
837

938
module PCG64 : sig
1039
(** PCG-64 is a 128-bit implementation of O'Neill's permutation congruential
@@ -20,7 +49,7 @@ module PCG64 : sig
2049

2150
include Common.BITGEN
2251

23-
val advance : int128 -> t -> t
52+
val advance : uint64 * uint64 -> t -> t
2453
(** [advance delta] Advances the underlying RNG as if [delta] draws have been made.
2554
The returned state is that of the generator [delta] steps forward. *)
2655

@@ -29,55 +58,60 @@ module PCG64 : sig
2958
(0, bound) as well as the state of the generator advanced one step forward. *)
3059
end = struct
3160
type t = {s : setseq; ustore : uint32 option}
32-
and setseq = {state : uint128; increment : uint128}
61+
and setseq = {state : U128.t; increment : U128.t}
62+
63+
64+
let sixtythree = Uint32.of_int32 63l
65+
let multiplier = U128.of_u64 (Uint64.of_int64 2549297995355413924L)
66+
(Uint64.of_int64 4865540595714422341L)
3367

34-
let multiplier = Uint128.of_string "0x2360ed051fc65da44385df649fccf645"
35-
let sixtythree = Uint32.of_int 63
3668

3769
(* Uses the XSL-RR output function *)
38-
let output state =
39-
let v = Uint128.(shift_right state 64 |> logxor state |> to_uint64)
40-
and r = Uint128.(shift_right state 122 |> to_int) in
41-
let nr = Uint32.(of_int r |> neg |> logand sixtythree |> to_int) in
42-
Uint64.(logor (shift_left v nr) (shift_right v r))
43-
70+
let output U128.{high; low} =
71+
let v = Uint64.(logxor high low) in
72+
let r = Uint64.(shift_right high 58 |> to_int) in
73+
let nr = Uint32.(of_int r |> neg |> logand sixtythree |> to_int) in
74+
Uint64.(logor (shift_left v nr) (shift_right v r))
75+
4476

4577
let next {state; increment} =
46-
let state' = Uint128.(state * multiplier + increment) in
78+
let state' = U128.(state * multiplier + increment) in
4779
output state', {state = state'; increment}
4880

4981

5082
let next_uint64 t = match next t.s with
5183
| u, s -> u, {t with s}
52-
84+
5385

5486
let next_uint32 t =
5587
match Common.next_uint32 ~next:next t.s t.ustore with
56-
| u, s, ustore -> u, {s; ustore}
88+
| u, s, ustore -> u, {s; ustore}
5789

5890

59-
let next_double t = Common.next_double ~nextu64:next_uint64 t
91+
let next_bounded_uint64 bound t = Common.next_bounded_uint64 bound ~nextu64:next_uint64 t
6092

6193

62-
let advance delta {s = {state; increment}; _} =
63-
let open Uint128 in
64-
let rec lcg d am ap cm cp = (* advance state using LCG method *)
65-
match d = zero, logand d one = one with
66-
| true, _ -> am * state + ap
67-
| false, true -> lcg (shift_right d 1) (am * cm) (ap * cm + cp) (cm * cm) (cp * (cm + one))
68-
| false, false -> lcg (shift_right d 1) am ap (cm * cm) (cp * (cm + one))
69-
in {s = {state = lcg (Uint128.of_int128 delta) one zero multiplier increment; increment}; ustore = None}
94+
let next_double t = Common.next_double ~nextu64:next_uint64 t
7095

7196

7297
let set_seed seed =
73-
let open Uint128 in
74-
let s = logor (shift_left (of_uint64 seed.(0)) 64) (of_uint64 seed.(1))
75-
and i = logor (shift_left (of_uint64 seed.(2)) 64) (of_uint64 seed.(3)) in
76-
let increment = logor (shift_left i 1) one in
77-
{state = (increment + s) * multiplier + increment; increment}
78-
79-
80-
let next_bounded_uint64 bound t = Common.next_bounded_uint64 bound ~nextu64:next_uint64 t
98+
let s2 = Uint64.(logor (shift_left seed.(2) 1) (shift_right seed.(3) 63)) in
99+
let s3 = Uint64.(logor (shift_left seed.(3) 1) one) in
100+
let increment = U128.of_u64 s2 s3 in
101+
let state = U128.(zero * multiplier + increment) in
102+
{state = U128.((of_u64 seed.(0) seed.(1) + state) * multiplier + increment); increment}
103+
104+
105+
let advance (d1, d0) {s = {state; increment}; _} =
106+
let open U128 in
107+
let half x = U128.{low = Uint64.(logor (shift_right x.low 1) (shift_left x.high 63));
108+
high = Uint64.(shift_right x.high 1)} in
109+
let rec lcg d am ap cm cp =
110+
match Uint64.(d.high <= zero && d.low <= zero, logand d.low one = one) with
111+
| true, _ -> am * state + ap
112+
| false, true -> lcg (half d) (am * cm) (ap * cm + cp) (cm * cm) (cp * (cm + one))
113+
| false, false -> lcg (half d) am ap (cm * cm) (cp * (cm + one))
114+
in {s = {state = lcg (of_u64 d1 d0) one zero multiplier increment; increment}; ustore = None}
81115

82116

83117
let initialize seed =

test/test_pcg.ml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ let test_advance _ =
77
let t = SeedSequence.initialize [Uint128.of_int 12345] |> PCG64.initialize in
88
let advance n = Seq.(iterate (fun s -> PCG64.next_uint64 s |> snd) t |> drop n |> uncons |> Option.get |> fst) in
99
assert_equal
10-
(PCG64.advance (Int128.of_int 100) t |> PCG64.next_uint64 |> fst |> Uint64.to_string)
10+
(PCG64.advance Uint64.(of_int 0, of_int 100) t |> PCG64.next_uint64 |> fst |> Uint64.to_string)
1111
(advance 100 |> PCG64.next_uint64 |> fst |> Uint64.to_string)
1212
~printer:(fun x -> x)
1313

0 commit comments

Comments
 (0)