Skip to content

Commit 847c12c

Browse files
committed
Add optional support for borsh serialisation
Behind a feature flag.
1 parent 40bb0b2 commit 847c12c

File tree

5 files changed

+180
-2
lines changed

5 files changed

+180
-2
lines changed

Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ libc = { version = "0.2.82", optional = true }
4343

4444
matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm"] }
4545

46+
borsh = { version = "1.2", optional = true, default-features = false }
4647
serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] }
4748
rawpointer = { version = "0.2" }
4849

@@ -66,7 +67,7 @@ serde-1 = ["serde"]
6667
test = []
6768

6869
# This feature is used for docs
69-
docs = ["approx", "approx-0_5", "serde", "rayon"]
70+
docs = ["approx", "approx-0_5", "serde", "borsh", "rayon"]
7071

7172
std = ["num-traits/std", "matrixmultiply/std"]
7273
rayon = ["rayon_", "std"]

src/array_borsh.rs

+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
use crate::imp_prelude::*;
2+
use crate::IntoDimension;
3+
use alloc::vec::Vec;
4+
use borsh::{BorshDeserialize, BorshSerialize};
5+
use core::ops::Deref;
6+
7+
/// **Requires crate feature `"borsh"`**
8+
impl<I> BorshSerialize for Dim<I>
9+
where
10+
I: BorshSerialize,
11+
{
12+
fn serialize<W: borsh::io::Write>(&self, writer: &mut W) -> borsh::io::Result<()> {
13+
<I as BorshSerialize>::serialize(&self.ix(), writer)
14+
}
15+
}
16+
17+
/// **Requires crate feature `"borsh"`**
18+
impl<I> BorshDeserialize for Dim<I>
19+
where
20+
I: BorshDeserialize,
21+
{
22+
fn deserialize_reader<R: borsh::io::Read>(reader: &mut R) -> borsh::io::Result<Self> {
23+
<I as BorshDeserialize>::deserialize_reader(reader).map(Dim::new)
24+
}
25+
}
26+
27+
/// **Requires crate feature `"borsh"`**
28+
impl BorshSerialize for IxDyn {
29+
fn serialize<W: borsh::io::Write>(&self, writer: &mut W) -> borsh::io::Result<()> {
30+
let elts = self.ix().deref();
31+
// Output length of dimensions.
32+
<usize as BorshSerialize>::serialize(&elts.len(), writer)?;
33+
// Followed by actual data.
34+
for elt in elts {
35+
<Ix as BorshSerialize>::serialize(elt, writer)?;
36+
}
37+
Ok(())
38+
}
39+
}
40+
41+
/// **Requires crate feature `"borsh"`**
42+
impl BorshDeserialize for IxDyn {
43+
fn deserialize_reader<R: borsh::io::Read>(reader: &mut R) -> borsh::io::Result<Self> {
44+
// Deserialize the length.
45+
let len = <usize as BorshDeserialize>::deserialize_reader(reader)?;
46+
// Deserialize the given number of elements. We assume the source is
47+
// trusted so we use a capacity hint...
48+
let mut elts = Vec::with_capacity(len);
49+
for _ix in 0..len {
50+
elts.push(<Ix as BorshDeserialize>::deserialize_reader(reader)?);
51+
}
52+
Ok(elts.into_dimension())
53+
}
54+
}
55+
56+
/// **Requires crate feature `"borsh"`**
57+
impl<A, D, S> BorshSerialize for ArrayBase<S, D>
58+
where
59+
A: BorshSerialize,
60+
D: Dimension + BorshSerialize,
61+
S: Data<Elem = A>,
62+
{
63+
fn serialize<W: borsh::io::Write>(&self, writer: &mut W) -> borsh::io::Result<()> {
64+
// Dimensions
65+
<D as BorshSerialize>::serialize(&self.raw_dim(), writer)?;
66+
// Followed by length of data
67+
let iter = self.iter();
68+
<usize as BorshSerialize>::serialize(&iter.len(), writer)?;
69+
// Followed by data itself.
70+
for elt in iter {
71+
<A as BorshSerialize>::serialize(elt, writer)?;
72+
}
73+
Ok(())
74+
}
75+
}
76+
77+
/// **Requires crate feature `"borsh"`**
78+
impl<A, D, S> BorshDeserialize for ArrayBase<S, D>
79+
where
80+
A: BorshDeserialize,
81+
D: BorshDeserialize + Dimension,
82+
S: DataOwned<Elem = A>,
83+
{
84+
fn deserialize_reader<R: borsh::io::Read>(reader: &mut R) -> borsh::io::Result<Self> {
85+
// Dimensions
86+
let dim = <D as BorshDeserialize>::deserialize_reader(reader)?;
87+
// Followed by length of data
88+
let len = <usize as BorshDeserialize>::deserialize_reader(reader)?;
89+
// Followed by data itself.
90+
let mut data = Vec::with_capacity(len);
91+
for _ix in 0..len {
92+
data.push(<A as BorshDeserialize>::deserialize_reader(reader)?);
93+
}
94+
ArrayBase::from_shape_vec(dim, data).map_err(|_shape_err| {
95+
borsh::io::Error::new(
96+
borsh::io::ErrorKind::InvalidData,
97+
"data and dimensions must match in size",
98+
)
99+
})
100+
}
101+
}

