|
| 1 | +#!/usr/bin/env python3 |
| 2 | +import itertools |
| 3 | +import sys |
| 4 | + |
| 5 | +MASK = 0xffffff |
| 6 | + |
| 7 | + |
| 8 | +def forward(key, blk): |
| 9 | + assert tuple(map(len, (key, blk))) == (9, 6) |
| 10 | + def S(j, v): return (v << j | (v & MASK) >> 24-j) & MASK |
| 11 | + ws = blk[:3], blk[3:], key[:3], key[3:6], key[6:] |
| 12 | + x, y, l1, l0, k0 = (int.from_bytes(w, 'big') for w in ws) |
| 13 | + l, k = [l0, l1], [k0] |
| 14 | + for i in range(21): |
| 15 | + l.append((S(16, l[i]) + k[i] ^ i) & MASK) |
| 16 | + k.append(S(3, k[i]) ^ l[-1]) |
| 17 | + for i in range(22): |
| 18 | + x = (S(16, x) + y ^ k[i]) & MASK |
| 19 | + y = (S(3, y) ^ x) & MASK |
| 20 | + return b''.join(z.to_bytes(3, 'big') for z in (x, y)) |
| 21 | + |
| 22 | + |
| 23 | +def backward(key, cipher): |
| 24 | + assert tuple(map(len, (key, cipher))) == (9, 6) |
| 25 | + def S(j, v): return (v << j | (v & MASK) >> 24-j) & MASK |
| 26 | + ws = cipher[:3], cipher[3:], key[:3], key[3:6], key[6:] |
| 27 | + x, y, l1, l0, k0 = (int.from_bytes(w, 'big') for w in ws) |
| 28 | + l, k = [l0, l1], [k0] |
| 29 | + for i in range(21): |
| 30 | + l.append((S(16, l[i]) + k[i] ^ i) & MASK) |
| 31 | + k.append(S(3, k[i]) ^ l[-1]) |
| 32 | + for i in range(21, -1, -1): |
| 33 | + y = S(21, y ^ x) |
| 34 | + x = S(8, (x ^ k[i]) - y) |
| 35 | + x, y = (z & 0xffffff for z in (x, y)) |
| 36 | + return b''.join(z.to_bytes(3, 'big') for z in (x, y)) |
| 37 | + |
| 38 | + |
| 39 | +# did I implement this correctly? |
| 40 | +assert forward(*map(bytes.fromhex, ('1211100a0908020100', |
| 41 | + '20796c6c6172'))) == b'\xc0\x49\xa5\x38\x5a\xdc' |
| 42 | + |
| 43 | + |
| 44 | +def H(m): |
| 45 | + s = bytes(6) |
| 46 | + v = m + bytes(-len(m) % 9) + len(m).to_bytes(9, 'big') |
| 47 | + for i in range(0, len(v), 9): |
| 48 | + s = forward(v[i:i+9], s) |
| 49 | + return s |
| 50 | + |
| 51 | + |
| 52 | +if len(sys.argv) < 2: |
| 53 | + print(f"Usage: python3 {sys.argv[0]} <hex>") |
| 54 | + exit(1) |
| 55 | + |
| 56 | +target = bytes.fromhex(sys.argv[1]) |
| 57 | + |
| 58 | + |
| 59 | +key = (18).to_bytes(9, 'big') |
| 60 | + |
| 61 | +start = bytes(6) |
| 62 | +end = backward(key, target) |
| 63 | + |
| 64 | +assert(forward(key, end) == target) |
| 65 | + |
| 66 | +forward_dict = {} |
| 67 | +backward_dict = {} |
| 68 | + |
| 69 | +all_bytes = [i.to_bytes(1, 'big') for i in range(256)] |
| 70 | +for k in itertools.product(all_bytes, repeat=9): |
| 71 | + key = b''.join(k) |
| 72 | + f = forward(key, start) |
| 73 | + forward_dict[f] = key |
| 74 | + if f in backward_dict: |
| 75 | + ans = key + backward_dict[f] |
| 76 | + break |
| 77 | + b = backward(key, end) |
| 78 | + backward_dict[b] = key |
| 79 | + if b in forward_dict: |
| 80 | + ans = forward_dict[b] + key |
| 81 | + break |
| 82 | + |
| 83 | +print(ans.hex()) |
| 84 | +assert H(ans) == target |
0 commit comments