Skip to content

Commit baf6c26

Browse files
authored
Merge pull request #352 from nathanAjacobs/ReadOnlySequenceDeserializationFix
Fix for generic struct deserialization when using ReadOnlySequence
2 parents c42d4a0 + 1cc9c4e commit baf6c26

File tree

2 files changed

+148
-1
lines changed

2 files changed

+148
-1
lines changed

src/MemoryPack.Core/MemoryPackSerializer.Deserialize.cs

+45
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,51 @@ public static int Deserialize<
7575
#endif
7676
T>(in ReadOnlySequence<byte> buffer, ref T? value, MemoryPackSerializerOptions? options = default)
7777
{
78+
if (!RuntimeHelpers.IsReferenceOrContainsReferences<T>())
79+
{
80+
int sizeOfT = Unsafe.SizeOf<T>();
81+
if (buffer.Length < sizeOfT)
82+
{
83+
MemoryPackSerializationException.ThrowInvalidRange(Unsafe.SizeOf<T>(), (int)buffer.Length);
84+
}
85+
86+
ReadOnlySequence<byte> sliced = buffer.Slice(0, sizeOfT);
87+
88+
if (sliced.IsSingleSegment)
89+
{
90+
value = Unsafe.ReadUnaligned<T>(ref MemoryMarshal.GetReference(sliced.FirstSpan));
91+
return sizeOfT;
92+
}
93+
else
94+
{
95+
// We can't read directly from ReadOnlySequence<byte> to T, so we copy to a temp array.
96+
// if less than 512 bytes, use stackalloc, otherwise use MemoryPool<byte>
97+
byte[]? tempArray = null;
98+
99+
Span<byte> tempSpan = sizeOfT <= 512 ? stackalloc byte[sizeOfT] : default;
100+
101+
try
102+
{
103+
if (sizeOfT > 512)
104+
{
105+
tempArray = ArrayPool<byte>.Shared.Rent(sizeOfT);
106+
tempSpan = tempArray;
107+
}
108+
109+
sliced.CopyTo(tempSpan);
110+
value = Unsafe.ReadUnaligned<T>(ref MemoryMarshal.GetReference(tempSpan));
111+
return sizeOfT;
112+
}
113+
finally
114+
{
115+
if (tempArray is not null)
116+
{
117+
ArrayPool<byte>.Shared.Return(tempArray);
118+
}
119+
}
120+
}
121+
}
122+
78123
var state = threadStaticReaderOptionalState;
79124
if (state == null)
80125
{

tests/MemoryPack.Tests/DeserializeTest.cs

+103-1
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
using System;
2+
using System.Buffers;
23
using System.Collections.Generic;
34
using System.Dynamic;
45
using System.IO;
56
using System.Linq;
7+
using System.Runtime.InteropServices;
68
using System.Text;
79
using System.Threading.Tasks;
810

911
namespace MemoryPack.Tests;
1012

11-
public class DeserializeTest
13+
public partial class DeserializeTest
1214
{
1315
[Fact]
1416
public async Task StreamTest()
@@ -30,6 +32,106 @@ public async Task StreamTest()
3032
result.Should().Equal(expected);
3133
}
3234

35+
[Fact]
36+
public void GenericValueStructTest()
37+
{
38+
GenericStruct<int> value = new() { Id = 75, Value = 23 };
39+
40+
RunMultiSegmentTest(value);
41+
}
42+
43+
[Fact]
44+
public void LargeGenericValueStructTest()
45+
{
46+
GenericStruct<PrePaddedInt> value = new() { Id = 75, Value = new PrePaddedInt() { Value = 23 } };
47+
48+
RunMultiSegmentTest(value);
49+
}
50+
51+
[Fact]
52+
public void GenericReferenceStructTest()
53+
{
54+
GenericStruct<string> value = new GenericStruct<string>() { Id = 75, Value = "Hello World!" };
55+
56+
RunMultiSegmentTest(value);
57+
}
58+
59+
[Fact]
60+
public void LargeGenericReferenceStructTest()
61+
{
62+
GenericStruct<PrePaddedString> value = new() { Id = 75, Value = new PrePaddedString() { Value = "Hello World!" } };
63+
64+
RunMultiSegmentTest(value);
65+
}
66+
67+
private void RunMultiSegmentTest<T>(T value)
68+
{
69+
byte[] bytes = MemoryPackSerializer.Serialize(value);
70+
71+
byte[] firstHalf = new byte[bytes.Length / 2];
72+
Array.Copy(bytes, 0, firstHalf, 0, firstHalf.Length);
73+
74+
int secondHalfLength = bytes.Length / 2;
75+
if (bytes.Length % 2 != 0)
76+
{
77+
secondHalfLength++;
78+
}
79+
80+
byte[] secondHalf = new byte[secondHalfLength];
81+
82+
Array.Copy(bytes, firstHalf.Length, secondHalf, 0, secondHalfLength);
83+
84+
ReadOnlySequence<byte> sequence = ReadOnlySequenceBuilder.Create(firstHalf, secondHalf);
85+
86+
T? result = MemoryPackSerializer.Deserialize<T>(sequence);
87+
result.Should().Be(value);
88+
}
89+
90+
[MemoryPackable]
91+
public partial struct GenericStruct<T>
92+
{
93+
public int Id;
94+
public T Value;
95+
96+
public override string ToString()
97+
{
98+
return $"{Id}, {Value}";
99+
}
100+
}
101+
102+
[StructLayout(LayoutKind.Explicit, Size = 516)]
103+
struct PrePaddedInt
104+
{
105+
[FieldOffset(512)]
106+
public int Value;
107+
}
108+
109+
[MemoryPackable]
110+
private partial class PrePaddedString : IEquatable<PrePaddedString>
111+
{
112+
private PrePaddedInt _padding;
113+
public string Value { get; set; } = "";
114+
115+
public bool Equals(PrePaddedString? other)
116+
{
117+
if (other is null)
118+
return false;
119+
120+
return Value.Equals(other.Value);
121+
}
122+
123+
public override bool Equals(object? obj)
124+
{
125+
if (obj is PrePaddedString other)
126+
return Equals(other);
127+
return false;
128+
}
129+
130+
public override int GetHashCode()
131+
{
132+
return Value.GetHashCode();
133+
}
134+
}
33135

34136
class RandomStream : Stream
35137
{

0 commit comments

Comments
 (0)