src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ mod aliases;
164164
#[macro_use]
165165
mod itertools;
166166
mod argument_traits;
167+
#[cfg(feature = "borsh")]
168+
mod array_borsh;
167169
#[cfg(feature = "serde")]
168170
mod array_serde;
169171
mod arrayformat;

xtest-serialization/Cargo.toml

+5-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ publish = false
88
test = false
99

1010
[dependencies]
11-
ndarray = { path = "..", features = ["serde"] }
11+
ndarray = { path = "..", features = ["serde", "borsh"] }
1212

1313
[features]
1414
default = ["ron"]
@@ -23,6 +23,10 @@ version = "1.0.40"
2323
[dev-dependencies.rmp-serde]
2424
version = "0.14.0"
2525

26+
[dev-dependencies.borsh]
27+
version = "1.2"
28+
default-features = false
29+
2630
[dependencies.ron]
2731
version = "0.5.1"
2832
optional = true

xtest-serialization/tests/serialize.rs

+70
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ extern crate rmp_serde;
99
#[cfg(feature = "ron")]
1010
extern crate ron;
1111

12+
extern crate borsh;
13+
1214
use ndarray::{arr0, arr1, arr2, s, ArcArray, ArcArray2, ArrayD, IxDyn};
1315

1416
#[test]
@@ -218,3 +220,71 @@ fn serial_many_dim_ron() {
218220
assert_eq!(a, a_de);
219221
}
220222
}
223+
224+
#[test]
225+
fn serial_ixdyn_borsh() {
226+
{
227+
let a = arr0::<f32>(2.72).into_dyn();
228+
let serial = borsh::to_vec(&a).unwrap();
229+
println!("Borsh encode {:?} => {:?}", a, serial);
230+
let res = borsh::from_slice::<ArcArray<f32, _>>(&serial);
231+
println!("{:?}", res);
232+
assert_eq!(a, res.unwrap());
233+
}
234+
235+
{
236+
let a = arr1::<f32>(&[2.72, 1., 2.]).into_dyn();
237+
let serial = borsh::to_vec(&a).unwrap();
238+
println!("Borsh encode {:?} => {:?}", a, serial);
239+
let res = borsh::from_slice::<ArrayD<f32>>(&serial);
240+
println!("{:?}", res);
241+
assert_eq!(a, res.unwrap());
242+
}
243+
244+
{
245+
let a = arr2(&[[3., 1., 2.2], [3.1, 4., 7.]])
246+
.into_shape(IxDyn(&[3, 1, 1, 1, 2, 1]))
247+
.unwrap();
248+
let serial = borsh::to_vec(&a).unwrap();
249+
println!("Borsh encode {:?} => {:?}", a, serial);
250+
let res = borsh::from_slice::<ArrayD<f32>>(&serial);
251+
println!("{:?}", res);
252+
assert_eq!(a, res.unwrap());
253+
}
254+
}
255+
256+
#[test]
257+
fn serial_many_dim_borsh() {
258+
use borsh::from_slice as borsh_deserialize;
259+
use borsh::to_vec as borsh_serialize;
260+
261+
{
262+
let a = arr0::<f32>(2.72);
263+
let a_s = borsh_serialize(&a).unwrap();
264+
let a_de: ArcArray<f32, _> = borsh_deserialize(&a_s).unwrap();
265+
assert_eq!(a, a_de);
266+
}
267+
268+
{
269+
let a = arr1::<f32>(&[2.72, 1., 2.]);
270+
let a_s = borsh_serialize(&a).unwrap();
271+
let a_de: ArcArray<f32, _> = borsh_deserialize(&a_s).unwrap();
272+
assert_eq!(a, a_de);
273+
}
274+
275+
{
276+
let a = arr2(&[[3., 1., 2.2], [3.1, 4., 7.]]);
277+
let a_s = borsh_serialize(&a).unwrap();
278+
let a_de: ArcArray<f32, _> = borsh_deserialize(&a_s).unwrap();
279+
assert_eq!(a, a_de);
280+
}
281+
282+
{
283+
// Test a sliced array.
284+
let mut a = ArcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4));
285+
a.slice_collapse(s![..;-1, .., .., ..2]);
286+
let a_s = borsh_serialize(&a).unwrap();
287+
let a_de: ArcArray<f32, _> = borsh_deserialize(&a_s).unwrap();
288+
assert_eq!(a, a_de);
289+
}
290+
}

0 commit comments

Comments
 (0)