12
12
except ImportError :
13
13
pass
14
14
15
+ from ssl import SSLContext , create_default_context
15
16
from errno import EAGAIN , ECONNRESET , ETIMEDOUT
16
17
from sys import implementation
17
18
from time import monotonic , sleep
33
34
from .route import Route
34
35
from .status import BAD_REQUEST_400 , UNAUTHORIZED_401 , FORBIDDEN_403 , NOT_FOUND_404
35
36
37
+ if implementation .name != "circuitpython" :
38
+ from ssl import Purpose , CERT_NONE , SSLError # pylint: disable=ungrouped-imports
39
+
36
40
37
41
NO_REQUEST = "no_request"
38
42
CONNECTION_TIMED_OUT = "connection_timed_out"
39
43
REQUEST_HANDLED_NO_RESPONSE = "request_handled_no_response"
40
44
REQUEST_HANDLED_RESPONSE_SENT = "request_handled_response_sent"
41
45
46
+ # CircuitPython does not have these error codes
47
+ MBEDTLS_ERR_SSL_FATAL_ALERT_MESSAGE = - 30592
48
+
42
49
43
50
class Server : # pylint: disable=too-many-instance-attributes
44
51
"""A basic socket-based HTTP server."""
@@ -52,25 +59,81 @@ class Server: # pylint: disable=too-many-instance-attributes
52
59
root_path : str
53
60
"""Root directory to serve files from. ``None`` if serving files is disabled."""
54
61
62
+ @staticmethod
63
+ def _validate_https_cert_provided (
64
+ certfile : Union [str , None ], keyfile : Union [str , None ]
65
+ ) -> None :
66
+ if certfile is None or keyfile is None :
67
+ raise ValueError ("Both certfile and keyfile must be specified for HTTPS" )
68
+
69
+ @staticmethod
70
+ def _create_circuitpython_ssl_context (certfile : str , keyfile : str ) -> SSLContext :
71
+ ssl_context = create_default_context ()
72
+
73
+ ssl_context .load_verify_locations (cadata = "" )
74
+ ssl_context .load_cert_chain (certfile , keyfile )
75
+
76
+ return ssl_context
77
+
78
+ @staticmethod
79
+ def _create_cpython_ssl_context (certfile : str , keyfile : str ) -> SSLContext :
80
+ ssl_context = create_default_context (purpose = Purpose .CLIENT_AUTH )
81
+
82
+ ssl_context .load_cert_chain (certfile , keyfile )
83
+
84
+ ssl_context .verify_mode = CERT_NONE
85
+ ssl_context .check_hostname = False
86
+
87
+ return ssl_context
88
+
89
+ @classmethod
90
+ def _create_ssl_context (cls , certfile : str , keyfile : str ) -> SSLContext :
91
+ return (
92
+ cls ._create_circuitpython_ssl_context (certfile , keyfile )
93
+ if implementation .name == "circuitpython"
94
+ else cls ._create_cpython_ssl_context (certfile , keyfile )
95
+ )
96
+
55
97
def __init__ (
56
- self , socket_source : _ISocketPool , root_path : str = None , * , debug : bool = False
98
+ self ,
99
+ socket_source : _ISocketPool ,
100
+ root_path : str = None ,
101
+ * ,
102
+ https : bool = False ,
103
+ certfile : str = None ,
104
+ keyfile : str = None ,
105
+ debug : bool = False ,
57
106
) -> None :
58
107
"""Create a server, and get it ready to run.
59
108
60
109
:param socket: An object that is a source of sockets. This could be a `socketpool`
61
110
in CircuitPython or the `socket` module in CPython.
62
111
:param str root_path: Root directory to serve files from
63
112
:param bool debug: Enables debug messages useful during development
113
+ :param bool https: If True, the server will use HTTPS
114
+ :param str certfile: Path to the certificate file, required if ``https`` is True
115
+ :param str keyfile: Path to the private key file, required if ``https`` is True
64
116
"""
65
- self ._auths = []
66
117
self ._buffer = bytearray (1024 )
67
118
self ._timeout = 1
119
+
120
+ self ._auths = []
68
121
self ._routes : "List[Route]" = []
122
+ self .headers = Headers ()
123
+
69
124
self ._socket_source = socket_source
70
125
self ._sock = None
71
- self . headers = Headers ()
126
+
72
127
self .host , self .port = None , None
73
128
self .root_path = root_path
129
+ self .https = https
130
+
131
+ if https :
132
+ self ._validate_https_cert_provided (certfile , keyfile )
133
+ self ._ssl_context = self ._create_ssl_context (certfile , keyfile )
134
+ else :
135
+ self ._ssl_context = None
136
+
74
137
if root_path in ["" , "/" ] and debug :
75
138
_debug_warning_exposed_files (root_path )
76
139
self .stopped = True
@@ -197,6 +260,7 @@ def serve_forever(
197
260
@staticmethod
198
261
def _create_server_socket (
199
262
socket_source : _ISocketPool ,
263
+ ssl_context : "SSLContext | None" ,
200
264
host : str ,
201
265
port : int ,
202
266
) -> _ISocket :
@@ -206,6 +270,9 @@ def _create_server_socket(
206
270
if implementation .version >= (9 ,) or implementation .name != "circuitpython" :
207
271
sock .setsockopt (socket_source .SOL_SOCKET , socket_source .SO_REUSEADDR , 1 )
208
272
273
+ if ssl_context is not None :
274
+ sock = ssl_context .wrap_socket (sock , server_side = True )
275
+
209
276
sock .bind ((host , port ))
210
277
sock .listen (10 )
211
278
sock .setblocking (False ) # Non-blocking socket
@@ -225,7 +292,9 @@ def start(self, host: str = "0.0.0.0", port: int = 5000) -> None:
225
292
self .host , self .port = host , port
226
293
227
294
self .stopped = False
228
- self ._sock = self ._create_server_socket (self ._socket_source , host , port )
295
+ self ._sock = self ._create_server_socket (
296
+ self ._socket_source , self ._ssl_context , host , port
297
+ )
229
298
230
299
if self .debug :
231
300
_debug_started_server (self )
@@ -386,7 +455,9 @@ def _set_default_server_headers(self, response: Response) -> None:
386
455
name , value
387
456
)
388
457
389
- def poll (self ) -> str :
458
+ def poll ( # pylint: disable=too-many-branches,too-many-return-statements
459
+ self ,
460
+ ) -> str :
390
461
"""
391
462
Call this method inside your main loop to get the server to check for new incoming client
392
463
requests. When a request comes in, it will be handled by the handler function.
@@ -399,11 +470,12 @@ def poll(self) -> str:
399
470
400
471
conn = None
401
472
try :
473
+ if self .debug :
474
+ _debug_start_time = monotonic ()
475
+
402
476
conn , client_address = self ._sock .accept ()
403
477
conn .settimeout (self ._timeout )
404
478
405
- _debug_start_time = monotonic ()
406
-
407
479
# Receive the whole request
408
480
if (request := self ._receive_request (conn , client_address )) is None :
409
481
conn .close ()
@@ -424,9 +496,8 @@ def poll(self) -> str:
424
496
# Send the response
425
497
response ._send () # pylint: disable=protected-access
426
498
427
- _debug_end_time = monotonic ()
428
-
429
499
if self .debug :
500
+ _debug_end_time = monotonic ()
430
501
_debug_response_sent (response , _debug_end_time - _debug_start_time )
431
502
432
503
return REQUEST_HANDLED_RESPONSE_SENT
@@ -439,6 +510,15 @@ def poll(self) -> str:
439
510
# Connection reset by peer, try again later.
440
511
if error .errno == ECONNRESET :
441
512
return NO_REQUEST
513
+ # Handshake failed, try again later.
514
+ if error .errno == MBEDTLS_ERR_SSL_FATAL_ALERT_MESSAGE :
515
+ return NO_REQUEST
516
+
517
+ # CPython specific SSL related errors
518
+ if implementation .name != "circuitpython" and isinstance (error , SSLError ):
519
+ # Ignore unknown SSL certificate errors
520
+ if getattr (error , "reason" , None ) == "SSLV3_ALERT_CERTIFICATE_UNKNOWN" :
521
+ return NO_REQUEST
442
522
443
523
if self .debug :
444
524
_debug_exception_in_handler (error )
@@ -547,9 +627,10 @@ def _debug_warning_exposed_files(root_path: str):
547
627
548
628
def _debug_started_server (server : "Server" ):
549
629
"""Prints a message when the server starts."""
630
+ scheme = "https" if server .https else "http"
550
631
host , port = server .host , server .port
551
632
552
- print (f"Started development server on http ://{ host } :{ port } " )
633
+ print (f"Started development server on { scheme } ://{ host } :{ port } " )
553
634
554
635
555
636
def _debug_response_sent (response : "Response" , time_elapsed : float ):
0 commit comments