Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrated to MessagePack instead of BinaryFormatter #1202

Merged
merged 2 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
<RootNamespace>Microsoft.Spark.Extensions.DotNet.Interactive.UnitTest</RootNamespace>
<EnableUnsafeBinaryFormatterSerialization>true</EnableUnsafeBinaryFormatterSerialization>
</PropertyGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ namespace Microsoft.Spark.E2ETest.ExternalLibrary
[Serializable]
public class ExternalClass
{
private string _s;
private string s;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise messagePack cannot instantiate an ExternalClass instance, as it doesn't know what to use for argument.
We can either create a default parameter-less constructor, or rename field to match ctor argument name.


public ExternalClass(string s)
{
_s = s;
this.s = s;
}

public static string HelloWorld()
Expand All @@ -23,7 +23,7 @@ public static string HelloWorld()

public string Concat(string s)
{
return _s + s;
return this.s + s;
}
}
}
8 changes: 5 additions & 3 deletions src/csharp/Microsoft.Spark.E2ETest/IpcTests/BroadcastTests.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Linq;
using MessagePack;
using Microsoft.Spark.Sql;
using Xunit;
using static Microsoft.Spark.Sql.Functions;
Expand All @@ -12,10 +13,11 @@ public class TestBroadcastVariable
public int IntValue { get; private set; }
public string StringValue { get; private set; }

