Skip to content

Commit 3575eed

Browse files
committed
Update solution to .NET 8
using the native System.NET to parse IPAddress and IPNetwork updated tests
1 parent 1a939be commit 3575eed

10 files changed

+213
-239
lines changed

MyWebApp.IntegrationTests/CustomRemoteIpAddressMiddleware.cs

+17-32
Original file line numberDiff line numberDiff line change
@@ -3,42 +3,27 @@
33
using Microsoft.AspNetCore.Hosting;
44
using Microsoft.AspNetCore.Http;
55

6-
namespace MyWebApp.IntegrationTests
7-
{
8-
public class CustomRemoteIpAddressMiddleware
9-
{
10-
private readonly RequestDelegate _next;
11-
private readonly IPAddress _fakeIpAddress;
6+
namespace MyWebApp.IntegrationTests;
127

13-
public CustomRemoteIpAddressMiddleware(RequestDelegate next, IPAddress? fakeIpAddress = null)
14-
{
15-
_next = next;
16-
_fakeIpAddress = fakeIpAddress ?? IPAddress.Parse("127.0.0.1");
17-
}
8+
public class CustomRemoteIpAddressMiddleware(RequestDelegate next, IPAddress? fakeIpAddress = null)
9+
{
10+
private readonly IPAddress _fakeIpAddress = fakeIpAddress ?? IPAddress.Parse("127.0.0.1");
1811

19-
public async Task Invoke(HttpContext httpContext)
20-
{
21-
httpContext.Connection.RemoteIpAddress = _fakeIpAddress;
22-
await _next(httpContext);
23-
}
12+
public async Task Invoke(HttpContext httpContext)
13+
{
14+
httpContext.Connection.RemoteIpAddress = _fakeIpAddress;
15+
await next(httpContext);
2416
}
17+
}
2518

26-
public class CustomRemoteIpStartupFilter : IStartupFilter
19+
public class CustomRemoteIpStartupFilter(IPAddress? remoteIp = null) : IStartupFilter
20+
{
21+
public Action<IApplicationBuilder> Configure(Action<IApplicationBuilder> next)
2722
{
28-
private readonly IPAddress? _remoteIp;
29-
30-
public CustomRemoteIpStartupFilter(IPAddress? remoteIp = null)
23+
return app =>
3124
{
32-
_remoteIp = remoteIp;
33-
}
34-
35-
public Action<IApplicationBuilder> Configure(Action<IApplicationBuilder> next)
36-
{
37-
return app =>
38-
{
39-
app.UseMiddleware<CustomRemoteIpAddressMiddleware>(_remoteIp);
40-
next(app);
41-
};
42-
}
25+
app.UseMiddleware<CustomRemoteIpAddressMiddleware>(remoteIp);
26+
next(app);
27+
};
4328
}
44-
}
29+
}

MyWebApp.IntegrationTests/IpRestrictionTests.cs

+65-64
Original file line numberDiff line numberDiff line change
@@ -7,86 +7,87 @@
77
using Microsoft.Extensions.Logging;
88
using Microsoft.VisualStudio.TestTools.UnitTesting;
99

