1
+ from __future__ import annotations
2
+
1
3
import binascii
2
4
import io
3
5
import os
4
6
import ssl
5
- from typing import Tuple , Union
7
+ from typing import TYPE_CHECKING , Any , Callable , ClassVar
6
8
7
9
from .compat import decode_from_bytes , encode_to_bytes
8
10
from .exceptions import StrictMocketException
9
11
12
+ if TYPE_CHECKING :
13
+ from _typeshed import ReadableBuffer
14
+ from typing_extensions import NoReturn
15
+
10
16
SSL_PROTOCOL = ssl .PROTOCOL_TLSv1_2
11
17
12
18
13
19
class MocketSocketCore (io .BytesIO ):
14
- def write (self , content ):
20
+ def write ( # type: ignore[override] # BytesIO returns int
21
+ self ,
22
+ content : ReadableBuffer ,
23
+ ) -> None :
15
24
super (MocketSocketCore , self ).write (content )
16
25
17
26
from mocket import Mocket
@@ -20,7 +29,7 @@ def write(self, content):
20
29
os .write (Mocket .w_fd , content )
21
30
22
31
23
- def hexdump (binary_string ) :
32
+ def hexdump (binary_string : bytes ) -> str :
24
33
r"""
25
34
>>> hexdump(b"bar foobar foo") == decode_from_bytes(encode_to_bytes("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F"))
26
35
True
@@ -29,7 +38,7 @@ def hexdump(binary_string):
29
38
return " " .join (a + b for a , b in zip (bs [::2 ], bs [1 ::2 ]))
30
39
31
40
32
- def hexload (string ) :
41
+ def hexload (string : str ) -> bytes :
33
42
r"""
34
43
>>> hexload("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F") == encode_to_bytes("bar foobar foo")
35
44
True
@@ -38,39 +47,40 @@ def hexload(string):
38
47
return encode_to_bytes (binascii .unhexlify (string_no_spaces ))
39
48
40
49
41
- def get_mocketize (wrapper_ ) :
50
+ def get_mocketize (wrapper_ : Callable ) -> Callable :
42
51
import decorator
43
52
44
- if decorator .__version__ < "5" : # pragma: no cover
53
+ if decorator .__version__ < "5" : # type: ignore[attr-defined] # pragma: no cover
45
54
return decorator .decorator (wrapper_ )
46
- return decorator .decorator (wrapper_ , kwsyntax = True )
55
+ return decorator .decorator ( # type: ignore[call-arg] # kwsyntax
56
+ wrapper_ ,
57
+ kwsyntax = True ,
58
+ )
47
59
48
60
49
61
class MocketMode :
50
- __shared_state = {}
51
- STRICT = None
52
- STRICT_ALLOWED = None
62
+ __shared_state : ClassVar [ dict [ str , Any ]] = {}
63
+ STRICT : ClassVar = None
64
+ STRICT_ALLOWED : ClassVar = None
53
65
54
- def __init__ (self ):
66
+ def __init__ (self ) -> None :
55
67
self .__dict__ = self .__shared_state
56
68
57
- def is_allowed (self , location : Union [ str , Tuple [str , int ] ]) -> bool :
69
+ def is_allowed (self , location : str | tuple [str , int ]) -> bool :
58
70
"""
59
71
Checks if (`host`, `port`) or at least `host`
60
72
are allowed locations to perform real `socket` calls
61
73
"""
62
74
if not self .STRICT :
63
75
return True
64
- try :
65
- host , _ = location
66
- except ValueError :
67
- host = None
68
- return location in self .STRICT_ALLOWED or (
69
- host is not None and host in self .STRICT_ALLOWED
70
- )
76
+
77
+ host_allowed = False
78
+ if isinstance (location , tuple ):
79
+ host_allowed = location [0 ] in self .STRICT_ALLOWED
80
+ return host_allowed or location in self .STRICT_ALLOWED
71
81
72
82
@staticmethod
73
- def raise_not_allowed ():
83
+ def raise_not_allowed () -> NoReturn :
74
84
from .mocket import Mocket
75
85
76
86
current_entries = [
0 commit comments