public TestBroadcastVariable(int intVal, string stringVal)
[SerializationConstructor]
public TestBroadcastVariable(int intValue, string stringValue)
{
IntValue = intVal;
StringValue = stringVal;
IntValue = intValue;
StringValue = stringValue;
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
<EnableUnsafeBinaryFormatterSerialization>true</EnableUnsafeBinaryFormatterSerialization>
</PropertyGroup>
<ItemGroup>
<Content Include="Resources\*">
Expand Down
154 changes: 154 additions & 0 deletions src/csharp/Microsoft.Spark.UnitTest/BinarySerDeTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.IO;
using MessagePack;
using Microsoft.Spark.Utils;
using Xunit;

namespace Microsoft.Spark.UnitTest;

[Collection("Spark Unit Tests")]
public class BinarySerDeTests
{
[Theory]
[InlineData(42)]
[InlineData("Test")]
[InlineData(99.99)]
public void Serialize_ShouldWriteObjectToStream(object input)
{
using var memoryStream = new MemoryStream();
BinarySerDe.Serialize(memoryStream, input);
memoryStream.Position = 0;

var deserializedObject = MessagePackSerializer.Typeless.Deserialize(memoryStream);

Assert.Equal(input, deserializedObject);
}

[Fact]
public void Deserialize_ShouldReturnExpectedObject_WhenTypeMatches()
{
var employee = new Employee { Id = 101, Name = "John Doe" };
using var memoryStream = new MemoryStream();
MessagePackSerializer.Typeless.Serialize(memoryStream, employee);
memoryStream.Position = 0;

var result = BinarySerDe.Deserialize<Employee>(memoryStream);

Assert.Equal(employee.Id, result.Id);
Assert.Equal(employee.Name, result.Name);
}

[Fact]
public void Deserialize_ShouldThrowInvalidCastEx_WhenTypeDoesNotMatch()
{
var employee = new Employee { Id = 101, Name = "John Doe" };
using var memoryStream = new MemoryStream();
MessagePackSerializer.Typeless.Serialize(memoryStream, employee);
memoryStream.Position = 0;

var action = () => BinarySerDe.Deserialize<Department>(memoryStream);

Assert.Throws<InvalidCastException>(action);
}

[Fact]
public void Serialize_CustomFunctionAndObject_ShouldBeSerializable()
{
var department = new Department { Name = "HR", EmployeeCount = 27 };
var employeeStub = new Employee
{
EmbeddedObject = department,
Id = 11,
Name = "Derek",
};
using var memoryStream = new MemoryStream();
MessagePackSerializer.Typeless.Serialize(memoryStream, employeeStub);
memoryStream.Position = 0;

var deserializedCalculation = BinarySerDe.Deserialize<Employee>(memoryStream);

Assert.IsType<Department>(deserializedCalculation.EmbeddedObject);
Assert.Equal(27, ((Department)deserializedCalculation.EmbeddedObject).EmployeeCount);
Assert.Equal("HR", ((Department)deserializedCalculation.EmbeddedObject).Name);
}

[Fact]
public void Serialize_ClassWithoutSerializableAttribute_ShouldThrowException()
{
var nonSerializableClass = new NonSerializableClass { Value = 123 };
using var memoryStream = new MemoryStream();
BinarySerDe.Serialize(memoryStream, nonSerializableClass);
memoryStream.Position = 0;

Assert.Throws<MessagePackSerializationException>(() => BinarySerDe.Deserialize<NonSerializableClass>(memoryStream));
}

[Fact]
public void Serialize_CollectionAndDictionary_ShouldBeSerializable()
{
var list = new List<int> { 1, 2, 3 };
var dictionary = new Dictionary<string, int> { { "one", 1 }, { "two", 2 } };

using var memoryStream = new MemoryStream();
BinarySerDe.Serialize(memoryStream, list);
memoryStream.Position = 0;
var deserializedList = MessagePackSerializer.Typeless.Deserialize(memoryStream) as List<int>;

Assert.Equal(list, deserializedList);

memoryStream.SetLength(0);
BinarySerDe.Serialize(memoryStream, dictionary);
memoryStream.Position = 0;
var deserializedDictionary = MessagePackSerializer.Typeless.Deserialize(memoryStream) as Dictionary<string, int>;

Assert.Equal(dictionary, deserializedDictionary);
}

[Fact]
public void Serialize_PolymorphicObject_ShouldBeSerializable()
{
Employee manager = new Manager { Id = 1, Name = "Alice", Role = "Account manager" };
using var memoryStream = new MemoryStream();
BinarySerDe.Serialize(memoryStream, manager);
memoryStream.Position = 0;

var deserializedEmployee = BinarySerDe.Deserialize<Employee>(memoryStream);

Assert.IsType<Manager>(deserializedEmployee);
Assert.Equal("Alice", deserializedEmployee.Name);
Assert.Equal("Account manager", ((Manager)deserializedEmployee).Role);
}

[Serializable]
private class Employee
{
public int Id { get; set; }

public string Name { get; set; }

public object EmbeddedObject { get; set; }
}

[Serializable]
private class Department
{
public string Name { get; set; }
public int EmployeeCount { get; set; }
}

[Serializable]
private class Manager : Employee
{
public string Role { get; set; }
}

private class NonSerializableClass
{
public int Value { get; init; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
<RootNamespace>Microsoft.Spark.UnitTest</RootNamespace>
<EnableUnsafeBinaryFormatterSerialization>true</EnableUnsafeBinaryFormatterSerialization>
</PropertyGroup>

<ItemGroup>
Expand Down
22 changes: 8 additions & 14 deletions src/csharp/Microsoft.Spark.UnitTest/UdfSerDeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using System;
using System.IO;
using System.Reflection;
using System.Runtime.Serialization.Formatters.Binary;
using Microsoft.Spark.Utils;
using Xunit;

Expand All @@ -17,21 +16,21 @@ public class UdfSerDeTests
[Serializable]
private class TestClass
{
private readonly string _str;
private readonly string str;

public TestClass(string s)
public TestClass(string str)
{
_str = s;
this.str = str;
}

public string Concat(string s)
{
if (_str == null)
if (str == null)
{
return s + s;
}

return _str + s;
return str + s;
}

public override bool Equals(object obj)
Expand All @@ -43,7 +42,7 @@ public override bool Equals(object obj)
return false;
}

return _str == that._str;
return str == that.str;
}

public override int GetHashCode()
Expand Down Expand Up @@ -149,16 +148,13 @@ private Delegate SerDe(Delegate udf)
return Deserialize(Serialize(udf));
}

#pragma warning disable SYSLIB0011 // Type or member is obsolete
// TODO: Replace BinaryFormatter with a new, secure serializer.
private byte[] Serialize(Delegate udf)
{
UdfSerDe.UdfData udfData = UdfSerDe.Serialize(udf);

using (var ms = new MemoryStream())
{
var bf = new BinaryFormatter();
bf.Serialize(ms, udfData);
BinarySerDe.Serialize(ms, udfData);
return ms.ToArray();
}
}
Expand All @@ -167,11 +163,9 @@ private Delegate Deserialize(byte[] serializedUdf)
{
using (var ms = new MemoryStream(serializedUdf, false))
{
var bf = new BinaryFormatter();
UdfSerDe.UdfData udfData = (UdfSerDe.UdfData)bf.Deserialize(ms);
var udfData = BinarySerDe.Deserialize<UdfSerDe.UdfData>(ms);
return UdfSerDe.Deserialize(udfData);
}
}
#pragma warning restore
}
}
15 changes: 3 additions & 12 deletions src/csharp/Microsoft.Spark.Worker.UnitTest/CommandExecutorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
using System.IO;
using System.Linq;
using System.Reflection;
using System.Runtime.Serialization.Formatters.Binary;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow;
Expand Down Expand Up @@ -1049,10 +1048,8 @@ public void TestRDDCommandExecutor(Version sparkVersion, IpcOptions ipcOptions)

