Skip to content

Commit 41da3a9

Browse files
committed
Fix issue where CORS headers were not being returned properly
1 parent 8c8d672 commit 41da3a9

File tree

1 file changed

+50
-2
lines changed

1 file changed

+50
-2
lines changed

sapporo/app.py

+50-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
from fastapi.exceptions import RequestValidationError
99
from fastapi.middleware.cors import CORSMiddleware
1010
from fastapi.responses import JSONResponse
11+
from starlette.datastructures import Headers, MutableHeaders
1112
from starlette.exceptions import HTTPException as StarletteHTTPException
13+
from starlette.responses import PlainTextResponse, Response
14+
from starlette.types import ASGIApp, Message, Receive, Scope, Send
1215

1316
from sapporo.auth import get_auth_config
1417
from sapporo.config import (LOGGER, PKG_DIR, add_openapi_info, get_config,
@@ -59,6 +62,52 @@ async def generic_exception_handler(_request: Request, _exc: Exception) -> JSONR
5962
)
6063

6164

65+
class CustomCORSMiddleware(CORSMiddleware):
66+
"""\
67+
CORSMiddleware that returns CORS headers even if the Origin header is not present
68+
"""
69+
70+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
71+
if scope["type"] != "http":
72+
await self.app(scope, receive, send)
73+
return
74+
75+
method = scope["method"]
76+
headers = Headers(scope=scope)
77+
78+
if method == "OPTIONS" and "access-control-request-method" in headers:
79+
response = self.preflight_response(request_headers=headers)
80+
await response(scope, receive, send)
81+
return
82+
83+
await self.simple_response(scope, receive, send, request_headers=headers)
84+
85+
async def send(
86+
self, message: Message, send: Send, request_headers: Headers
87+
) -> None:
88+
if message["type"] != "http.response.start":
89+
await send(message)
90+
return
91+
92+
message.setdefault("headers", [])
93+
headers = MutableHeaders(scope=message)
94+
headers.update(self.simple_headers)
95+
origin = request_headers.get("Origin", "*")
96+
has_cookie = "cookie" in request_headers
97+
98+
# If request includes any cookie headers, then we must respond
99+
# with the specific origin instead of '*'.
100+
if self.allow_all_origins and has_cookie:
101+
self.allow_explicit_origin(headers, origin)
102+
103+
# If we only allow specific origins, then we have to mirror back
104+
# the Origin header in the response.
105+
elif not self.allow_all_origins and self.is_allowed_origin(origin=origin):
106+
self.allow_explicit_origin(headers, origin)
107+
108+
await send(message)
109+
110+
62111
def init_app_state() -> None:
63112
"""
64113
Perform validation, initialize the cache, and log the configuration contents.
@@ -146,9 +195,8 @@ def create_app() -> FastAPI:
146195
)
147196

148197
app.add_middleware(
149-
CORSMiddleware,
198+
CustomCORSMiddleware,
150199
allow_origins=[app_config.allow_origin],
151-
allow_credentials=True,
152200
allow_methods=["*"],
153201
allow_headers=["*"],
154202
)

0 commit comments

Comments
 (0)