10-
namespace MyWebApp.IntegrationTests
10+
namespace MyWebApp.IntegrationTests;
11+
12+
[TestClass]
13+
public class IpRestrictionTests
1114
{
12-
[TestClass]
13-
public class IpRestrictionTests
15+
[TestMethod]
16+
public void ParseConfigurations()
1417
{
15-
[TestMethod]
16-
public async Task HttpRequestWithAllowedIpAddressShouldReturn200()
17-
{
18-
var factory = new WebApplicationFactory<Startup>().WithWebHostBuilder(builder =>
19-
{
20-
builder.UseSetting("https_port", "5001").ConfigureLogging(c => c.AddConsole());
21-
builder.ConfigureTestServices(services =>
22-
{
23-
services.AddSingleton<IStartupFilter>(new CustomRemoteIpStartupFilter(IPAddress.Parse("127.0.0.1")));
24-
});
25-
});
26-
var client = factory.CreateClient();
27-
var response = await client.GetAsync("values");
28-
29-
Assert.AreEqual(HttpStatusCode.OK, response.StatusCode);
30-
Assert.AreEqual("application/json; charset=utf-8", response.Content.Headers.ContentType?.ToString());
31-
32-
var json = await response.Content.ReadAsStringAsync();
33-
Assert.AreEqual("[\"value1\",\"value2\"]", json);
34-
}
18+
var b = IPNetwork.TryParse("2001:0db8::/64", out var ip);
19+
Assert.IsTrue(b);
20+
Assert.AreEqual(new IPNetwork(IPAddress.Parse("2001:0db8::"), 64), ip);
21+
}
3522

36-
[TestMethod]
37-
public async Task HttpRequestWithForbiddenIpAddressShouldReturn403()
23+
[TestMethod]
24+
public async Task HttpRequestWithAllowedIpAddressShouldReturn200()
25+
{
26+
var factory = new WebApplicationFactory<Startup>().WithWebHostBuilder(builder =>
3827
{
39-
var factory = new WebApplicationFactory<Startup>().WithWebHostBuilder(builder =>
28+
builder.UseSetting("https_port", "5001").ConfigureLogging(c => c.AddConsole());
29+
builder.ConfigureTestServices(services =>
4030
{
41-
builder.UseSetting("https_port", "5001").ConfigureLogging(c => c.AddConsole());
42-
builder.ConfigureTestServices(services =>
43-
{
44-
services.AddSingleton<IStartupFilter>(new CustomRemoteIpStartupFilter(IPAddress.Parse("127.168.1.32")));
45-
});
31+
services.AddSingleton<IStartupFilter>(new CustomRemoteIpStartupFilter(IPAddress.Parse("127.0.0.1")));
4632
});
47-
var client = factory.CreateClient();
48-
var response = await client.GetAsync("values");
33+
});
34+
var client = factory.CreateClient();
35+
var response = await client.GetAsync("values");
4936

50-
Assert.AreEqual(HttpStatusCode.Forbidden, response.StatusCode);
51-
}
37+
Assert.AreEqual(HttpStatusCode.OK, response.StatusCode);
38+
Assert.AreEqual("application/json; charset=utf-8", response.Content.Headers.ContentType?.ToString());
5239

53-
[Ignore("I haven't figured out how to test in this way.")]
54-
[TestMethod]
55-
public async Task HttpRequestWithLocalHostIpAddressShouldReturn200()
40+
var json = await response.Content.ReadAsStringAsync();
41+
Assert.AreEqual("[\"value1\",\"value2\"]", json);
42+
}
43+
44+
[TestMethod]
45+
public async Task HttpRequestWithForbiddenIpAddressShouldReturn403()
46+
{
47+
var factory = new WebApplicationFactory<Startup>().WithWebHostBuilder(builder =>
5648
{
57-
var factory = new WebApplicationFactory<Startup>()
58-
.WithWebHostBuilder(builder => builder.UseSetting("https_port", "5001").ConfigureLogging(c => c.AddConsole()));
59-
var client = factory.CreateClient(new WebApplicationFactoryClientOptions
49+
builder.UseSetting("https_port", "5001").ConfigureLogging(c => c.AddConsole());
50+
builder.ConfigureTestServices(services =>
6051
{
61-
AllowAutoRedirect = true
52+
services.AddSingleton<IStartupFilter>(new CustomRemoteIpStartupFilter(IPAddress.Parse("127.168.1.32")));
6253
});
54+
});
55+
var client = factory.CreateClient();
56+
var response = await client.GetAsync("values");
6357

64-
// tried to use TestServer.SendAsync to set RemoteIP, but failed
65-
// https://github.com/aspnet/Hosting/issues/1135
66-
// https://github.com/aspnet/Hosting/pull/1248
67-
// If without setting the RemoteIP, the test server will simply use an empty IP, which breaks the test.
68-
69-
//var context = await factory.Server.SendAsync((c) =>
70-
//{
71-
// c.Connection.RemoteIpAddress = IPAddress.Parse("127.168.1.32");
72-
// c.Request.Method = HttpMethods.Get;
73-
// c.Request.Path = new PathString("/values");
58+
Assert.AreEqual(HttpStatusCode.Forbidden, response.StatusCode);
59+
}
7460

75-
//});
76-
//var response = context.Response;
61+
[TestMethod]
62+
public async Task HttpRequestWithLocalHostIpAddressShouldReturn200()
63+
{
64+
var factory = new WebApplicationFactory<Startup>()
65+
.WithWebHostBuilder(builder => builder.UseSetting("https_port", "5001").ConfigureLogging(c => c.AddConsole()));
7766

78-
//Assert.AreEqual(HttpStatusCode.OK, response.StatusCode);
79-
//Assert.AreEqual("application/json; charset=utf-8", response.ContentType);
67+
var context = await factory.Server.SendAsync((c) =>
68+
{
69+
c.Connection.RemoteIpAddress = IPAddress.Parse("127.168.1.32");
70+
c.Request.Method = HttpMethods.Get;
71+
c.Request.Path = new PathString("/values");
72+
c.Request.IsHttps = true;
8073

81-
//var json = response.Body.ToString();
74+
});
75+
Assert.AreEqual((int)HttpStatusCode.Forbidden, context.Response.StatusCode);
8276

83-
var response = await client.GetAsync("values");
77+
context = await factory.Server.SendAsync((c) =>
78+
{
79+
c.Connection.RemoteIpAddress = IPAddress.Parse("192.168.1.9");
80+
c.Request.Method = HttpMethods.Get;
81+
c.Request.Path = new PathString("/values");
82+
c.Request.IsHttps = true;
8483

85-
Assert.AreEqual(HttpStatusCode.OK, response.StatusCode);
86-
Assert.AreEqual("application/json; charset=utf-8", response.Content.Headers.ContentType?.ToString());
84+
});
85+
var response = context.Response;
86+
Assert.AreEqual((int)HttpStatusCode.OK, response.StatusCode);
87+
Assert.AreEqual("application/json; charset=utf-8", response.ContentType);
8788

88-
var json = await response.Content.ReadAsStringAsync();
89-
Assert.AreEqual("[\"value1\",\"value2\"]", json);
90-
}
89+
using var sr = new StreamReader(response.Body);
90+
var json = await sr.ReadToEndAsync();
91+
Assert.AreEqual("[\"value1\",\"value2\"]", json);
9192
}
92-
}
93+
}

