diff --git a/quantcrypt/internal/cipher/krypton.py b/quantcrypt/internal/cipher/krypton.py index bb2eddd..d0fb42c 100644 --- a/quantcrypt/internal/cipher/krypton.py +++ b/quantcrypt/internal/cipher/krypton.py @@ -18,6 +18,7 @@ from Cryptodome.Util.strxor import strxor from Cryptodome.Cipher import AES from collections.abc import Callable +from dataclasses import dataclass from ..errors import InvalidArgsError from ..kdf.kmac_kdf import KKDF from .common import ( @@ -28,7 +29,13 @@ from . import errors -__all__ = ["Krypton"] +__all__ = ["DecryptedData", "Krypton"] + + +@dataclass +class DecryptedData: + plaintext: bytes = None, + header: bytes = None class Krypton: @@ -293,7 +300,8 @@ def encrypt_file( :param plaintext_file: Path to the plaintext file, which must exist. :param output_file: Path to the ciphertext file. If the file exists, it will be overwritten. - :param header: Associated Authenticated Data + :param header: Associated Authenticated Data, which is included + unencrypted into the metadata field of the generated ciphertext file. :param context: Optional field to describe the ciphers purpose. Alters the output of internal hash functions. Not a secret. :param chunk_size: By default, the chunk size is automatically determined @@ -320,7 +328,8 @@ def encrypt_file( output_file.unlink(missing_ok=True) output_file.touch() with open(output_file, 'r+b') as out_file: - out_file.write(b'0' * 170) # reserved space + reserved_space = b'0' * (180 + len(header)) + out_file.write(reserved_space) while True: chunk = in_file.read(chunk_size.value) if not chunk: @@ -329,10 +338,14 @@ def encrypt_file( out_file.write(ciphertext) # chunk_size + 1 byte if callback: callback() - vdp = krypton.finish_encryption() # 160 bytes - cs = f"{chunk_size.value:0>10}".encode("utf-8") # 10 bytes out_file.seek(0) - out_file.write(vdp + cs) # 170 bytes + + h_len = f"{len(header):0>10}".encode("utf-8") # 10 bytes + cs = f"{chunk_size.value:0>10}".encode("utf-8") # 10 bytes + vdp = krypton.finish_encryption() # 160 bytes + + out_file.write(h_len + cs + vdp) # 180 bytes + out_file.write(header) # len(header) bytes @classmethod @utils.input_validator() @@ -341,12 +354,11 @@ def decrypt_file( secret_key: Annotated[bytes, Field(min_length=64, max_length=64)], ciphertext_file: Path, output_file: Path | None = None, - header: bytes = b'', context: Annotated[Optional[bytes], Field(default=b'')] = b'', callback: Callable = None, *, into_memory: bool = False - ) -> bytes | None: + ) -> DecryptedData: """ Decrypts a file of any size on disk in chunks. The user must provide either a path for the `output_file` parameter, where the decrypted plaintext will be @@ -358,7 +370,6 @@ def decrypt_file( :param ciphertext_file: Path to the ciphertext file, which must exist. :param output_file: Path to the plaintext file, optional. If the file exists, it will be overwritten. - :param header: Associated Authenticated Data :param context: Optional field to describe the ciphers purpose. Alters the output of internal hash functions. Not a secret. :param callback: This callback, when provided, will be called for each @@ -388,25 +399,31 @@ def _decrypted_chunk() -> Generator[bytes, None, None]: yield _pt with open(ciphertext_file, 'rb') as in_file: - vdp_cs = in_file.read(170) - vdp, cs = vdp_cs[:160], vdp_cs[160:] + data = in_file.read(180) + h_len, cs, vdp = data[:10], data[10:20], data[20:180] + h_len_int = int(h_len.decode("utf-8")) + header = in_file.read(h_len_int) cs_int = int(cs.decode("utf-8")) krypton = cls(secret_key, context, None) setattr(krypton, '_chunk_size', cs_int) krypton.begin_decryption(vdp, header) + plaintext = None if into_memory: plaintext = bytes() for chunk in _decrypted_chunk(): plaintext += chunk - krypton.finish_decryption() - return plaintext - - output_file.unlink(missing_ok=True) - output_file.touch() - with output_file.open("wb") as out_file: - for chunk in _decrypted_chunk(): - out_file.write(chunk) - krypton.finish_decryption() + else: + output_file.unlink(missing_ok=True) + output_file.touch() + with output_file.open("wb") as out_file: + for chunk in _decrypted_chunk(): + out_file.write(chunk) + + krypton.finish_decryption() + return DecryptedData( + plaintext=plaintext, + header=header + ) diff --git a/quantcrypt/internal/cli/main.py b/quantcrypt/internal/cli/main.py index ff215ce..346b3d5 100644 --- a/quantcrypt/internal/cli/main.py +++ b/quantcrypt/internal/cli/main.py @@ -34,7 +34,7 @@ @app.callback() -def main(version: VersionAtd = False, info: InfoAtd = False): +def main(version: VersionAtd = False, info: InfoAtd = False) -> None: if version: pkg_info = PackageInfo() print(pkg_info.Version) diff --git a/tests/test_cipher/test_krypton.py b/tests/test_cipher/test_krypton.py index 24dabaa..1ee02f0 100644 --- a/tests/test_cipher/test_krypton.py +++ b/tests/test_cipher/test_krypton.py @@ -19,7 +19,7 @@ from quantcrypt.internal.cipher import errors -@pytest.fixture(name="file_data") +@pytest.fixture(name="file_data", scope="function") def fixture_file_data(tmp_path: Path) -> DotMap: orig_pt = os.urandom(1024 * 16) pt_file = tmp_path / "test_file.bin" @@ -219,13 +219,21 @@ def _cb(): assert sum(counters[1]) == 4 +def test_krypton_file_enc_dec_header(file_data: DotMap): + header = b'z' * 32 + Krypton.encrypt_file(file_data.sk, file_data.pt_file, file_data.ct_file, header) + dec_data = Krypton.decrypt_file(file_data.sk, file_data.ct_file, file_data.pt2_file) + assert dec_data.plaintext is None + assert dec_data.header == header + + def test_krypton_file_enc_dec_into_memory(file_data: DotMap): Krypton.encrypt_file(file_data.sk, file_data.pt_file, file_data.ct_file) - pt2 = Krypton.decrypt_file( + dec_data = Krypton.decrypt_file( file_data.sk, file_data.ct_file, file_data.pt2_file, into_memory=True ) - assert pt2 == file_data.orig_pt + assert dec_data.plaintext == file_data.orig_pt def test_krypton_file_enc_dec_chunk_size_override(file_data: DotMap): @@ -238,12 +246,12 @@ def callback(): file_data.sk, file_data.pt_file, file_data.ct_file, chunk_size=ChunkSize.KB(1) ) - pt2 = Krypton.decrypt_file( + dec_data = Krypton.decrypt_file( file_data.sk, file_data.ct_file, file_data.pt2_file, callback=callback, into_memory=True ) assert sum(counter) == 16 - assert pt2 == file_data.orig_pt + assert dec_data.plaintext == file_data.orig_pt def test_krypton_file_enc_dec_errors(tmp_path: Path):