16
16
from enum import Enum
17
17
from abc import ABC , abstractmethod
18
18
from types import ModuleType
19
- from typing import Literal
19
+ from typing import Literal , Type
20
20
from functools import lru_cache
21
21
from ..errors import InvalidArgsError
22
22
from .. import utils
26
26
__all__ = [
27
27
"PQAVariant" ,
28
28
"BasePQAParamSizes" ,
29
- "BasePQAlgorithm" ,
29
+ "BasePQAlgorithm"
30
30
]
31
31
32
32
33
33
class PQAVariant (Enum ):
34
- CLEAN = "clean"
35
- AVX2 = "avx2"
34
+ """
35
+ Available binaries:
36
+
37
+ REF - Clean reference binaries for the x86_64 architecture.
38
+
39
+ OPT - Speed-optimized binaries for the x86_64 architecture.
40
+
41
+ ARM - Binaries for the aarch64 architecture.
42
+ """
43
+ REF = "clean"
44
+ OPT = "avx2"
45
+ ARM = "aarch64"
36
46
37
47
38
48
class BasePQAParamSizes :
@@ -71,18 +81,18 @@ def _import(self, variant: PQAVariant) -> ModuleType:
71
81
def __init__ (self , variant : PQAVariant = None ):
72
82
# variant is None -> auto-select mode
73
83
try :
74
- _var = variant or PQAVariant .AVX2
84
+ _var = variant or PQAVariant .OPT
75
85
self ._lib = self ._import (_var )
76
86
self .variant = _var
77
87
except ModuleNotFoundError as ex :
78
88
if variant is None :
79
89
try :
80
- self ._lib = self ._import (PQAVariant .CLEAN )
81
- self .variant = PQAVariant .CLEAN
90
+ self ._lib = self ._import (PQAVariant .REF )
91
+ self .variant = PQAVariant .REF
82
92
return
83
93
except ModuleNotFoundError : # pragma: no cover
84
94
pass
85
- elif variant == PQAVariant .AVX2 : # pragma: no cover
95
+ elif variant == PQAVariant .OPT : # pragma: no cover
86
96
raise ex
87
97
raise SystemExit ( # pragma: no cover
88
98
"Quantcrypt Fatal Error:\n "
@@ -95,15 +105,19 @@ def _upper_name(self) -> str:
95
105
pattern = '.[^A-Z]*'
96
106
)).upper ()
97
107
98
- def _keygen (self , algo_type : Literal ["kem" , "sign" ]) -> tuple [bytes , bytes ]:
108
+ def _keygen (
109
+ self ,
110
+ algo_type : Literal ["kem" , "sign" ],
111
+ error_cls : Type [errors .PQAError ]
112
+ ) -> tuple [bytes , bytes ]:
99
113
ffi , params = FFI (), self .param_sizes
100
114
public_key = ffi .new (f"uint8_t [{ params .pk_size } ]" )
101
115
secret_key = ffi .new (f"uint8_t [{ params .sk_size } ]" )
102
116
103
117
name = f"_crypto_{ algo_type } _keypair"
104
118
func = getattr (self ._lib , self ._namespace + name )
105
119
if func (public_key , secret_key ) != 0 : # pragma: no cover
106
- return tuple ()
120
+ raise error_cls
107
121
108
122
pk = ffi .buffer (public_key , params .pk_size )
109
123
sk = ffi .buffer (secret_key , params .sk_size )
0 commit comments