MyWebApp.IntegrationTests/MyWebApp.IntegrationTests.csproj

+6-9
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,18 @@
11
<Project Sdk="Microsoft.NET.Sdk">
22

33
<PropertyGroup>
4-
<TargetFramework>net6.0</TargetFramework>
5-
4+
<TargetFramework>net8.0</TargetFramework>
65
<IsPackable>false</IsPackable>
7-
86
<Nullable>enable</Nullable>
9-
107
<ImplicitUsings>enable</ImplicitUsings>
118
</PropertyGroup>
129

1310
<ItemGroup>
14-
<PackageReference Include="Microsoft.AspNetCore.Mvc.Testing" Version="6.0.8" />
15-
<PackageReference Include="Microsoft.Extensions.Logging.Console" Version="6.0.0" />
16-
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.3.1" />
17-
<PackageReference Include="MSTest.TestAdapter" Version="2.2.10" />
18-
<PackageReference Include="MSTest.TestFramework" Version="2.2.10" />
11+
<PackageReference Include="Microsoft.AspNetCore.Mvc.Testing" Version="8.0.0" />
12+
<PackageReference Include="Microsoft.Extensions.Logging.Console" Version="8.0.0" />
13+
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.8.0" />
14+
<PackageReference Include="MSTest.TestAdapter" Version="3.1.1" />
15+
<PackageReference Include="MSTest.TestFramework" Version="3.1.1" />
1916
</ItemGroup>
2017

