From c9d60a6fc7ad9286131cbb0e4b58d006df2d40a3 Mon Sep 17 00:00:00 2001 From: "Ihor.Zhuravlov" Date: Sun, 26 Jan 2025 13:42:43 +0100 Subject: [PATCH] Migrated binary serialization to use Messagepack, extracted serialization logic into a separate class --- ...ensions.DotNet.Interactive.UnitTest.csproj | 1 - .../ExternalClass.cs | 6 +- .../IpcTests/BroadcastTests.cs | 8 +- .../Microsoft.Spark.E2ETest.csproj | 1 - .../BinarySerDeTests.cs | 154 ++++++++++++++++++ .../Microsoft.Spark.UnitTest.csproj | 1 - .../Microsoft.Spark.UnitTest/UdfSerDeTests.cs | 22 +-- .../CommandExecutorTests.cs | 15 +- .../Microsoft.Spark.Worker.UnitTest.csproj | 1 - .../Command/RDDCommandExecutor.cs | 11 +- .../Microsoft.Spark.Worker.csproj | 1 - .../Processor/BroadcastVariableProcessor.cs | 20 +-- src/csharp/Microsoft.Spark/Broadcast.cs | 43 +++-- .../Microsoft.Spark/Microsoft.Spark.csproj | 2 +- src/csharp/Microsoft.Spark/RDD/Collector.cs | 9 +- src/csharp/Microsoft.Spark/SparkContext.cs | 5 +- .../Microsoft.Spark/Utils/BinarySerDe.cs | 113 +++++++++++++ .../Microsoft.Spark/Utils/CommandSerDe.cs | 7 +- .../Utils/DependencyProviderUtils.cs | 7 +- 19 files changed, 329 insertions(+), 98 deletions(-) create mode 100644 src/csharp/Microsoft.Spark.UnitTest/BinarySerDeTests.cs create mode 100644 src/csharp/Microsoft.Spark/Utils/BinarySerDe.cs diff --git a/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive.UnitTest/Microsoft.Spark.Extensions.DotNet.Interactive.UnitTest.csproj b/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive.UnitTest/Microsoft.Spark.Extensions.DotNet.Interactive.UnitTest.csproj index 5128070fb..d3c63bd22 100644 --- a/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive.UnitTest/Microsoft.Spark.Extensions.DotNet.Interactive.UnitTest.csproj +++ b/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive.UnitTest/Microsoft.Spark.Extensions.DotNet.Interactive.UnitTest.csproj @@ -3,7 +3,6 @@ net8.0 Microsoft.Spark.Extensions.DotNet.Interactive.UnitTest - true diff --git a/src/csharp/Microsoft.Spark.E2ETest.ExternalLibrary/ExternalClass.cs b/src/csharp/Microsoft.Spark.E2ETest.ExternalLibrary/ExternalClass.cs index db525172d..086a38e4a 100644 --- a/src/csharp/Microsoft.Spark.E2ETest.ExternalLibrary/ExternalClass.cs +++ b/src/csharp/Microsoft.Spark.E2ETest.ExternalLibrary/ExternalClass.cs @@ -9,11 +9,11 @@ namespace Microsoft.Spark.E2ETest.ExternalLibrary [Serializable] public class ExternalClass { - private string _s; + private string s; public ExternalClass(string s) { - _s = s; + this.s = s; } public static string HelloWorld() @@ -23,7 +23,7 @@ public static string HelloWorld() public string Concat(string s) { - return _s + s; + return this.s + s; } } } diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/BroadcastTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/BroadcastTests.cs index e0443f04c..e9e5f98a5 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/BroadcastTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/BroadcastTests.cs @@ -1,5 +1,6 @@ using System; using System.Linq; +using MessagePack; using Microsoft.Spark.Sql; using Xunit; using static Microsoft.Spark.Sql.Functions; @@ -12,10 +13,11 @@ public class TestBroadcastVariable public int IntValue { get; private set; } public string StringValue { get; private set; } - public TestBroadcastVariable(int intVal, string stringVal) + [SerializationConstructor] + public TestBroadcastVariable(int intValue, string stringValue) { - IntValue = intVal; - StringValue = stringVal; + IntValue = intValue; + StringValue = stringValue; } } diff --git a/src/csharp/Microsoft.Spark.E2ETest/Microsoft.Spark.E2ETest.csproj b/src/csharp/Microsoft.Spark.E2ETest/Microsoft.Spark.E2ETest.csproj index e242da0d0..8af2fae53 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/Microsoft.Spark.E2ETest.csproj +++ b/src/csharp/Microsoft.Spark.E2ETest/Microsoft.Spark.E2ETest.csproj @@ -1,7 +1,6 @@  net8.0 - true diff --git a/src/csharp/Microsoft.Spark.UnitTest/BinarySerDeTests.cs b/src/csharp/Microsoft.Spark.UnitTest/BinarySerDeTests.cs new file mode 100644 index 000000000..7c0a2642a --- /dev/null +++ b/src/csharp/Microsoft.Spark.UnitTest/BinarySerDeTests.cs @@ -0,0 +1,154 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.IO; +using MessagePack; +using Microsoft.Spark.Utils; +using Xunit; + +namespace Microsoft.Spark.UnitTest; + +[Collection("Spark Unit Tests")] +public class BinarySerDeTests +{ + [Theory] + [InlineData(42)] + [InlineData("Test")] + [InlineData(99.99)] + public void Serialize_ShouldWriteObjectToStream(object input) + { + using var memoryStream = new MemoryStream(); + BinarySerDe.Serialize(memoryStream, input); + memoryStream.Position = 0; + + var deserializedObject = MessagePackSerializer.Typeless.Deserialize(memoryStream); + + Assert.Equal(input, deserializedObject); + } + + [Fact] + public void Deserialize_ShouldReturnExpectedObject_WhenTypeMatches() + { + var employee = new Employee { Id = 101, Name = "John Doe" }; + using var memoryStream = new MemoryStream(); + MessagePackSerializer.Typeless.Serialize(memoryStream, employee); + memoryStream.Position = 0; + + var result = BinarySerDe.Deserialize(memoryStream); + + Assert.Equal(employee.Id, result.Id); + Assert.Equal(employee.Name, result.Name); + } + + [Fact] + public void Deserialize_ShouldThrowInvalidCastEx_WhenTypeDoesNotMatch() + { + var employee = new Employee { Id = 101, Name = "John Doe" }; + using var memoryStream = new MemoryStream(); + MessagePackSerializer.Typeless.Serialize(memoryStream, employee); + memoryStream.Position = 0; + + var action = () => BinarySerDe.Deserialize(memoryStream); + + Assert.Throws(action); + } + + [Fact] + public void Serialize_CustomFunctionAndObject_ShouldBeSerializable() + { + var department = new Department { Name = "HR", EmployeeCount = 27 }; + var employeeStub = new Employee + { + EmbeddedObject = department, + Id = 11, + Name = "Derek", + }; + using var memoryStream = new MemoryStream(); + MessagePackSerializer.Typeless.Serialize(memoryStream, employeeStub); + memoryStream.Position = 0; + + var deserializedCalculation = BinarySerDe.Deserialize(memoryStream); + + Assert.IsType(deserializedCalculation.EmbeddedObject); + Assert.Equal(27, ((Department)deserializedCalculation.EmbeddedObject).EmployeeCount); + Assert.Equal("HR", ((Department)deserializedCalculation.EmbeddedObject).Name); + } + + [Fact] + public void Serialize_ClassWithoutSerializableAttribute_ShouldThrowException() + { + var nonSerializableClass = new NonSerializableClass { Value = 123 }; + using var memoryStream = new MemoryStream(); + BinarySerDe.Serialize(memoryStream, nonSerializableClass); + memoryStream.Position = 0; + + Assert.Throws(() => BinarySerDe.Deserialize(memoryStream)); + } + + [Fact] + public void Serialize_CollectionAndDictionary_ShouldBeSerializable() + { + var list = new List { 1, 2, 3 }; + var dictionary = new Dictionary { { "one", 1 }, { "two", 2 } }; + + using var memoryStream = new MemoryStream(); + BinarySerDe.Serialize(memoryStream, list); + memoryStream.Position = 0; + var deserializedList = MessagePackSerializer.Typeless.Deserialize(memoryStream) as List; + + Assert.Equal(list, deserializedList); + + memoryStream.SetLength(0); + BinarySerDe.Serialize(memoryStream, dictionary); + memoryStream.Position = 0; + var deserializedDictionary = MessagePackSerializer.Typeless.Deserialize(memoryStream) as Dictionary; + + Assert.Equal(dictionary, deserializedDictionary); + } + + [Fact] + public void Serialize_PolymorphicObject_ShouldBeSerializable() + { + Employee manager = new Manager { Id = 1, Name = "Alice", Role = "Account manager" }; + using var memoryStream = new MemoryStream(); + BinarySerDe.Serialize(memoryStream, manager); + memoryStream.Position = 0; + + var deserializedEmployee = BinarySerDe.Deserialize(memoryStream); + + Assert.IsType(deserializedEmployee); + Assert.Equal("Alice", deserializedEmployee.Name); + Assert.Equal("Account manager", ((Manager)deserializedEmployee).Role); + } + + [Serializable] + private class Employee + { + public int Id { get; set; } + + public string Name { get; set; } + + public object EmbeddedObject { get; set; } + } + + [Serializable] + private class Department + { + public string Name { get; set; } + public int EmployeeCount { get; set; } + } + + [Serializable] + private class Manager : Employee + { + public string Role { get; set; } + } + + private class NonSerializableClass + { + public int Value { get; init; } + } +} diff --git a/src/csharp/Microsoft.Spark.UnitTest/Microsoft.Spark.UnitTest.csproj b/src/csharp/Microsoft.Spark.UnitTest/Microsoft.Spark.UnitTest.csproj index e635434ae..deb63cd3d 100644 --- a/src/csharp/Microsoft.Spark.UnitTest/Microsoft.Spark.UnitTest.csproj +++ b/src/csharp/Microsoft.Spark.UnitTest/Microsoft.Spark.UnitTest.csproj @@ -3,7 +3,6 @@ net8.0 Microsoft.Spark.UnitTest - true diff --git a/src/csharp/Microsoft.Spark.UnitTest/UdfSerDeTests.cs b/src/csharp/Microsoft.Spark.UnitTest/UdfSerDeTests.cs index b655fdf1b..ba6ebc0b9 100644 --- a/src/csharp/Microsoft.Spark.UnitTest/UdfSerDeTests.cs +++ b/src/csharp/Microsoft.Spark.UnitTest/UdfSerDeTests.cs @@ -5,7 +5,6 @@ using System; using System.IO; using System.Reflection; -using System.Runtime.Serialization.Formatters.Binary; using Microsoft.Spark.Utils; using Xunit; @@ -17,21 +16,21 @@ public class UdfSerDeTests [Serializable] private class TestClass { - private readonly string _str; + private readonly string str; - public TestClass(string s) + public TestClass(string str) { - _str = s; + this.str = str; } public string Concat(string s) { - if (_str == null) + if (str == null) { return s + s; } - return _str + s; + return str + s; } public override bool Equals(object obj) @@ -43,7 +42,7 @@ public override bool Equals(object obj) return false; } - return _str == that._str; + return str == that.str; } public override int GetHashCode() @@ -149,16 +148,13 @@ private Delegate SerDe(Delegate udf) return Deserialize(Serialize(udf)); } -#pragma warning disable SYSLIB0011 // Type or member is obsolete - // TODO: Replace BinaryFormatter with a new, secure serializer. private byte[] Serialize(Delegate udf) { UdfSerDe.UdfData udfData = UdfSerDe.Serialize(udf); using (var ms = new MemoryStream()) { - var bf = new BinaryFormatter(); - bf.Serialize(ms, udfData); + BinarySerDe.Serialize(ms, udfData); return ms.ToArray(); } } @@ -167,11 +163,9 @@ private Delegate Deserialize(byte[] serializedUdf) { using (var ms = new MemoryStream(serializedUdf, false)) { - var bf = new BinaryFormatter(); - UdfSerDe.UdfData udfData = (UdfSerDe.UdfData)bf.Deserialize(ms); + var udfData = BinarySerDe.Deserialize(ms); return UdfSerDe.Deserialize(udfData); } } -#pragma warning restore } } diff --git a/src/csharp/Microsoft.Spark.Worker.UnitTest/CommandExecutorTests.cs b/src/csharp/Microsoft.Spark.Worker.UnitTest/CommandExecutorTests.cs index cc43b9f37..e3082d8db 100644 --- a/src/csharp/Microsoft.Spark.Worker.UnitTest/CommandExecutorTests.cs +++ b/src/csharp/Microsoft.Spark.Worker.UnitTest/CommandExecutorTests.cs @@ -8,7 +8,6 @@ using System.IO; using System.Linq; using System.Reflection; -using System.Runtime.Serialization.Formatters.Binary; using System.Threading; using System.Threading.Tasks; using Apache.Arrow; @@ -1049,10 +1048,8 @@ public void TestRDDCommandExecutor(Version sparkVersion, IpcOptions ipcOptions) using var inputStream = new MemoryStream(); using var outputStream = new MemoryStream(); -#pragma warning disable SYSLIB0011 // Type or member is obsolete + // Write test data to the input stream. - var formatter = new BinaryFormatter(); -#pragma warning restore SYSLIB0011 // Type or member is obsolete var memoryStream = new MemoryStream(); var inputs = new int[] { 0, 1, 2, 3, 4 }; @@ -1061,10 +1058,7 @@ public void TestRDDCommandExecutor(Version sparkVersion, IpcOptions ipcOptions) foreach (int input in inputs) { memoryStream.Position = 0; -#pragma warning disable SYSLIB0011 // Type or member is obsolete - // TODO: Replace BinaryFormatter with a new, secure serializer. - formatter.Serialize(memoryStream, input); -#pragma warning restore SYSLIB0011 // Type or member is obsolete + BinarySerDe.Serialize(memoryStream, input); values.Add(memoryStream.ToArray()); } @@ -1094,12 +1088,9 @@ public void TestRDDCommandExecutor(Version sparkVersion, IpcOptions ipcOptions) for (int i = 0; i < inputs.Length; ++i) { Assert.True(SerDe.ReadInt32(outputStream) > 0); -#pragma warning disable SYSLIB0011 // Type or member is obsolete - // TODO: Replace BinaryFormatter with a new, secure serializer. Assert.Equal( mapUdf(i), - formatter.Deserialize(outputStream)); -#pragma warning restore SYSLIB0011 // Type or member is obsolete + BinarySerDe.Deserialize(outputStream)); } // Validate all the data on the stream is read. diff --git a/src/csharp/Microsoft.Spark.Worker.UnitTest/Microsoft.Spark.Worker.UnitTest.csproj b/src/csharp/Microsoft.Spark.Worker.UnitTest/Microsoft.Spark.Worker.UnitTest.csproj index 431f3765b..30ff60921 100644 --- a/src/csharp/Microsoft.Spark.Worker.UnitTest/Microsoft.Spark.Worker.UnitTest.csproj +++ b/src/csharp/Microsoft.Spark.Worker.UnitTest/Microsoft.Spark.Worker.UnitTest.csproj @@ -1,7 +1,6 @@  net8.0 - true diff --git a/src/csharp/Microsoft.Spark.Worker/Command/RDDCommandExecutor.cs b/src/csharp/Microsoft.Spark.Worker/Command/RDDCommandExecutor.cs index 7ad65bc40..adf08a6ad 100644 --- a/src/csharp/Microsoft.Spark.Worker/Command/RDDCommandExecutor.cs +++ b/src/csharp/Microsoft.Spark.Worker/Command/RDDCommandExecutor.cs @@ -5,7 +5,6 @@ using System; using System.Collections.Generic; using System.IO; -using System.Runtime.Serialization.Formatters.Binary; using Microsoft.Spark.Interop.Ipc; using Microsoft.Spark.Utils; @@ -19,10 +18,6 @@ internal class RDDCommandExecutor { [ThreadStatic] private static MemoryStream s_writeOutputStream; - [ThreadStatic] -#pragma warning disable SYSLIB0011 // Type or member is obsolete - private static BinaryFormatter s_binaryFormatter; -#pragma warning restore SYSLIB0011 // Type or member is obsolete /// /// Executes the commands on the input data read from input stream @@ -113,11 +108,7 @@ private void Serialize( switch (serializerMode) { case CommandSerDe.SerializedMode.Byte: -#pragma warning disable SYSLIB0011 // Type or member is obsolete - BinaryFormatter formatter = s_binaryFormatter ??= new BinaryFormatter(); - // TODO: Replace BinaryFormatter with a new, secure serializer. - formatter.Serialize(stream, message); -#pragma warning restore SYSLIB0011 // Type or member is obsolete + BinarySerDe.Serialize(stream, message); break; case CommandSerDe.SerializedMode.None: case CommandSerDe.SerializedMode.String: diff --git a/src/csharp/Microsoft.Spark.Worker/Microsoft.Spark.Worker.csproj b/src/csharp/Microsoft.Spark.Worker/Microsoft.Spark.Worker.csproj index 4f371e869..5d242b179 100644 --- a/src/csharp/Microsoft.Spark.Worker/Microsoft.Spark.Worker.csproj +++ b/src/csharp/Microsoft.Spark.Worker/Microsoft.Spark.Worker.csproj @@ -4,7 +4,6 @@ net48;net8.0 net8.0 Microsoft.Spark.Worker - true true diff --git a/src/csharp/Microsoft.Spark.Worker/Processor/BroadcastVariableProcessor.cs b/src/csharp/Microsoft.Spark.Worker/Processor/BroadcastVariableProcessor.cs index 17b857256..a3a539d12 100644 --- a/src/csharp/Microsoft.Spark.Worker/Processor/BroadcastVariableProcessor.cs +++ b/src/csharp/Microsoft.Spark.Worker/Processor/BroadcastVariableProcessor.cs @@ -5,9 +5,9 @@ using System; using System.IO; using System.Net; -using System.Runtime.Serialization.Formatters.Binary; using Microsoft.Spark.Interop.Ipc; using Microsoft.Spark.Network; +using Microsoft.Spark.Utils; namespace Microsoft.Spark.Worker.Processor { @@ -45,9 +45,7 @@ internal BroadcastVariables Process(Stream stream) broadcastVars.Secret); } } -#pragma warning disable SYSLIB0011 // Type or member is obsolete - var formatter = new BinaryFormatter(); -#pragma warning restore SYSLIB0011 // Type or member is obsolete + for (int i = 0; i < broadcastVars.Count; ++i) { long bid = SerDe.ReadInt64(stream); @@ -62,21 +60,17 @@ internal BroadcastVariables Process(Stream stream) $"server {readBid} is different from the Broadcast Id received " + $"from the payload {bid}."); } -#pragma warning disable SYSLIB0011 // Type or member is obsolete - // TODO: Replace BinaryFormatter with a new, secure serializer. - object value = formatter.Deserialize(socket.InputStream); -#pragma warning restore SYSLIB0011 // Type or member is obsolete + + var value = BinarySerDe.Deserialize(socket.InputStream); BroadcastRegistry.Add(bid, value); } else { string path = SerDe.ReadString(stream); - using FileStream fStream = + using FileStream fStream = File.Open(path, FileMode.Open, FileAccess.Read, FileShare.Read); -#pragma warning disable SYSLIB0011 // Type or member is obsolete - // TODO: Replace BinaryFormatter with a new, secure serializer. - object value = formatter.Deserialize(fStream); -#pragma warning restore SYSLIB0011 // Type or member is obsolete + + var value = BinarySerDe.Deserialize(fStream); BroadcastRegistry.Add(bid, value); } } diff --git a/src/csharp/Microsoft.Spark/Broadcast.cs b/src/csharp/Microsoft.Spark/Broadcast.cs index c26ad5329..501833d28 100644 --- a/src/csharp/Microsoft.Spark/Broadcast.cs +++ b/src/csharp/Microsoft.Spark/Broadcast.cs @@ -3,13 +3,13 @@ using System.Collections.Generic; using System.IO; using System.Net; -using System.Runtime.Serialization; -using System.Runtime.Serialization.Formatters.Binary; using System.Threading; +using MessagePack; using Microsoft.Spark.Interop; using Microsoft.Spark.Interop.Ipc; using Microsoft.Spark.Network; using Microsoft.Spark.Services; +using Microsoft.Spark.Utils; namespace Microsoft.Spark { @@ -21,7 +21,9 @@ namespace Microsoft.Spark /// reduce communication cost. /// [Serializable] - public sealed class Broadcast : IJvmObjectReferenceProvider + public sealed class Broadcast + : IMessagePackSerializationCallbackReceiver, + IJvmObjectReferenceProvider { [NonSerialized] private readonly string _path; @@ -37,6 +39,11 @@ internal Broadcast(SparkContext sc, T value) _bid = (long)_jvmObject.Invoke("id"); } + // Default constructor, needed for deserialization + internal Broadcast() + { + } + public JvmObjectReference Reference => _jvmObject; /// @@ -79,17 +86,6 @@ public void Destroy() File.Delete(_path); } - /// - /// Serialization callback function that adds to the JvmBroadcastRegistry when the - /// Broadcast variable object is being serialized. - /// - /// The current StreaminContext being used - [OnSerialized] - internal void OnSerialized(StreamingContext context) - { - JvmBroadcastRegistry.Add(_jvmObject); - } - /// /// Function that creates a temporary directory inside the given directory and returns the /// absolute filepath of temporary file name in that directory. @@ -221,10 +217,23 @@ private void WriteToFile(object value) /// /// Serializable object /// Stream to which the object is serialized - private void Dump(object value, Stream stream) + private void Dump(object value, Stream stream) => + BinarySerDe.Serialize(stream, value); + + /// + /// Serialization callback function that adds to the JvmBroadcastRegistry when the + /// Broadcast variable object is being serialized. + /// + public void OnBeforeSerialize() + { + JvmBroadcastRegistry.Add(_jvmObject); + } + + /// + /// Deserialization callback function + /// + public void OnAfterDeserialize() { - var formatter = new BinaryFormatter(); - formatter.Serialize(stream, value); } } diff --git a/src/csharp/Microsoft.Spark/Microsoft.Spark.csproj b/src/csharp/Microsoft.Spark/Microsoft.Spark.csproj index 93bd2fbf5..26729e1e4 100644 --- a/src/csharp/Microsoft.Spark/Microsoft.Spark.csproj +++ b/src/csharp/Microsoft.Spark/Microsoft.Spark.csproj @@ -10,7 +10,6 @@ .NET for Apache Spark https://github.com/dotnet/spark/tree/master/docs/release-notes spark;dotnet;csharp - true @@ -30,6 +29,7 @@ + diff --git a/src/csharp/Microsoft.Spark/RDD/Collector.cs b/src/csharp/Microsoft.Spark/RDD/Collector.cs index 9acee09bc..15072481c 100644 --- a/src/csharp/Microsoft.Spark/RDD/Collector.cs +++ b/src/csharp/Microsoft.Spark/RDD/Collector.cs @@ -6,7 +6,6 @@ using System.Collections.Generic; using System.IO; using System.Linq; -using System.Runtime.Serialization.Formatters.Binary; using Microsoft.Spark.Interop.Ipc; using Microsoft.Spark.Sql; using Microsoft.Spark.Utils; @@ -70,12 +69,8 @@ internal interface IDeserializer /// private sealed class BinaryDeserializer : IDeserializer { - private readonly BinaryFormatter _formater = new BinaryFormatter(); - - public object Deserialize(Stream stream, int length) - { - return _formater.Deserialize(stream); - } + public object Deserialize(Stream stream, int length) => + BinarySerDe.Deserialize(stream, length); } /// diff --git a/src/csharp/Microsoft.Spark/SparkContext.cs b/src/csharp/Microsoft.Spark/SparkContext.cs index 248005b65..c7aa85704 100644 --- a/src/csharp/Microsoft.Spark/SparkContext.cs +++ b/src/csharp/Microsoft.Spark/SparkContext.cs @@ -4,10 +4,10 @@ using System.Collections.Generic; using System.IO; -using System.Runtime.Serialization.Formatters.Binary; using Microsoft.Spark.Hadoop.Conf; using Microsoft.Spark.Interop.Internal.Scala; using Microsoft.Spark.Interop.Ipc; +using Microsoft.Spark.Utils; using static Microsoft.Spark.Utils.CommandSerDe; namespace Microsoft.Spark @@ -225,13 +225,12 @@ public void ClearJobGroup() /// RDD representing distributed collection internal RDD Parallelize(IEnumerable seq, int? numSlices = null) { - var formatter = new BinaryFormatter(); using var memoryStream = new MemoryStream(); var values = new List(); foreach (T obj in seq) { - formatter.Serialize(memoryStream, obj); + BinarySerDe.Serialize(memoryStream, obj); values.Add(memoryStream.ToArray()); memoryStream.SetLength(0); } diff --git a/src/csharp/Microsoft.Spark/Utils/BinarySerDe.cs b/src/csharp/Microsoft.Spark/Utils/BinarySerDe.cs new file mode 100644 index 000000000..7a7896324 --- /dev/null +++ b/src/csharp/Microsoft.Spark/Utils/BinarySerDe.cs @@ -0,0 +1,113 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.IO; +using MessagePack; +using MessagePack.Resolvers; +using Microsoft.Spark.Interop.Ipc; + +namespace Microsoft.Spark.Utils; + +// If deserialization of untrusted data is required, extend this functionality to +// incorporate techniques such as using a Message Authentication Code (MAC) +// or whitelisting allowed types to mitigate security risks. + +/// +/// BinarySerDe (Serialization/Deserialization) is a utility class designed to handle +/// serialization and deserialization of objects to and from binary formats. +/// +/// +/// This implementation uses the MessagePack `Typeless` API, which embeds type +/// information into the serialized data. It adds overhead of 1-2 bytes for primitive +/// types, and serializes 'System.Type' entirely for complex objects. +/// Does not serialize type definition, so in order to deserialize a complex object, +/// declaring library should be available in app domain or at probing locations. +/// +/// +internal static class BinarySerDe +{ + private static MessagePackSerializerOptions _options = + new AllowStandardOrSerializableMessagePackSerializerOptions( + TypelessContractlessStandardResolver.Instance + ).WithSecurity(MessagePackSecurity.UntrustedData); + + /// + /// Deserializes a stream of binary data into an object of type T. + /// When using or shared streams, prefer the overloaded version + /// that accepts a `length` parameter to ensure no excess data is consumed. + /// + /// The expected type of the deserialized object. + /// The stream containing the serialized data. + /// An object of type T. + internal static T Deserialize(Stream stream) + { + return (T)MessagePackSerializer.Typeless.Deserialize(stream, _options); + } + + /// + /// Deserializes an object from stream, ensuring no excess data is read. + /// + /// The stream containing the serialized data. + /// The length of byte section to deserialize. + /// The deserialized object. + internal static object Deserialize(Stream stream, int length) + { + ReadOnlyMemory memory = SerDe.ReadBytes(stream, length); + + return MessagePackSerializer.Typeless.Deserialize(memory, _options); + } + + /// + /// Serializes an object into a binary stream + /// + /// The type of the object to serialize. + /// The target stream where the data will be written. + /// The object to serialize. + internal static void Serialize(Stream stream, T graph) + { + MessagePackSerializer.Typeless.Serialize(stream, graph, _options); + } +} + +/// +/// Additional security for MessagePack typeless serialization, that only allows +/// standard classes or classes marked with the 'System.Serializable' attribute. +/// +internal class AllowStandardOrSerializableMessagePackSerializerOptions + : MessagePackSerializerOptions +{ + public AllowStandardOrSerializableMessagePackSerializerOptions(IFormatterResolver resolver) + : base(resolver) { } + + protected AllowStandardOrSerializableMessagePackSerializerOptions( + MessagePackSerializerOptions copyFrom + ) + : base(copyFrom) { } + + public override void ThrowIfDeserializingTypeIsDisallowed(Type type) + { + // Check against predefined blacklist + base.ThrowIfDeserializingTypeIsDisallowed(type); + + // Check if MessagePack can handle this type safely + var formatter = StandardResolver.Instance.GetFormatterDynamic(type); + + if ( + formatter == null + && type.GetCustomAttributes(typeof(System.SerializableAttribute), true).Length == 0 + ) + { + throw new MessagePackSerializationException( + $"Deserialization attempted to create the type {type.FullName} which is not allowed." + + $" Add 'System.Serializable' attribute to allow serialization" + ); + } + } + + protected override MessagePackSerializerOptions Clone() + { + return new AllowStandardOrSerializableMessagePackSerializerOptions(this); + } +} diff --git a/src/csharp/Microsoft.Spark/Utils/CommandSerDe.cs b/src/csharp/Microsoft.Spark/Utils/CommandSerDe.cs index d2343fd29..e302b4bb7 100644 --- a/src/csharp/Microsoft.Spark/Utils/CommandSerDe.cs +++ b/src/csharp/Microsoft.Spark/Utils/CommandSerDe.cs @@ -8,7 +8,6 @@ using System.IO; using System.Linq; using System.Reflection; -using System.Runtime.Serialization.Formatters.Binary; using System.Text; using Microsoft.Spark.Interop.Ipc; using Microsoft.Spark.Sql; @@ -159,10 +158,9 @@ internal static byte[] Serialize( Udfs = udfs.ToArray() }; - var formatter = new BinaryFormatter(); using (var stream = new MemoryStream()) { - formatter.Serialize(stream, udfWrapperData); + BinarySerDe.Serialize(stream, udfWrapperData); byte[] udfBytes = stream.ToArray(); byte[] udfBytesLengthAsBytes = BitConverter.GetBytes(udfBytes.Length); @@ -291,10 +289,9 @@ private static UdfWrapperData GetUdfWrapperDataFromStream( byte[] serializedCommand = SerDe.ReadBytes(stream); - var bf = new BinaryFormatter(); var ms = new MemoryStream(serializedCommand, false); - return (UdfWrapperData)bf.Deserialize(ms); + return BinarySerDe.Deserialize(ms); } internal static T Deserialize( diff --git a/src/csharp/Microsoft.Spark/Utils/DependencyProviderUtils.cs b/src/csharp/Microsoft.Spark/Utils/DependencyProviderUtils.cs index 3954151d1..88b2271ad 100644 --- a/src/csharp/Microsoft.Spark/Utils/DependencyProviderUtils.cs +++ b/src/csharp/Microsoft.Spark/Utils/DependencyProviderUtils.cs @@ -4,7 +4,6 @@ using System; using System.IO; -using System.Runtime.Serialization.Formatters.Binary; namespace Microsoft.Spark.Utils { @@ -74,15 +73,13 @@ public override bool Equals(object obj) internal static Metadata Deserialize(string path) { using FileStream fileStream = File.OpenRead(path); - var formatter = new BinaryFormatter(); - return (Metadata)formatter.Deserialize(fileStream); + return BinarySerDe.Deserialize(fileStream); } internal void Serialize(string path) { using FileStream fileStream = File.OpenWrite(path); - var formatter = new BinaryFormatter(); - formatter.Serialize(fileStream, this); + BinarySerDe.Serialize(fileStream, this); } private bool Equals(Metadata other)