Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OnPrepareResponseAsync #45062

Merged
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Middleware/StaticFiles/src/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
#nullable enable
Microsoft.AspNetCore.Builder.StaticFileOptions.OnPrepareResponseAsync.get -> System.Func<Microsoft.AspNetCore.StaticFiles.StaticFileResponseContext!, System.Threading.Tasks.Task!>!
Microsoft.AspNetCore.Builder.StaticFileOptions.OnPrepareResponseAsync.set -> void
20 changes: 11 additions & 9 deletions src/Middleware/StaticFiles/src/StaticFileContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ private void ComputeRange()
IsRangeRequest = isRangeRequest;
}

public void ApplyResponseHeaders(int statusCode)
public Task ApplyResponseHeadersAsync(int statusCode)
{
_response.StatusCode = statusCode;
if (statusCode < 400)
Expand All @@ -265,10 +265,13 @@ public void ApplyResponseHeaders(int statusCode)
_response.ContentLength = _length;
}

if (_options.OnPrepareResponse != StaticFileOptions._defaultOnPrepareResponse)
if (_options.OnPrepareResponse != StaticFileOptions._defaultOnPrepareResponse || _options.OnPrepareResponseAsync != StaticFileOptions._defaultOnPrepareResponseAsync)
{
_options.OnPrepareResponse(new StaticFileResponseContext(_context, _fileInfo));
var context = new StaticFileResponseContext(_context, _fileInfo);
_options.OnPrepareResponse(context);
return _options.OnPrepareResponseAsync(context);
}
return Task.CompletedTask;
}

