|
1 | 1 | #include "blob/hash.h"
|
2 |
| -#include "blob/wolfssl_init.h" |
3 | 2 |
|
4 | 3 | #include <expected>
|
5 | 4 |
|
6 | 5 |
|
7 |
| -namespace Blob |
| 6 | +#include <mbedtls/md.h> |
| 7 | + |
| 8 | + |
| 9 | +using namespace Blob; |
| 10 | + |
| 11 | +HMAC::HMAC(const Bytes& key, HashType hashType) |
8 | 12 | {
|
9 |
| - std::expected<Bytes, Error> HMACSum(const Bytes& key, const Bytes& msg, HashType type) |
10 |
| - { |
11 |
| - HMAC hmac(key, type); |
12 |
| - Error err = hmac.Write(msg); |
13 |
| - if (err != Error::None) { |
14 |
| - return std::unexpected<Error>(err); |
15 |
| - } |
16 |
| - |
17 |
| - return hmac.Sum(); |
| 13 | + const mbedtls_md_info_t* mdInfo = nullptr; |
| 14 | + switch (hashType) { |
| 15 | + case HashType::SHA1: |
| 16 | + mdInfo = mbedtls_md_info_from_type(MBEDTLS_MD_SHA1); |
| 17 | + break; |
| 18 | + case HashType::SHA256: |
| 19 | + mdInfo = mbedtls_md_info_from_type(MBEDTLS_MD_SHA256); |
| 20 | + break; |
| 21 | + case HashType::SHA512: |
| 22 | + mdInfo = mbedtls_md_info_from_type(MBEDTLS_MD_SHA512); |
| 23 | + break; |
| 24 | + default: |
| 25 | + err = Error::Unsupported; |
| 26 | + return; |
18 | 27 | }
|
19 | 28 |
|
20 |
| - std::expected<Bytes, Error> HKDF(const Bytes& key, const Bytes& salt, const Bytes& info, size_t outLen, HashType hashType = HashType::SHA256) |
21 |
| - { |
22 |
| - int wcType; |
23 |
| - switch (hashType) |
24 |
| - { |
25 |
| - case HashType::SHA1: |
26 |
| - wcType = WC_SHA; |
27 |
| - break; |
28 |
| - case HashType::SHA256: |
29 |
| - wcType = WC_SHA256; |
30 |
| - break; |
31 |
| - case HashType::SHA512: |
32 |
| - wcType = WC_SHA512; |
33 |
| - break; |
34 |
| - |
35 |
| - default: |
36 |
| - return std::unexpected<Error>{Error::Unsupported}; |
37 |
| - } |
38 |
| - |
39 |
| - Bytes out(outLen, '\xff'); |
40 |
| - int ret = wc_HKDF(wcType, key.c_str(), key.size(), salt.c_str(), salt.size(), info.c_str(), info.size(), out.data(), outLen); |
41 |
| - if (ret != 0) { |
42 |
| - return std::unexpected<Error>{Error::ShitHappens}; |
43 |
| - } |
44 |
| - |
45 |
| - return out; |
| 29 | + if (mdInfo == nullptr) { |
| 30 | + err = Error::Unsupported; |
| 31 | + return; |
46 | 32 | }
|
47 | 33 |
|
48 |
| - HMAC::HMAC(const Bytes& key, HashType hashType) |
49 |
| - { |
50 |
| - int wcType; |
51 |
| - switch (hashType) |
52 |
| - { |
53 |
| - case HashType::SHA1: |
54 |
| - wcType = WC_SHA; |
55 |
| - break; |
56 |
| - case HashType::SHA256: |
57 |
| - wcType = WC_SHA256; |
58 |
| - break; |
59 |
| - case HashType::SHA512: |
60 |
| - wcType = WC_SHA512; |
61 |
| - break; |
62 |
| - |
63 |
| - default: |
64 |
| - err = Error::Unsupported; |
65 |
| - return; |
66 |
| - } |
67 |
| - |
68 |
| - int ret = wc_HmacInit(&ctx, nullptr, INVALID_DEVID); |
69 |
| - if (ret != 0) { |
70 |
| - err = Error::InitFailed; |
71 |
| - return; |
72 |
| - } |
73 |
| - |
74 |
| - ret = wc_HmacSetKey(&ctx, wcType, key.c_str(), key.size()); |
75 |
| - if (ret != 0) { |
76 |
| - err = Error::InvalidKey; |
77 |
| - wc_HmacFree(&ctx); |
78 |
| - return; |
79 |
| - } |
| 34 | + ctx = static_cast<mbedtls_md_context_t*>(malloc(sizeof(mbedtls_md_context_t))); |
| 35 | + if (ctx == nullptr) { |
| 36 | + err = Error::ShitHappens; |
| 37 | + return; |
80 | 38 | }
|
81 | 39 |
|
82 |
| - Error HMAC::Write(const Bytes& data) |
83 |
| - { |
84 |
| - if (err != Error::None) { |
85 |
| - return err; |
86 |
| - } |
| 40 | + mbedtls_md_init(ctx); |
| 41 | + int rc = mbedtls_md_setup(ctx, mdInfo, 1); |
| 42 | + if (rc != 0) { |
| 43 | + mbedtls_md_free(ctx); |
| 44 | + free(ctx); |
| 45 | + err = Error::InitFailed; |
| 46 | + return; |
| 47 | + } |
87 | 48 |
|
88 |
| - int ret = wc_HmacUpdate(&ctx, data.c_str(), data.size()); |
89 |
| - if (ret != 0) { |
90 |
| - return Error::ShitHappens; |
91 |
| - } |
| 49 | + rc = mbedtls_md_hmac_starts(ctx, key.c_str(), key.size()); |
| 50 | + if (rc != 0) { |
| 51 | + mbedtls_md_free(ctx); |
| 52 | + free(ctx); |
| 53 | + err = Error::InitFailed; |
| 54 | + return; |
| 55 | + } |
| 56 | +} |
92 | 57 |
|
93 |
| - return Error::None; |
| 58 | +Error HMAC::Write(const Bytes& data) |
| 59 | +{ |
| 60 | + if (err != Error::None) { |
| 61 | + return err; |
94 | 62 | }
|
95 | 63 |
|
96 |
| - std::expected<Bytes, Error> HMAC::Sum() |
97 |
| - { |
98 |
| - if (err != Error::None) { |
99 |
| - return std::unexpected<Error>{err}; |
100 |
| - } |
| 64 | + if (mbedtls_md_hmac_update(ctx, data.c_str(), data.length())) { |
| 65 | + return Error::ShitHappens; |
| 66 | + } |
101 | 67 |
|
102 |
| - size_t hashLen = wc_HmacSizeByType(ctx.macType); |
103 |
| - if (hashLen <= 0) { |
104 |
| - return std::unexpected<Error>{Error::ShitHappens}; |
105 |
| - } |
| 68 | + return Error::None; |
| 69 | +} |
106 | 70 |
|
107 |
| - Bytes out(hashLen, '\xff'); |
108 |
| - int ret = wc_HmacFinal(&ctx, out.data()); |
109 |
| - if (ret != 0) { |
110 |
| - return std::unexpected<Error>{Error::ShitHappens}; |
111 |
| - } |
| 71 | +std::expected<Bytes, Error> HMAC::Sum() |
| 72 | +{ |
| 73 | + if (err != Error::None) { |
| 74 | + return std::unexpected<Error>{err}; |
| 75 | + } |
112 | 76 |
|
113 |
| - return out; |
| 77 | + size_t hashLen = static_cast<size_t>(mbedtls_md_get_size(mbedtls_md_info_from_ctx(ctx))); |
| 78 | + if (hashLen == 0) { |
| 79 | + return std::unexpected<Error>{Error::ShitHappens}; |
114 | 80 | }
|
115 | 81 |
|
116 |
| - HMAC::~HMAC() |
117 |
| - { |
118 |
| - if (err != Error::None) { |
119 |
| - return; |
120 |
| - } |
| 82 | + Bytes out(hashLen, '\xff'); |
| 83 | + int ret = mbedtls_md_hmac_finish(ctx, out.data()); |
| 84 | + if (ret != 0) { |
| 85 | + return std::unexpected<Error>{Error::ShitHappens}; |
| 86 | + } |
| 87 | + |
| 88 | + return out; |
| 89 | +} |
| 90 | + |
| 91 | +HMAC::~HMAC() |
| 92 | +{ |
| 93 | + if (ctx == nullptr) { |
| 94 | + return; |
| 95 | + } |
| 96 | + |
| 97 | + mbedtls_md_free(ctx); |
| 98 | + free(ctx); |
| 99 | + ctx = nullptr; |
| 100 | +} |
121 | 101 |
|
122 |
| - wc_HmacFree(&ctx); |
| 102 | +std::expected<Bytes, Error> HMAC::Sum(const Bytes& key, const Bytes& msg, HashType type) |
| 103 | +{ |
| 104 | + HMAC hmac(key, type); |
| 105 | + Error err = hmac.Write(msg); |
| 106 | + if (err != Error::None) { |
| 107 | + return std::unexpected<Error>(err); |
123 | 108 | }
|
| 109 | + |
| 110 | + return hmac.Sum(); |
124 | 111 | }
|
0 commit comments