|
| 1 | +import base64 |
| 2 | +import binascii |
| 3 | +import hashlib |
| 4 | +import struct |
| 5 | +import ecdsa |
| 6 | +import sys |
| 7 | + |
| 8 | +from Crypto.PublicKey import RSA, DSA |
| 9 | + |
| 10 | +INT_LEN = 4 |
| 11 | + |
| 12 | +class InvalidKeyException(Exception): |
| 13 | + pass |
| 14 | + |
| 15 | +class TooShortKeyException(InvalidKeyException): |
| 16 | + pass |
| 17 | + |
| 18 | +class InvalidTypeException(InvalidKeyException): |
| 19 | + pass |
| 20 | + |
| 21 | +class MalformedDataException(InvalidKeyException): |
| 22 | + pass |
| 23 | + |
| 24 | +class SSHKey: |
| 25 | + def __init__(self, keydata): |
| 26 | + self.keydata = keydata |
| 27 | + self.current_position = 0 |
| 28 | + self.decoded_key = None |
| 29 | + self.parse() |
| 30 | + |
| 31 | + def hash(self): |
| 32 | + """ Calculates fingerprint hash. |
| 33 | +
|
| 34 | + Shamelessly copied from http://stackoverflow.com/questions/6682815/deriving-an-ssh-fingerprint-from-a-public-key-in-python |
| 35 | + """ |
| 36 | + fp_plain = hashlib.md5(self.decoded_key).hexdigest() |
| 37 | + return ':'.join(a+b for a, b in zip(fp_plain[::2], fp_plain[1::2])) |
| 38 | + |
| 39 | + def unpack_by_int(self): |
| 40 | + """ Returns next data field. """ |
| 41 | + # Unpack length of data field |
| 42 | + try: |
| 43 | + requested_data_length = struct.unpack('>I', self.decoded_key[self.current_position:self.current_position+INT_LEN])[0] |
| 44 | + except struct.error: |
| 45 | + raise MalformedDataException("Unable to unpack %s bytes from the data" % INT_LEN) |
| 46 | + |
| 47 | + # Move pointer to the beginning of the data field |
| 48 | + self.current_position += INT_LEN |
| 49 | + remaining_data_length = len(self.decoded_key[self.current_position:]) |
| 50 | + |
| 51 | + if remaining_data_length < requested_data_length: |
| 52 | + raise MalformedDataException("Requested %s bytes, but only %s bytes available." % (requested_data_length, remaining_data_length)) |
| 53 | + |
| 54 | + next_data = self.decoded_key[self.current_position:self.current_position+requested_data_length] |
| 55 | + # Move pointer to the end of the data field |
| 56 | + self.current_position += requested_data_length |
| 57 | + return next_data |
| 58 | + |
| 59 | + @classmethod |
| 60 | + def parse_long(cls, data): |
| 61 | + """ Calculate two's complement """ |
| 62 | + if sys.version < '3': |
| 63 | + ret = long(0) |
| 64 | + for byte in data: |
| 65 | + ret = (ret << 8) + ord(byte) |
| 66 | + return ret |
| 67 | + ret = 0 |
| 68 | + for byte in data: |
| 69 | + ret = (ret << 8) + byte |
| 70 | + return ret |
| 71 | + |
| 72 | + |
| 73 | + @classmethod |
| 74 | + def split_key(cls, data): |
| 75 | + key_parts = data.strip().split(None, 3) |
| 76 | + if len(key_parts) < 2: # Key type and content are mandatory fields. |
| 77 | + raise InvalidKeyException("Unexpected key format: at least type and base64 encoded value is required") |
| 78 | + return key_parts |
| 79 | + |
| 80 | + @classmethod |
| 81 | + def decode_key(cls, pubkey_content): |
| 82 | + # Decode base64 coded part. |
| 83 | + try: |
| 84 | + decoded_key = base64.b64decode(pubkey_content.encode("ascii")) |
| 85 | + except (TypeError, binascii.Error): |
| 86 | + raise InvalidKeyException("Unable to decode the key") |
| 87 | + return decoded_key |
| 88 | + |
| 89 | + def parse(self): |
| 90 | + self.current_position = 0 |
| 91 | + key_parts = self.split_key(self.keydata) |
| 92 | + |
| 93 | + key_type = key_parts[0] |
| 94 | + pubkey_content = key_parts[1] |
| 95 | + |
| 96 | + self.decoded_key = self.decode_key(pubkey_content) |
| 97 | + |
| 98 | + # Check key type |
| 99 | + unpacked_key_type = self.unpack_by_int() |
| 100 | + if key_type != unpacked_key_type.decode(): |
| 101 | + raise InvalidTypeException("Keytype mismatch: %s != %s" % (key_type, unpacked_key_type)) |
| 102 | + |
| 103 | + self.key_type = unpacked_key_type |
| 104 | + |
| 105 | + if self.key_type == b"ssh-rsa": |
| 106 | + |
| 107 | + raw_e = self.unpack_by_int() |
| 108 | + raw_n = self.unpack_by_int() |
| 109 | + |
| 110 | + unpacked_e = self.parse_long(raw_e) |
| 111 | + unpacked_n = self.parse_long(raw_n) |
| 112 | + |
| 113 | + self.rsa = RSA.construct((unpacked_n, unpacked_e)) |
| 114 | + self.bits = self.rsa.size() + 1 |
| 115 | + |
| 116 | + elif self.key_type == b"ssh-dss": |
| 117 | + data_fields = {} |
| 118 | + for expected_length, item in [(309, "p"), (48, "q"), (309, "g"), (309, "y")]: |
| 119 | + data_fields[item] = self.parse_long(self.unpack_by_int()) |
| 120 | + item_length = len(str(data_fields[item])) |
| 121 | + if item_length != expected_length: |
| 122 | + raise MalformedDataException("DSA parameter %s has invalid length (%s, expected %s)" % (item, item_length, expected_length)) |
| 123 | + |
| 124 | + self.dsa = DSA.construct((data_fields["y"], data_fields["g"], data_fields["p"], data_fields["q"])) |
| 125 | + self.bits = self.dsa.size() + 1 |
| 126 | + if self.bits != 1024: |
| 127 | + raise InvalidKeyException("ssh-dss keys must be 1024 bits (was %s)" % self.bits) |
| 128 | + |
| 129 | + elif self.key_type.strip().startswith(b"ecdsa-sha"): |
| 130 | + curve_information = self.unpack_by_int() |
| 131 | + curve_data = {b"nistp256": (ecdsa.curves.NIST256p, hashlib.sha256), |
| 132 | + b"nistp192": (ecdsa.curves.NIST192p, hashlib.sha256), |
| 133 | + b"nistp224": (ecdsa.curves.NIST224p, hashlib.sha256), |
| 134 | + b"nistp384": (ecdsa.curves.NIST384p, hashlib.sha384), |
| 135 | + b"nistp521": (ecdsa.curves.NIST521p, hashlib.sha512)} |
| 136 | + if curve_information not in curve_data: |
| 137 | + raise NotImplementedError("Invalid curve type: %s" % curve_information) |
| 138 | + curve, hash_algorithm = curve_data[curve_information] |
| 139 | + |
| 140 | + data = self.unpack_by_int() |
| 141 | + |
| 142 | + key = ecdsa.VerifyingKey.from_string(data[1:], curve, hash_algorithm) |
| 143 | + self.bits = int(curve_information.replace(b"nistp", b"")) # TODO |
| 144 | + self.ecdsa = ecdsa |
| 145 | + else: |
| 146 | + raise NotImplementedError("Invalid key type: %s" % self.key_type) |
0 commit comments