Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add socket path to scope["server"] #2561

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions tests/protocols/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@


class MockSocket:
def __init__(self, family, peername=None, sockname=None):
def __init__(
self,
family: socket.AddressFamily,
peername: tuple[str, int] | None = None,
sockname: tuple[str, int] | str | None = None,
):
self.peername = peername
self.sockname = sockname
self.family = family
Expand Down Expand Up @@ -41,8 +46,8 @@ def test_get_local_addr_with_socket():
assert get_local_addr(transport) == ("123.45.6.7", 123)

if hasattr(socket, "AF_UNIX"): # pragma: no cover
transport = MockTransport({"socket": MockSocket(family=socket.AF_UNIX, sockname=("127.0.0.1", 8000))})
assert get_local_addr(transport) == ("127.0.0.1", 8000)
transport = MockTransport({"socket": MockSocket(family=socket.AF_UNIX, sockname="/tmp/test.sock")})
assert get_local_addr(transport) == ("/tmp/test.sock", None)


def test_get_remote_addr_with_socket():
Expand All @@ -62,7 +67,7 @@ def test_get_remote_addr_with_socket():

def test_get_local_addr():
transport = MockTransport({"sockname": "path/to/unix-domain-socket"})
assert get_local_addr(transport) is None
assert get_local_addr(transport) == ("path/to/unix-domain-socket", None)

transport = MockTransport({"sockname": ("123.45.6.7", 123)})
assert get_local_addr(transport) == ("123.45.6.7", 123)
Expand All @@ -81,5 +86,5 @@ def test_get_remote_addr():
[({"client": ("127.0.0.1", 36000)}, "127.0.0.1:36000"), ({"client": None}, "")],
ids=["ip:port client", "None client"],
)
def test_get_client_addr(scope, expected_client):
def test_get_client_addr(scope: Any, expected_client: str):
assert get_client_addr(scope) == expected_client
2 changes: 1 addition & 1 deletion uvicorn/protocols/http/h11_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
# Per-connection state
self.transport: asyncio.Transport = None # type: ignore[assignment]
self.flow: FlowControl = None # type: ignore[assignment]
self.server: tuple[str, int] | None = None
self.server: tuple[str, int | None] | None = None
self.client: tuple[str, int] | None = None
self.scheme: Literal["http", "https"] | None = None

Expand Down
2 changes: 1 addition & 1 deletion uvicorn/protocols/http/httptools_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
# Per-connection state
self.transport: asyncio.Transport = None # type: ignore[assignment]
self.flow: FlowControl = None # type: ignore[assignment]
self.server: tuple[str, int] | None = None
self.server: tuple[str, int | None] | None = None
self.client: tuple[str, int] | None = None
self.scheme: Literal["http", "https"] | None = None
self.pipeline: deque[tuple[RequestResponseCycle, ASGI3Application]] = deque()
Expand Down
16 changes: 11 additions & 5 deletions uvicorn/protocols/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import socket
import urllib.parse

from uvicorn._types import WWWScope
Expand All @@ -10,7 +11,7 @@ class ClientDisconnected(OSError): ...


def get_remote_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
socket_info = transport.get_extra_info("socket")
socket_info: socket.socket | None = transport.get_extra_info("socket")
if socket_info is not None:
try:
info = socket_info.getpeername()
Expand All @@ -26,15 +27,20 @@ def get_remote_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
return None


def get_local_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
socket_info = transport.get_extra_info("socket")
def get_local_addr(transport: asyncio.Transport) -> tuple[str, int | None] | None:
socket_info: socket.socket | None = transport.get_extra_info("socket")
if socket_info is not None:
info = socket_info.getsockname()

return (str(info[0]), int(info[1])) if isinstance(info, tuple) else None
if isinstance(info, tuple):
return (str(info[0]), int(info[1]))
elif isinstance(info, str):
return (info, None)
return None
info = transport.get_extra_info("sockname")
if info is not None and isinstance(info, (list, tuple)) and len(info) == 2:
return (str(info[0]), int(info[1]))
elif isinstance(info, str):
return (info, None)
return None


Expand Down
2 changes: 1 addition & 1 deletion uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(

# Connection state
self.transport: asyncio.Transport = None # type: ignore[assignment]
self.server: tuple[str, int] | None = None
self.server: tuple[str, int | None] | None = None
self.client: tuple[str, int] | None = None
self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]

Expand Down
2 changes: 1 addition & 1 deletion uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(

# Connection state
self.transport: asyncio.Transport = None # type: ignore[assignment]
self.server: tuple[str, int] | None = None
self.server: tuple[str, int | None] | None = None
self.client: tuple[str, int] | None = None
self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]

Expand Down
Loading