using var inputStream = new MemoryStream();
using var outputStream = new MemoryStream();
#pragma warning disable SYSLIB0011 // Type or member is obsolete

// Write test data to the input stream.
var formatter = new BinaryFormatter();
#pragma warning restore SYSLIB0011 // Type or member is obsolete
var memoryStream = new MemoryStream();

var inputs = new int[] { 0, 1, 2, 3, 4 };
Expand All @@ -1061,10 +1058,7 @@ public void TestRDDCommandExecutor(Version sparkVersion, IpcOptions ipcOptions)
foreach (int input in inputs)
{
memoryStream.Position = 0;
#pragma warning disable SYSLIB0011 // Type or member is obsolete
// TODO: Replace BinaryFormatter with a new, secure serializer.
formatter.Serialize(memoryStream, input);
#pragma warning restore SYSLIB0011 // Type or member is obsolete
BinarySerDe.Serialize(memoryStream, input);
values.Add(memoryStream.ToArray());
}

Expand Down Expand Up @@ -1094,12 +1088,9 @@ public void TestRDDCommandExecutor(Version sparkVersion, IpcOptions ipcOptions)
for (int i = 0; i < inputs.Length; ++i)
{
Assert.True(SerDe.ReadInt32(outputStream) > 0);
#pragma warning disable SYSLIB0011 // Type or member is obsolete
// TODO: Replace BinaryFormatter with a new, secure serializer.
Assert.Equal(
mapUdf(i),
formatter.Deserialize(outputStream));
#pragma warning restore SYSLIB0011 // Type or member is obsolete
BinarySerDe.Deserialize<object>(outputStream));
}

// Validate all the data on the stream is read.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
<EnableUnsafeBinaryFormatterSerialization>true</EnableUnsafeBinaryFormatterSerialization>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Moq" Version="4.20.70" />
Expand Down
11 changes: 1 addition & 10 deletions src/csharp/Microsoft.Spark.Worker/Command/RDDCommandExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Runtime.Serialization.Formatters.Binary;
using Microsoft.Spark.Interop.Ipc;
using Microsoft.Spark.Utils;

Expand All @@ -19,10 +18,6 @@ internal class RDDCommandExecutor
{
[ThreadStatic]
private static MemoryStream s_writeOutputStream;
[ThreadStatic]
#pragma warning disable SYSLIB0011 // Type or member is obsolete
private static BinaryFormatter s_binaryFormatter;
#pragma warning restore SYSLIB0011 // Type or member is obsolete

/// <summary>
/// Executes the commands on the input data read from input stream
Expand Down Expand Up @@ -113,11 +108,7 @@ private void Serialize(
switch (serializerMode)
{
case CommandSerDe.SerializedMode.Byte:
#pragma warning disable SYSLIB0011 // Type or member is obsolete
BinaryFormatter formatter = s_binaryFormatter ??= new BinaryFormatter();
// TODO: Replace BinaryFormatter with a new, secure serializer.
formatter.Serialize(stream, message);
#pragma warning restore SYSLIB0011 // Type or member is obsolete
BinarySerDe.Serialize(stream, message);
break;
case CommandSerDe.SerializedMode.None:
case CommandSerDe.SerializedMode.String:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
<TargetFrameworks>net48;net8.0</TargetFrameworks>
<TargetFrameworks Condition="'$(OS)' != 'Windows_NT'">net8.0</TargetFrameworks>
<RootNamespace>Microsoft.Spark.Worker</RootNamespace>
<EnableUnsafeBinaryFormatterSerialization>true</EnableUnsafeBinaryFormatterSerialization>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
</PropertyGroup>
<ItemGroup>
Expand Down
Loading