2118
<ItemGroup>
+11-19
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,17 @@
11
using Microsoft.AspNetCore.Mvc;
22
using MyWebApp.Infrastructure;
33

4-
namespace MyWebApp.Controllers
4+
namespace MyWebApp.Controllers;
5+
6+
[ApiController]
7+
[Route("[controller]")]
8+
public class ValuesController(ILogger<ValuesController> logger) : ControllerBase
59
{
6-
[ApiController]
7-
[Route("[controller]")]
8-
public class ValuesController : ControllerBase
10+
[ServiceFilter(typeof(ClientIpCheckActionFilter))]
11+
[HttpGet]
12+
public IEnumerable<string> Get()
913
{
10-
private readonly ILogger<ValuesController> _logger;
11-
12-
public ValuesController(ILogger<ValuesController> logger)
13-
{
14-
_logger = logger;
15-
}
16-
17-
[ServiceFilter(typeof(ClientIpCheckActionFilter))]
18-
[HttpGet]
19-
public IEnumerable<string> Get()
20-
{
21-
_logger.LogInformation("Client IP: {remoteIpAddress}", HttpContext.Connection.RemoteIpAddress?.ToString());
22-
return new[] { "value1", "value2" };
23-
}
14+
logger.LogInformation("Client IP: {remoteIpAddress}", HttpContext.Connection.RemoteIpAddress?.ToString());
15+
return new[] { "value1", "value2" };
2416
}
25-
}
17+
}
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,32 @@
1-
using System.Net;
2-
3-
namespace MyWebApp.Infrastructure
1+
namespace MyWebApp.Infrastructure
42
{
5-
public class AdminSafeListMiddleware
3+
public class AdminSafeListMiddleware(
4+
RequestDelegate next,
5+
ILogger<AdminSafeListMiddleware> logger,
6+
IpSafeList safeList)
67
{
7-
private readonly RequestDelegate _next;
8-
private readonly ILogger<AdminSafeListMiddleware> _logger;
9-
private readonly List<IPAddress> _ipAddresses;
10-
private readonly List<IPNetwork> _ipNetworks;
11-
12-
public AdminSafeListMiddleware(RequestDelegate next, ILogger<AdminSafeListMiddleware> logger, IpSafeList safeList)
13-
{
14-
_ipAddresses = safeList.IpAddresses.Split(';').Select(IPAddress.Parse).ToList();
15-
_ipNetworks = safeList.IpNetworks.Split(';').Select(IPNetwork.Parse).ToList();
16-
17-
_next = next;
18-
_logger = logger;
19-
}
20-
21-
public async Task Invoke(HttpContext context)
8+
public Task Invoke(HttpContext context)
229
{
2310
var remoteIp = context.Connection.RemoteIpAddress;
2411
if (remoteIp == null)
2512
{
2613
throw new ArgumentException("Remote IP is NULL, may due to missing ForwardedHeaders.");
2714
}
28-
_logger.LogDebug("Remote IpAddress: {RemoteIp}", remoteIp);
15+
logger.LogDebug("Remote IpAddress: {RemoteIp}", remoteIp);
2916

3017
if (remoteIp.IsIPv4MappedToIPv6)
3118
{
3219
remoteIp = remoteIp.MapToIPv4();
3320
}
3421

35-
if (!_ipAddresses.Contains(remoteIp) && !_ipNetworks.Any(x => x.Contains(remoteIp)))
22+
if (!safeList.IsSafeIp(remoteIp))
3623
{
37-
_logger.LogWarning("Forbidden Request from IP: {remoteIp}", remoteIp);
24+
logger.LogWarning("Forbidden Request from IP: {remoteIp}", remoteIp);
3825
context.Response.StatusCode = StatusCodes.Status403Forbidden;
39-
return;
26+
return Task.CompletedTask;
4027
}
4128

42-
await _next.Invoke(context);
29+
return next.Invoke(context);
4330
}
4431
}
4532
}

0 commit comments

Comments
 (0)