|
8 | 8 | from fastapi.exceptions import RequestValidationError
|
9 | 9 | from fastapi.middleware.cors import CORSMiddleware
|
10 | 10 | from fastapi.responses import JSONResponse
|
| 11 | +from starlette.datastructures import Headers, MutableHeaders |
11 | 12 | from starlette.exceptions import HTTPException as StarletteHTTPException
|
| 13 | +from starlette.responses import PlainTextResponse, Response |
| 14 | +from starlette.types import ASGIApp, Message, Receive, Scope, Send |
12 | 15 |
|
13 | 16 | from sapporo.auth import get_auth_config
|
14 | 17 | from sapporo.config import (LOGGER, PKG_DIR, add_openapi_info, get_config,
|
@@ -59,6 +62,52 @@ async def generic_exception_handler(_request: Request, _exc: Exception) -> JSONR
|
59 | 62 | )
|
60 | 63 |
|
61 | 64 |
|
| 65 | +class CustomCORSMiddleware(CORSMiddleware): |
| 66 | + """\ |
| 67 | + CORSMiddleware that returns CORS headers even if the Origin header is not present |
| 68 | + """ |
| 69 | + |
| 70 | + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
| 71 | + if scope["type"] != "http": |
| 72 | + await self.app(scope, receive, send) |
| 73 | + return |
| 74 | + |
| 75 | + method = scope["method"] |
| 76 | + headers = Headers(scope=scope) |
| 77 | + |
| 78 | + if method == "OPTIONS" and "access-control-request-method" in headers: |
| 79 | + response = self.preflight_response(request_headers=headers) |
| 80 | + await response(scope, receive, send) |
| 81 | + return |
| 82 | + |
| 83 | + await self.simple_response(scope, receive, send, request_headers=headers) |
| 84 | + |
| 85 | + async def send( |
| 86 | + self, message: Message, send: Send, request_headers: Headers |
| 87 | + ) -> None: |
| 88 | + if message["type"] != "http.response.start": |
| 89 | + await send(message) |
| 90 | + return |
| 91 | + |
| 92 | + message.setdefault("headers", []) |
| 93 | + headers = MutableHeaders(scope=message) |
| 94 | + headers.update(self.simple_headers) |
| 95 | + origin = request_headers.get("Origin", "*") |
| 96 | + has_cookie = "cookie" in request_headers |
| 97 | + |
| 98 | + # If request includes any cookie headers, then we must respond |
| 99 | + # with the specific origin instead of '*'. |
| 100 | + if self.allow_all_origins and has_cookie: |
| 101 | + self.allow_explicit_origin(headers, origin) |
| 102 | + |
| 103 | + # If we only allow specific origins, then we have to mirror back |
| 104 | + # the Origin header in the response. |
| 105 | + elif not self.allow_all_origins and self.is_allowed_origin(origin=origin): |
| 106 | + self.allow_explicit_origin(headers, origin) |
| 107 | + |
| 108 | + await send(message) |
| 109 | + |
| 110 | + |
62 | 111 | def init_app_state() -> None:
|
63 | 112 | """
|
64 | 113 | Perform validation, initialize the cache, and log the configuration contents.
|
@@ -146,9 +195,8 @@ def create_app() -> FastAPI:
|
146 | 195 | )
|
147 | 196 |
|
148 | 197 | app.add_middleware(
|
149 |
| - CORSMiddleware, |
| 198 | + CustomCORSMiddleware, |
150 | 199 | allow_origins=[app_config.allow_origin],
|
151 |
| - allow_credentials=True, |
152 | 200 | allow_methods=["*"],
|
153 | 201 | allow_headers=["*"],
|
154 | 202 | )
|
|
0 commit comments