Skip to content

Commit 9b06f5a

Browse files
committed
Modified Server to create and use SSLContext
1 parent b00f70f commit 9b06f5a

File tree

1 file changed

+50
-5
lines changed

1 file changed

+50
-5
lines changed

Diff for: adafruit_httpserver/server.py

+50-5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
except ImportError:
1313
pass
1414

15+
from ssl import SSLContext, create_default_context
1516
from errno import EAGAIN, ECONNRESET, ETIMEDOUT
1617
from sys import implementation
1718
from time import monotonic, sleep
@@ -39,6 +40,9 @@
3940
REQUEST_HANDLED_NO_RESPONSE = "request_handled_no_response"
4041
REQUEST_HANDLED_RESPONSE_SENT = "request_handled_response_sent"
4142

43+
# CircuitPython does not have these error codes
44+
MBEDTLS_ERR_SSL_FATAL_ALERT_MESSAGE = -30592
45+
4246

4347
class Server: # pylint: disable=too-many-instance-attributes
4448
"""A basic socket-based HTTP server."""
@@ -52,25 +56,57 @@ class Server: # pylint: disable=too-many-instance-attributes
5256
root_path: str
5357
"""Root directory to serve files from. ``None`` if serving files is disabled."""
5458

59+
@staticmethod
60+
def _validate_https_cert_provided(certfile: str, keyfile: str) -> None:
61+
if not certfile or not keyfile:
62+
raise ValueError("Both certfile and keyfile must be specified for HTTPS")
63+
64+
@staticmethod
65+
def _create_ssl_context(certfile: str, keyfile: str) -> SSLContext:
66+
ssl_context = create_default_context()
67+
ssl_context.load_verify_locations(cadata="")
68+
ssl_context.load_cert_chain(certfile, keyfile)
69+
70+
return ssl_context
71+
5572
def __init__(
56-
self, socket_source: _ISocketPool, root_path: str = None, *, debug: bool = False
73+
self,
74+
socket_source: _ISocketPool,
75+
root_path: str = None,
76+
*,
77+
https: bool = False,
78+
certfile: str = None,
79+
keyfile: str = None,
80+
debug: bool = False,
5781
) -> None:
5882
"""Create a server, and get it ready to run.
5983
6084
:param socket: An object that is a source of sockets. This could be a `socketpool`
6185
in CircuitPython or the `socket` module in CPython.
6286
:param str root_path: Root directory to serve files from
6387
:param bool debug: Enables debug messages useful during development
88+
:param bool https: If True, the server will use HTTPS
89+
:param str certfile: Path to the certificate file, required if ``https`` is True
90+
:param str keyfile: Path to the private key file, required if ``https`` is True
6491
"""
65-
self._auths = []
6692
self._buffer = bytearray(1024)
6793
self._timeout = 1
94+
95+
self._auths = []
6896
self._routes: "List[Route]" = []
97+
self.headers = Headers()
98+
6999
self._socket_source = socket_source
70100
self._sock = None
71-
self.headers = Headers()
101+
72102
self.host, self.port = None, None
73103
self.root_path = root_path
104+
self.https = https
105+
106+
if https:
107+
self._validate_https_cert_provided(certfile, keyfile)
108+
self._ssl_context = self._create_ssl_context(certfile, keyfile)
109+
74110
if root_path in ["", "/"] and debug:
75111
_debug_warning_exposed_files(root_path)
76112
self.stopped = True
@@ -197,6 +233,7 @@ def serve_forever(
197233
@staticmethod
198234
def _create_server_socket(
199235
socket_source: _ISocketPool,
236+
ssl_context: SSLContext,
200237
host: str,
201238
port: int,
202239
) -> _ISocket:
@@ -206,6 +243,9 @@ def _create_server_socket(
206243
if implementation.version >= (9,) or implementation.name != "circuitpython":
207244
sock.setsockopt(socket_source.SOL_SOCKET, socket_source.SO_REUSEADDR, 1)
208245

246+
if ssl_context is not None:
247+
sock = ssl_context.wrap_socket(sock, server_side=True)
248+
209249
sock.bind((host, port))
210250
sock.listen(10)
211251
sock.setblocking(False) # Non-blocking socket
@@ -225,7 +265,9 @@ def start(self, host: str = "0.0.0.0", port: int = 5000) -> None:
225265
self.host, self.port = host, port
226266

227267
self.stopped = False
228-
self._sock = self._create_server_socket(self._socket_source, host, port)
268+
self._sock = self._create_server_socket(
269+
self._socket_source, self._ssl_context, host, port
270+
)
229271

230272
if self.debug:
231273
_debug_started_server(self)
@@ -439,6 +481,8 @@ def poll(self) -> str:
439481
# Connection reset by peer, try again later.
440482
if error.errno == ECONNRESET:
441483
return NO_REQUEST
484+
if error.errno == MBEDTLS_ERR_SSL_FATAL_ALERT_MESSAGE:
485+
return NO_REQUEST
442486

443487
if self.debug:
444488
_debug_exception_in_handler(error)
@@ -547,9 +591,10 @@ def _debug_warning_exposed_files(root_path: str):
547591

548592
def _debug_started_server(server: "Server"):
549593
"""Prints a message when the server starts."""
594+
scheme = "https" if server.https else "http"
550595
host, port = server.host, server.port
551596

552-
print(f"Started development server on http://{host}:{port}")
597+
print(f"Started development server on {scheme}://{host}:{port}")
553598

554599

555600
def _debug_response_sent(response: "Response", time_elapsed: float):

0 commit comments

Comments
 (0)