public PreconditionState GetPreconditionState()
Expand All @@ -289,10 +292,9 @@ private static PreconditionState GetMaxPreconditionState(params PreconditionStat

public Task SendStatusAsync(int statusCode)
{
ApplyResponseHeaders(statusCode);

_logger.Handled(statusCode, SubPath);
return Task.CompletedTask;

return ApplyResponseHeadersAsync(statusCode);
}

public async Task ServeStaticFile(HttpContext context, RequestDelegate next)
Expand Down Expand Up @@ -344,7 +346,7 @@ public async Task ServeStaticFile(HttpContext context, RequestDelegate next)
public async Task SendAsync()
{
SetCompressionMode();
ApplyResponseHeaders(StatusCodes.Status200OK);
await ApplyResponseHeadersAsync(StatusCodes.Status200OK);
try
{
await _context.Response.SendFileAsync(_fileInfo, 0, _length, _context.RequestAborted);
Expand All @@ -365,7 +367,7 @@ internal async Task SendRangeAsync()
// SHOULD include a Content-Range field with a byte-range-resp-spec of "*". The instance-length specifies
// the current length of the selected resource. e.g. */length
ResponseHeaders.ContentRange = new ContentRangeHeaderValue(_length);
ApplyResponseHeaders(StatusCodes.Status416RangeNotSatisfiable);
await ApplyResponseHeadersAsync(StatusCodes.Status416RangeNotSatisfiable);

_logger.RangeNotSatisfiable(SubPath);
return;
Expand All @@ -374,7 +376,7 @@ internal async Task SendRangeAsync()
ResponseHeaders.ContentRange = ComputeContentRange(_range, out var start, out var length);
_response.ContentLength = length;
SetCompressionMode();
ApplyResponseHeaders(StatusCodes.Status206PartialContent);
await ApplyResponseHeadersAsync(StatusCodes.Status206PartialContent);

try
{
Expand Down
14 changes: 14 additions & 0 deletions src/Middleware/StaticFiles/src/StaticFileOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace Microsoft.AspNetCore.Builder;
public class StaticFileOptions : SharedOptionsBase
{
internal static readonly Action<StaticFileResponseContext> _defaultOnPrepareResponse = _ => { };
internal static readonly Func<StaticFileResponseContext, Task> _defaultOnPrepareResponseAsync = _ => Task.CompletedTask;

/// <summary>
/// Defaults to all request paths
Expand All @@ -28,6 +29,7 @@ public StaticFileOptions() : this(new SharedOptions())
public StaticFileOptions(SharedOptions sharedOptions) : base(sharedOptions)
{
OnPrepareResponse = _defaultOnPrepareResponse;
OnPrepareResponseAsync = _defaultOnPrepareResponseAsync;
}

/// <summary>
Expand Down Expand Up @@ -61,5 +63,17 @@ public StaticFileOptions(SharedOptions sharedOptions) : base(sharedOptions)
/// Called after the status code and headers have been set, but before the body has been written.
/// This can be used to add or change the response headers.
/// </summary>
/// <remarks>
/// <see cref="OnPrepareResponse" /> is called before <see cref="OnPrepareResponseAsync" />.
/// </remarks>
public Action<StaticFileResponseContext> OnPrepareResponse { get; set; }

/// <summary>
/// Called after the status code and headers have been set, but before the body has been written.
/// This can be used to add or change the response headers.
/// </summary>
/// <remarks>
/// <see cref="OnPrepareResponse" /> is called before <see cref="OnPrepareResponseAsync" />.
/// </remarks>
public Func<StaticFileResponseContext, Task> OnPrepareResponseAsync { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,142 @@ private async Task FoundFile_Served(string baseUrl, string baseDir, string reque
}
}

[Fact]
public async Task OnPrepareResponse_Executed_Test()
{
var baseUrl = "";
var baseDir = @".";
var requestUrl = "/testDocument.Txt";

var onPrepareResponseExecuted = false;

using (var fileProvider = new PhysicalFileProvider(Path.Combine(AppContext.BaseDirectory, baseDir)))
{
using var host = await StaticFilesTestServer.Create(app => app.UseStaticFiles(new StaticFileOptions
{
RequestPath = new PathString(baseUrl),
FileProvider = fileProvider,
OnPrepareResponse = context =>
{
onPrepareResponseExecuted = true;
}
}));
using var server = host.GetTestServer();
var fileInfo = fileProvider.GetFileInfo(Path.GetFileName(requestUrl));
var response = await server.CreateRequest(requestUrl).GetAsync();
var responseContent = await response.Content.ReadAsByteArrayAsync();

Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Equal("text/plain", response.Content.Headers.ContentType.ToString());
Assert.True(response.Content.Headers.ContentLength == fileInfo.Length);
Assert.Equal(response.Content.Headers.ContentLength, responseContent.Length);
Assert.NotNull(response.Headers.ETag);

using (var stream = fileInfo.CreateReadStream())
{
var fileContents = new byte[stream.Length];
stream.Read(fileContents, 0, (int)stream.Length);
Assert.True(responseContent.SequenceEqual(fileContents));
}
}

Assert.True(onPrepareResponseExecuted);
}

[Fact]
public async Task OnPrepareResponseAsync_Executed_Test()
{
var baseUrl = "";
var baseDir = @".";
var requestUrl = "/testDocument.Txt";

var onPrepareResponseExecuted = false;

using (var fileProvider = new PhysicalFileProvider(Path.Combine(AppContext.BaseDirectory, baseDir)))
{
using var host = await StaticFilesTestServer.Create(app => app.UseStaticFiles(new StaticFileOptions
{
RequestPath = new PathString(baseUrl),
FileProvider = fileProvider,
OnPrepareResponseAsync = context =>
{
onPrepareResponseExecuted = true;

return Task.CompletedTask;
}
}));
using var server = host.GetTestServer();
var fileInfo = fileProvider.GetFileInfo(Path.GetFileName(requestUrl));
var response = await server.CreateRequest(requestUrl).GetAsync();
var responseContent = await response.Content.ReadAsByteArrayAsync();

Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Equal("text/plain", response.Content.Headers.ContentType.ToString());
Assert.True(response.Content.Headers.ContentLength == fileInfo.Length);
Assert.Equal(response.Content.Headers.ContentLength, responseContent.Length);
Assert.NotNull(response.Headers.ETag);

using (var stream = fileInfo.CreateReadStream())
{
var fileContents = new byte[stream.Length];
stream.Read(fileContents, 0, (int)stream.Length);
Assert.True(responseContent.SequenceEqual(fileContents));
}
}

Assert.True(onPrepareResponseExecuted);
}

[Fact]
public async Task OnPrepareResponse_Execution_Order_Test()
{
var baseUrl = "";
var baseDir = @".";
var requestUrl = "/testDocument.Txt";

DateTime onPrepareExecutionTime = DateTime.MinValue;
DateTime onPrepareAsyncExecutionTime = DateTime.MinValue;

using (var fileProvider = new PhysicalFileProvider(Path.Combine(AppContext.BaseDirectory, baseDir)))
{
using var host = await StaticFilesTestServer.Create(app => app.UseStaticFiles(new StaticFileOptions
{
RequestPath = new PathString(baseUrl),
FileProvider = fileProvider,
OnPrepareResponse = context =>
{
onPrepareExecutionTime = DateTime.Now;
},
OnPrepareResponseAsync = context =>
{
onPrepareAsyncExecutionTime = DateTime.Now;
return Task.CompletedTask;
}
}));
using var server = host.GetTestServer();
var fileInfo = fileProvider.GetFileInfo(Path.GetFileName(requestUrl));
var response = await server.CreateRequest(requestUrl).GetAsync();
var responseContent = await response.Content.ReadAsByteArrayAsync();

Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Equal("text/plain", response.Content.Headers.ContentType.ToString());
Assert.True(response.Content.Headers.ContentLength == fileInfo.Length);
Assert.Equal(response.Content.Headers.ContentLength, responseContent.Length);
Assert.NotNull(response.Headers.ETag);

using (var stream = fileInfo.CreateReadStream())
{
var fileContents = new byte[stream.Length];
stream.Read(fileContents, 0, (int)stream.Length);
Assert.True(responseContent.SequenceEqual(fileContents));
}
}
Assert.NotEqual(DateTime.MinValue, onPrepareExecutionTime);
Assert.NotEqual(DateTime.MinValue, onPrepareAsyncExecutionTime);

Assert.True(onPrepareExecutionTime.Ticks < onPrepareAsyncExecutionTime.Ticks);
}

[Fact]
public async Task File_Served_If_Endpoint_With_Null_RequestDelegate_Is_Active()
{
Expand Down