| # This file is dual licensed under the terms of the Apache License, Version |
| # 2.0, and the BSD License. See the LICENSE file in the root of this repository |
| # for complete details. |
| |
| from __future__ import absolute_import, division, print_function |
| |
| import binascii |
| import os |
| import re |
| import struct |
| |
| import six |
| |
| from cryptography import utils |
| from cryptography.exceptions import UnsupportedAlgorithm |
| from cryptography.hazmat.backends import _get_backend |
| from cryptography.hazmat.primitives.asymmetric import dsa, ec, ed25519, rsa |
| from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes |
| from cryptography.hazmat.primitives.serialization import ( |
| Encoding, |
| NoEncryption, |
| PrivateFormat, |
| PublicFormat, |
| ) |
| |
| try: |
| from bcrypt import kdf as _bcrypt_kdf |
| |
| _bcrypt_supported = True |
| except ImportError: |
| _bcrypt_supported = False |
| |
| def _bcrypt_kdf(*args, **kwargs): |
| raise UnsupportedAlgorithm("Need bcrypt module") |
| |
| |
| try: |
| from base64 import encodebytes as _base64_encode |
| except ImportError: |
| from base64 import encodestring as _base64_encode |
| |
| _SSH_ED25519 = b"ssh-ed25519" |
| _SSH_RSA = b"ssh-rsa" |
| _SSH_DSA = b"ssh-dss" |
| _ECDSA_NISTP256 = b"ecdsa-sha2-nistp256" |
| _ECDSA_NISTP384 = b"ecdsa-sha2-nistp384" |
| _ECDSA_NISTP521 = b"ecdsa-sha2-nistp521" |
| _CERT_SUFFIX = b"-cert-v01@openssh.com" |
| |
| _SSH_PUBKEY_RC = re.compile(br"\A(\S+)[ \t]+(\S+)") |
| _SK_MAGIC = b"openssh-key-v1\0" |
| _SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----" |
| _SK_END = b"-----END OPENSSH PRIVATE KEY-----" |
| _BCRYPT = b"bcrypt" |
| _NONE = b"none" |
| _DEFAULT_CIPHER = b"aes256-ctr" |
| _DEFAULT_ROUNDS = 16 |
| _MAX_PASSWORD = 72 |
| |
| # re is only way to work on bytes-like data |
| _PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL) |
| |
| # padding for max blocksize |
| _PADDING = memoryview(bytearray(range(1, 1 + 16))) |
| |
| # ciphers that are actually used in key wrapping |
| _SSH_CIPHERS = { |
| b"aes256-ctr": (algorithms.AES, 32, modes.CTR, 16), |
| b"aes256-cbc": (algorithms.AES, 32, modes.CBC, 16), |
| } |
| |
| # map local curve name to key type |
| _ECDSA_KEY_TYPE = { |
| "secp256r1": _ECDSA_NISTP256, |
| "secp384r1": _ECDSA_NISTP384, |
| "secp521r1": _ECDSA_NISTP521, |
| } |
| |
| _U32 = struct.Struct(b">I") |
| _U64 = struct.Struct(b">Q") |
| |
| |
| def _ecdsa_key_type(public_key): |
| """Return SSH key_type and curve_name for private key.""" |
| curve = public_key.curve |
| if curve.name not in _ECDSA_KEY_TYPE: |
| raise ValueError( |
| "Unsupported curve for ssh private key: %r" % curve.name |
| ) |
| return _ECDSA_KEY_TYPE[curve.name] |
| |
| |
| def _ssh_pem_encode(data, prefix=_SK_START + b"\n", suffix=_SK_END + b"\n"): |
| return b"".join([prefix, _base64_encode(data), suffix]) |
| |
| |
| def _check_block_size(data, block_len): |
| """Require data to be full blocks""" |
| if not data or len(data) % block_len != 0: |
| raise ValueError("Corrupt data: missing padding") |
| |
| |
| def _check_empty(data): |
| """All data should have been parsed.""" |
| if data: |
| raise ValueError("Corrupt data: unparsed data") |
| |
| |
| def _init_cipher(ciphername, password, salt, rounds, backend): |
| """Generate key + iv and return cipher.""" |
| if not password: |
| raise ValueError("Key is password-protected.") |
| |
| algo, key_len, mode, iv_len = _SSH_CIPHERS[ciphername] |
| seed = _bcrypt_kdf(password, salt, key_len + iv_len, rounds, True) |
| return Cipher(algo(seed[:key_len]), mode(seed[key_len:]), backend) |
| |
| |
| def _get_u32(data): |
| """Uint32""" |
| if len(data) < 4: |
| raise ValueError("Invalid data") |
| return _U32.unpack(data[:4])[0], data[4:] |
| |
| |
| def _get_u64(data): |
| """Uint64""" |
| if len(data) < 8: |
| raise ValueError("Invalid data") |
| return _U64.unpack(data[:8])[0], data[8:] |
| |
| |
| def _get_sshstr(data): |
| """Bytes with u32 length prefix""" |
| n, data = _get_u32(data) |
| if n > len(data): |
| raise ValueError("Invalid data") |
| return data[:n], data[n:] |
| |
| |
| def _get_mpint(data): |
| """Big integer.""" |
| val, data = _get_sshstr(data) |
| if val and six.indexbytes(val, 0) > 0x7F: |
| raise ValueError("Invalid data") |
| return utils.int_from_bytes(val, "big"), data |
| |
| |
| def _to_mpint(val): |
| """Storage format for signed bigint.""" |
| if val < 0: |
| raise ValueError("negative mpint not allowed") |
| if not val: |
| return b"" |
| nbytes = (val.bit_length() + 8) // 8 |
| return utils.int_to_bytes(val, nbytes) |
| |
| |
| class _FragList(object): |
| """Build recursive structure without data copy.""" |
| |
| def __init__(self, init=None): |
| self.flist = [] |
| if init: |
| self.flist.extend(init) |
| |
| def put_raw(self, val): |
| """Add plain bytes""" |
| self.flist.append(val) |
| |
| def put_u32(self, val): |
| """Big-endian uint32""" |
| self.flist.append(_U32.pack(val)) |
| |
| def put_sshstr(self, val): |
| """Bytes prefixed with u32 length""" |
| if isinstance(val, (bytes, memoryview, bytearray)): |
| self.put_u32(len(val)) |
| self.flist.append(val) |
| else: |
| self.put_u32(val.size()) |
| self.flist.extend(val.flist) |
| |
| def put_mpint(self, val): |
| """Big-endian bigint prefixed with u32 length""" |
| self.put_sshstr(_to_mpint(val)) |
| |
| def size(self): |
| """Current number of bytes""" |
| return sum(map(len, self.flist)) |
| |
| def render(self, dstbuf, pos=0): |
| """Write into bytearray""" |
| for frag in self.flist: |
| flen = len(frag) |
| start, pos = pos, pos + flen |
| dstbuf[start:pos] = frag |
| return pos |
| |
| def tobytes(self): |
| """Return as bytes""" |
| buf = memoryview(bytearray(self.size())) |
| self.render(buf) |
| return buf.tobytes() |
| |
| |
| class _SSHFormatRSA(object): |
| """Format for RSA keys. |
| |
| Public: |
| mpint e, n |
| Private: |
| mpint n, e, d, iqmp, p, q |
| """ |
| |
| def get_public(self, data): |
| """RSA public fields""" |
| e, data = _get_mpint(data) |
| n, data = _get_mpint(data) |
| return (e, n), data |
| |
| def load_public(self, key_type, data, backend): |
| """Make RSA public key from data.""" |
| (e, n), data = self.get_public(data) |
| public_numbers = rsa.RSAPublicNumbers(e, n) |
| public_key = public_numbers.public_key(backend) |
| return public_key, data |
| |
| def load_private(self, data, pubfields, backend): |
| """Make RSA private key from data.""" |
| n, data = _get_mpint(data) |
| e, data = _get_mpint(data) |
| d, data = _get_mpint(data) |
| iqmp, data = _get_mpint(data) |
| p, data = _get_mpint(data) |
| q, data = _get_mpint(data) |
| |
| if (e, n) != pubfields: |
| raise ValueError("Corrupt data: rsa field mismatch") |
| dmp1 = rsa.rsa_crt_dmp1(d, p) |
| dmq1 = rsa.rsa_crt_dmq1(d, q) |
| public_numbers = rsa.RSAPublicNumbers(e, n) |
| private_numbers = rsa.RSAPrivateNumbers( |
| p, q, d, dmp1, dmq1, iqmp, public_numbers |
| ) |
| private_key = private_numbers.private_key(backend) |
| return private_key, data |
| |
| def encode_public(self, public_key, f_pub): |
| """Write RSA public key""" |
| pubn = public_key.public_numbers() |
| f_pub.put_mpint(pubn.e) |
| f_pub.put_mpint(pubn.n) |
| |
| def encode_private(self, private_key, f_priv): |
| """Write RSA private key""" |
| private_numbers = private_key.private_numbers() |
| public_numbers = private_numbers.public_numbers |
| |
| f_priv.put_mpint(public_numbers.n) |
| f_priv.put_mpint(public_numbers.e) |
| |
| f_priv.put_mpint(private_numbers.d) |
| f_priv.put_mpint(private_numbers.iqmp) |
| f_priv.put_mpint(private_numbers.p) |
| f_priv.put_mpint(private_numbers.q) |
| |
| |
| class _SSHFormatDSA(object): |
| """Format for DSA keys. |
| |
| Public: |
| mpint p, q, g, y |
| Private: |
| mpint p, q, g, y, x |
| """ |
| |
| def get_public(self, data): |
| """DSA public fields""" |
| p, data = _get_mpint(data) |
| q, data = _get_mpint(data) |
| g, data = _get_mpint(data) |
| y, data = _get_mpint(data) |
| return (p, q, g, y), data |
| |
| def load_public(self, key_type, data, backend): |
| """Make DSA public key from data.""" |
| (p, q, g, y), data = self.get_public(data) |
| parameter_numbers = dsa.DSAParameterNumbers(p, q, g) |
| public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) |
| self._validate(public_numbers) |
| public_key = public_numbers.public_key(backend) |
| return public_key, data |
| |
| def load_private(self, data, pubfields, backend): |
| """Make DSA private key from data.""" |
| (p, q, g, y), data = self.get_public(data) |
| x, data = _get_mpint(data) |
| |
| if (p, q, g, y) != pubfields: |
| raise ValueError("Corrupt data: dsa field mismatch") |
| parameter_numbers = dsa.DSAParameterNumbers(p, q, g) |
| public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) |
| self._validate(public_numbers) |
| private_numbers = dsa.DSAPrivateNumbers(x, public_numbers) |
| private_key = private_numbers.private_key(backend) |
| return private_key, data |
| |
| def encode_public(self, public_key, f_pub): |
| """Write DSA public key""" |
| public_numbers = public_key.public_numbers() |
| parameter_numbers = public_numbers.parameter_numbers |
| self._validate(public_numbers) |
| |
| f_pub.put_mpint(parameter_numbers.p) |
| f_pub.put_mpint(parameter_numbers.q) |
| f_pub.put_mpint(parameter_numbers.g) |
| f_pub.put_mpint(public_numbers.y) |
| |
| def encode_private(self, private_key, f_priv): |
| """Write DSA private key""" |
| self.encode_public(private_key.public_key(), f_priv) |
| f_priv.put_mpint(private_key.private_numbers().x) |
| |
| def _validate(self, public_numbers): |
| parameter_numbers = public_numbers.parameter_numbers |
| if parameter_numbers.p.bit_length() != 1024: |
| raise ValueError("SSH supports only 1024 bit DSA keys") |
| |
| |
| class _SSHFormatECDSA(object): |
| """Format for ECDSA keys. |
| |
| Public: |
| str curve |
| bytes point |
| Private: |
| str curve |
| bytes point |
| mpint secret |
| """ |
| |
| def __init__(self, ssh_curve_name, curve): |
| self.ssh_curve_name = ssh_curve_name |
| self.curve = curve |
| |
| def get_public(self, data): |
| """ECDSA public fields""" |
| curve, data = _get_sshstr(data) |
| point, data = _get_sshstr(data) |
| if curve != self.ssh_curve_name: |
| raise ValueError("Curve name mismatch") |
| if six.indexbytes(point, 0) != 4: |
| raise NotImplementedError("Need uncompressed point") |
| return (curve, point), data |
| |
| def load_public(self, key_type, data, backend): |
| """Make ECDSA public key from data.""" |
| (curve_name, point), data = self.get_public(data) |
| public_key = ec.EllipticCurvePublicKey.from_encoded_point( |
| self.curve, point.tobytes() |
| ) |
| return public_key, data |
| |
| def load_private(self, data, pubfields, backend): |
| """Make ECDSA private key from data.""" |
| (curve_name, point), data = self.get_public(data) |
| secret, data = _get_mpint(data) |
| |
| if (curve_name, point) != pubfields: |
| raise ValueError("Corrupt data: ecdsa field mismatch") |
| private_key = ec.derive_private_key(secret, self.curve, backend) |
| return private_key, data |
| |
| def encode_public(self, public_key, f_pub): |
| """Write ECDSA public key""" |
| point = public_key.public_bytes( |
| Encoding.X962, PublicFormat.UncompressedPoint |
| ) |
| f_pub.put_sshstr(self.ssh_curve_name) |
| f_pub.put_sshstr(point) |
| |
| def encode_private(self, private_key, f_priv): |
| """Write ECDSA private key""" |
| public_key = private_key.public_key() |
| private_numbers = private_key.private_numbers() |
| |
| self.encode_public(public_key, f_priv) |
| f_priv.put_mpint(private_numbers.private_value) |
| |
| |
| class _SSHFormatEd25519(object): |
| """Format for Ed25519 keys. |
| |
| Public: |
| bytes point |
| Private: |
| bytes point |
| bytes secret_and_point |
| """ |
| |
| def get_public(self, data): |
| """Ed25519 public fields""" |
| point, data = _get_sshstr(data) |
| return (point,), data |
| |
| def load_public(self, key_type, data, backend): |
| """Make Ed25519 public key from data.""" |
| (point,), data = self.get_public(data) |
| public_key = ed25519.Ed25519PublicKey.from_public_bytes( |
| point.tobytes() |
| ) |
| return public_key, data |
| |
| def load_private(self, data, pubfields, backend): |
| """Make Ed25519 private key from data.""" |
| (point,), data = self.get_public(data) |
| keypair, data = _get_sshstr(data) |
| |
| secret = keypair[:32] |
| point2 = keypair[32:] |
| if point != point2 or (point,) != pubfields: |
| raise ValueError("Corrupt data: ed25519 field mismatch") |
| private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret) |
| return private_key, data |
| |
| def encode_public(self, public_key, f_pub): |
| """Write Ed25519 public key""" |
| raw_public_key = public_key.public_bytes( |
| Encoding.Raw, PublicFormat.Raw |
| ) |
| f_pub.put_sshstr(raw_public_key) |
| |
| def encode_private(self, private_key, f_priv): |
| """Write Ed25519 private key""" |
| public_key = private_key.public_key() |
| raw_private_key = private_key.private_bytes( |
| Encoding.Raw, PrivateFormat.Raw, NoEncryption() |
| ) |
| raw_public_key = public_key.public_bytes( |
| Encoding.Raw, PublicFormat.Raw |
| ) |
| f_keypair = _FragList([raw_private_key, raw_public_key]) |
| |
| self.encode_public(public_key, f_priv) |
| f_priv.put_sshstr(f_keypair) |
| |
| |
| _KEY_FORMATS = { |
| _SSH_RSA: _SSHFormatRSA(), |
| _SSH_DSA: _SSHFormatDSA(), |
| _SSH_ED25519: _SSHFormatEd25519(), |
| _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()), |
| _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()), |
| _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()), |
| } |
| |
| |
| def _lookup_kformat(key_type): |
| """Return valid format or throw error""" |
| if not isinstance(key_type, bytes): |
| key_type = memoryview(key_type).tobytes() |
| if key_type in _KEY_FORMATS: |
| return _KEY_FORMATS[key_type] |
| raise UnsupportedAlgorithm("Unsupported key type: %r" % key_type) |
| |
| |
| def load_ssh_private_key(data, password, backend=None): |
| """Load private key from OpenSSH custom encoding.""" |
| utils._check_byteslike("data", data) |
| backend = _get_backend(backend) |
| if password is not None: |
| utils._check_bytes("password", password) |
| |
| m = _PEM_RC.search(data) |
| if not m: |
| raise ValueError("Not OpenSSH private key format") |
| p1 = m.start(1) |
| p2 = m.end(1) |
| data = binascii.a2b_base64(memoryview(data)[p1:p2]) |
| if not data.startswith(_SK_MAGIC): |
| raise ValueError("Not OpenSSH private key format") |
| data = memoryview(data)[len(_SK_MAGIC) :] |
| |
| # parse header |
| ciphername, data = _get_sshstr(data) |
| kdfname, data = _get_sshstr(data) |
| kdfoptions, data = _get_sshstr(data) |
| nkeys, data = _get_u32(data) |
| if nkeys != 1: |
| raise ValueError("Only one key supported") |
| |
| # load public key data |
| pubdata, data = _get_sshstr(data) |
| pub_key_type, pubdata = _get_sshstr(pubdata) |
| kformat = _lookup_kformat(pub_key_type) |
| pubfields, pubdata = kformat.get_public(pubdata) |
| _check_empty(pubdata) |
| |
| # load secret data |
| edata, data = _get_sshstr(data) |
| _check_empty(data) |
| |
| if (ciphername, kdfname) != (_NONE, _NONE): |
| ciphername = ciphername.tobytes() |
| if ciphername not in _SSH_CIPHERS: |
| raise UnsupportedAlgorithm("Unsupported cipher: %r" % ciphername) |
| if kdfname != _BCRYPT: |
| raise UnsupportedAlgorithm("Unsupported KDF: %r" % kdfname) |
| blklen = _SSH_CIPHERS[ciphername][3] |
| _check_block_size(edata, blklen) |
| salt, kbuf = _get_sshstr(kdfoptions) |
| rounds, kbuf = _get_u32(kbuf) |
| _check_empty(kbuf) |
| ciph = _init_cipher( |
| ciphername, password, salt.tobytes(), rounds, backend |
| ) |
| edata = memoryview(ciph.decryptor().update(edata)) |
| else: |
| blklen = 8 |
| _check_block_size(edata, blklen) |
| ck1, edata = _get_u32(edata) |
| ck2, edata = _get_u32(edata) |
| if ck1 != ck2: |
| raise ValueError("Corrupt data: broken checksum") |
| |
| # load per-key struct |
| key_type, edata = _get_sshstr(edata) |
| if key_type != pub_key_type: |
| raise ValueError("Corrupt data: key type mismatch") |
| private_key, edata = kformat.load_private(edata, pubfields, backend) |
| comment, edata = _get_sshstr(edata) |
| |
| # yes, SSH does padding check *after* all other parsing is done. |
| # need to follow as it writes zero-byte padding too. |
| if edata != _PADDING[: len(edata)]: |
| raise ValueError("Corrupt data: invalid padding") |
| |
| return private_key |
| |
| |
| def serialize_ssh_private_key(private_key, password=None): |
| """Serialize private key with OpenSSH custom encoding.""" |
| if password is not None: |
| utils._check_bytes("password", password) |
| if password and len(password) > _MAX_PASSWORD: |
| raise ValueError( |
| "Passwords longer than 72 bytes are not supported by " |
| "OpenSSH private key format" |
| ) |
| |
| if isinstance(private_key, ec.EllipticCurvePrivateKey): |
| key_type = _ecdsa_key_type(private_key.public_key()) |
| elif isinstance(private_key, rsa.RSAPrivateKey): |
| key_type = _SSH_RSA |
| elif isinstance(private_key, dsa.DSAPrivateKey): |
| key_type = _SSH_DSA |
| elif isinstance(private_key, ed25519.Ed25519PrivateKey): |
| key_type = _SSH_ED25519 |
| else: |
| raise ValueError("Unsupported key type") |
| kformat = _lookup_kformat(key_type) |
| |
| # setup parameters |
| f_kdfoptions = _FragList() |
| if password: |
| ciphername = _DEFAULT_CIPHER |
| blklen = _SSH_CIPHERS[ciphername][3] |
| kdfname = _BCRYPT |
| rounds = _DEFAULT_ROUNDS |
| salt = os.urandom(16) |
| f_kdfoptions.put_sshstr(salt) |
| f_kdfoptions.put_u32(rounds) |
| backend = _get_backend(None) |
| ciph = _init_cipher(ciphername, password, salt, rounds, backend) |
| else: |
| ciphername = kdfname = _NONE |
| blklen = 8 |
| ciph = None |
| nkeys = 1 |
| checkval = os.urandom(4) |
| comment = b"" |
| |
| # encode public and private parts together |
| f_public_key = _FragList() |
| f_public_key.put_sshstr(key_type) |
| kformat.encode_public(private_key.public_key(), f_public_key) |
| |
| f_secrets = _FragList([checkval, checkval]) |
| f_secrets.put_sshstr(key_type) |
| kformat.encode_private(private_key, f_secrets) |
| f_secrets.put_sshstr(comment) |
| f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)]) |
| |
| # top-level structure |
| f_main = _FragList() |
| f_main.put_raw(_SK_MAGIC) |
| f_main.put_sshstr(ciphername) |
| f_main.put_sshstr(kdfname) |
| f_main.put_sshstr(f_kdfoptions) |
| f_main.put_u32(nkeys) |
| f_main.put_sshstr(f_public_key) |
| f_main.put_sshstr(f_secrets) |
| |
| # copy result info bytearray |
| slen = f_secrets.size() |
| mlen = f_main.size() |
| buf = memoryview(bytearray(mlen + blklen)) |
| f_main.render(buf) |
| ofs = mlen - slen |
| |
| # encrypt in-place |
| if ciph is not None: |
| ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:]) |
| |
| txt = _ssh_pem_encode(buf[:mlen]) |
| buf[ofs:mlen] = bytearray(slen) |
| return txt |
| |
| |
| def load_ssh_public_key(data, backend=None): |
| """Load public key from OpenSSH one-line format.""" |
| backend = _get_backend(backend) |
| utils._check_byteslike("data", data) |
| |
| m = _SSH_PUBKEY_RC.match(data) |
| if not m: |
| raise ValueError("Invalid line format") |
| key_type = orig_key_type = m.group(1) |
| key_body = m.group(2) |
| with_cert = False |
| if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]: |
| with_cert = True |
| key_type = key_type[: -len(_CERT_SUFFIX)] |
| kformat = _lookup_kformat(key_type) |
| |
| try: |
| data = memoryview(binascii.a2b_base64(key_body)) |
| except (TypeError, binascii.Error): |
| raise ValueError("Invalid key format") |
| |
| inner_key_type, data = _get_sshstr(data) |
| if inner_key_type != orig_key_type: |
| raise ValueError("Invalid key format") |
| if with_cert: |
| nonce, data = _get_sshstr(data) |
| public_key, data = kformat.load_public(key_type, data, backend) |
| if with_cert: |
| serial, data = _get_u64(data) |
| cctype, data = _get_u32(data) |
| key_id, data = _get_sshstr(data) |
| principals, data = _get_sshstr(data) |
| valid_after, data = _get_u64(data) |
| valid_before, data = _get_u64(data) |
| crit_options, data = _get_sshstr(data) |
| extensions, data = _get_sshstr(data) |
| reserved, data = _get_sshstr(data) |
| sig_key, data = _get_sshstr(data) |
| signature, data = _get_sshstr(data) |
| _check_empty(data) |
| return public_key |
| |
| |
| def serialize_ssh_public_key(public_key): |
| """One-line public key format for OpenSSH""" |
| if isinstance(public_key, ec.EllipticCurvePublicKey): |
| key_type = _ecdsa_key_type(public_key) |
| elif isinstance(public_key, rsa.RSAPublicKey): |
| key_type = _SSH_RSA |
| elif isinstance(public_key, dsa.DSAPublicKey): |
| key_type = _SSH_DSA |
| elif isinstance(public_key, ed25519.Ed25519PublicKey): |
| key_type = _SSH_ED25519 |
| else: |
| raise ValueError("Unsupported key type") |
| kformat = _lookup_kformat(key_type) |
| |
| f_pub = _FragList() |
| f_pub.put_sshstr(key_type) |
| kformat.encode_public(public_key, f_pub) |
| |
| pub = binascii.b2a_base64(f_pub.tobytes()).strip() |
| return b"".join([key_type, b" ", pub]) |