5
5
import ipaddress
6
6
import sys
7
7
import unittest
8
+ from abc import ABC , abstractmethod
8
9
from collections import defaultdict
9
- from typing import cast
10
+ from typing import cast , Callable , Union
10
11
11
12
import pytest
12
13
import pytest_httpserver
26
27
from geoip2 .webservice import AsyncClient , Client
27
28
28
29
29
- class TestBaseClient (unittest .TestCase ):
30
+ class TestBaseClient (unittest .TestCase , ABC ):
31
+ client : Union [AsyncClient , Client ]
32
+ client_class : Callable [[int , str ], Union [AsyncClient , Client ]]
33
+
30
34
country = {
31
35
"continent" : {"code" : "NA" , "geoname_id" : 42 , "names" : {"en" : "North America" }},
32
36
"country" : {
@@ -54,6 +58,9 @@ class TestBaseClient(unittest.TestCase):
54
58
insights ["traits" ]["user_count" ] = 2
55
59
insights ["traits" ]["static_ip_score" ] = 1.3
56
60
61
+ @abstractmethod
62
+ def run_client (self , v ): ...
63
+
57
64
def _content_type (self , endpoint ):
58
65
return (
59
66
"application/vnd.maxmind.com-"
@@ -319,7 +326,7 @@ def user_agent_compare(actual: str, expected: str) -> bool:
319
326
header_value_matcher = HeaderValueMatcher (
320
327
defaultdict (
321
328
lambda : HeaderValueMatcher .default_header_value_matcher ,
322
- {"User-Agent" : user_agent_compare },
329
+ {"User-Agent" : user_agent_compare }, # type: ignore[dict-item]
323
330
),
324
331
),
325
332
).respond_with_json (
@@ -374,19 +381,22 @@ def test_insights_ok(self) -> None:
374
381
def test_named_constructor_args (self ) -> None :
375
382
id = 47
376
383
key = "1234567890ab"
377
- client = self .client_class (account_id = id , license_key = key )
384
+ client = self .client_class (id , key )
378
385
self .assertEqual (client ._account_id , str (id ))
379
386
self .assertEqual (client ._license_key , key )
380
387
381
388
def test_missing_constructor_args (self ) -> None :
382
389
with self .assertRaises (TypeError ):
383
- self .client_class (license_key = "1234567890ab" )
390
+
391
+ self .client_class (license_key = "1234567890ab" ) # type: ignore[call-arg]
384
392
385
393
with self .assertRaises (TypeError ):
386
- self .client_class ("47" )
394
+ self .client_class ("47" ) # type: ignore
387
395
388
396
389
397
class TestClient (TestBaseClient ):
398
+ client : Client
399
+
390
400
def setUp (self ) -> None :
391
401
self .client_class = Client
392
402
self .client = Client (42 , "abcdef123456" )
@@ -398,6 +408,8 @@ def run_client(self, v):
398
408
399
409
400
410
class TestAsyncClient (TestBaseClient ):
411
+ client : AsyncClient
412
+
401
413
def setUp (self ) -> None :
402
414
self ._loop = asyncio .new_event_loop ()
403
415
self .client_class = AsyncClient
0 commit comments