Skip to content

Commit 9e27ad8

Browse files
committed
Optimized PQA keygen code
1 parent 3b349db commit 9e27ad8

File tree

3 files changed

+26
-18
lines changed

3 files changed

+26
-18
lines changed

quantcrypt/internal/pqa/common.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from enum import Enum
1717
from abc import ABC, abstractmethod
1818
from types import ModuleType
19-
from typing import Literal
19+
from typing import Literal, Type
2020
from functools import lru_cache
2121
from ..errors import InvalidArgsError
2222
from .. import utils
@@ -26,13 +26,23 @@
2626
__all__ = [
2727
"PQAVariant",
2828
"BasePQAParamSizes",
29-
"BasePQAlgorithm",
29+
"BasePQAlgorithm"
3030
]
3131

3232

3333
class PQAVariant(Enum):
34-
CLEAN = "clean"
35-
AVX2 = "avx2"
34+
"""
35+
Available binaries:
36+
37+
REF - Clean reference binaries for the x86_64 architecture.
38+
39+
OPT - Speed-optimized binaries for the x86_64 architecture.
40+
41+
ARM - Binaries for the aarch64 architecture.
42+
"""
43+
REF = "clean"
44+
OPT = "avx2"
45+
ARM = "aarch64"
3646

3747

3848
class BasePQAParamSizes:
@@ -71,18 +81,18 @@ def _import(self, variant: PQAVariant) -> ModuleType:
7181
def __init__(self, variant: PQAVariant = None):
7282
# variant is None -> auto-select mode
7383
try:
74-
_var = variant or PQAVariant.AVX2
84+
_var = variant or PQAVariant.OPT
7585
self._lib = self._import(_var)
7686
self.variant = _var
7787
except ModuleNotFoundError as ex:
7888
if variant is None:
7989
try:
80-
self._lib = self._import(PQAVariant.CLEAN)
81-
self.variant = PQAVariant.CLEAN
90+
self._lib = self._import(PQAVariant.REF)
91+
self.variant = PQAVariant.REF
8292
return
8393
except ModuleNotFoundError: # pragma: no cover
8494
pass
85-
elif variant == PQAVariant.AVX2: # pragma: no cover
95+
elif variant == PQAVariant.OPT: # pragma: no cover
8696
raise ex
8797
raise SystemExit( # pragma: no cover
8898
"Quantcrypt Fatal Error:\n"
@@ -95,15 +105,19 @@ def _upper_name(self) -> str:
95105
pattern='.[^A-Z]*'
96106
)).upper()
97107

98-
def _keygen(self, algo_type: Literal["kem", "sign"]) -> tuple[bytes, bytes]:
108+
def _keygen(
109+
self,
110+
algo_type: Literal["kem", "sign"],
111+
error_cls: Type[errors.PQAError]
112+
) -> tuple[bytes, bytes]:
99113
ffi, params = FFI(), self.param_sizes
100114
public_key = ffi.new(f"uint8_t [{params.pk_size}]")
101115
secret_key = ffi.new(f"uint8_t [{params.sk_size}]")
102116

103117
name = f"_crypto_{algo_type}_keypair"
104118
func = getattr(self._lib, self._namespace + name)
105119
if func(public_key, secret_key) != 0: # pragma: no cover
106-
return tuple()
120+
raise error_cls
107121

108122
pk = ffi.buffer(public_key, params.pk_size)
109123
sk = ffi.buffer(secret_key, params.sk_size)

quantcrypt/internal/pqa/dss.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,7 @@ def keygen(self) -> tuple[bytes, bytes]:
7070
library has failed to generate the keys for the current
7171
DSS algorithm for any reason.
7272
"""
73-
result = self._keygen("sign")
74-
if not result: # pragma: no cover
75-
raise errors.DSSKeygenFailedError
76-
return result
73+
return self._keygen("sign", errors.DSSKeygenFailedError)
7774

7875
def sign(self, secret_key: bytes, message: bytes) -> bytes:
7976
"""

quantcrypt/internal/pqa/kem.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,7 @@ def keygen(self) -> tuple[bytes, bytes]:
4646
library has failed to generate the keys for the current
4747
KEM algorithm for any reason.
4848
"""
49-
result = self._keygen("kem")
50-
if not result: # pragma: no cover
51-
raise errors.KEMKeygenFailedError
52-
return result
49+
return self._keygen("kem", errors.KEMKeygenFailedError)
5350

5451
def encaps(self, public_key: bytes) -> tuple[bytes, bytes]:
5552
"""

0 commit comments

Comments
 (0)