blob: 5ecae59f8aa67502a8a27b135e9e993347535f1c [file] [log] [blame]
# This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.
from __future__ import absolute_import, division, print_function
import binascii
import os
import re
import struct
import six
from cryptography import utils
from cryptography.exceptions import UnsupportedAlgorithm
from cryptography.hazmat.backends import _get_backend
from cryptography.hazmat.primitives.asymmetric import dsa, ec, ed25519, rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.serialization import (
Encoding,
NoEncryption,
PrivateFormat,
PublicFormat,
)
try:
from bcrypt import kdf as _bcrypt_kdf
_bcrypt_supported = True
except ImportError:
_bcrypt_supported = False
def _bcrypt_kdf(*args, **kwargs):
raise UnsupportedAlgorithm("Need bcrypt module")
try:
from base64 import encodebytes as _base64_encode
except ImportError:
from base64 import encodestring as _base64_encode
_SSH_ED25519 = b"ssh-ed25519"
_SSH_RSA = b"ssh-rsa"
_SSH_DSA = b"ssh-dss"
_ECDSA_NISTP256 = b"ecdsa-sha2-nistp256"
_ECDSA_NISTP384 = b"ecdsa-sha2-nistp384"
_ECDSA_NISTP521 = b"ecdsa-sha2-nistp521"
_CERT_SUFFIX = b"-cert-v01@openssh.com"
_SSH_PUBKEY_RC = re.compile(br"\A(\S+)[ \t]+(\S+)")
_SK_MAGIC = b"openssh-key-v1\0"
_SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----"
_SK_END = b"-----END OPENSSH PRIVATE KEY-----"
_BCRYPT = b"bcrypt"
_NONE = b"none"
_DEFAULT_CIPHER = b"aes256-ctr"
_DEFAULT_ROUNDS = 16
_MAX_PASSWORD = 72
# re is only way to work on bytes-like data
_PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL)
# padding for max blocksize
_PADDING = memoryview(bytearray(range(1, 1 + 16)))
# ciphers that are actually used in key wrapping
_SSH_CIPHERS = {
b"aes256-ctr": (algorithms.AES, 32, modes.CTR, 16),
b"aes256-cbc": (algorithms.AES, 32, modes.CBC, 16),
}
# map local curve name to key type
_ECDSA_KEY_TYPE = {
"secp256r1": _ECDSA_NISTP256,
"secp384r1": _ECDSA_NISTP384,
"secp521r1": _ECDSA_NISTP521,
}
_U32 = struct.Struct(b">I")
_U64 = struct.Struct(b">Q")
def _ecdsa_key_type(public_key):
"""Return SSH key_type and curve_name for private key."""
curve = public_key.curve
if curve.name not in _ECDSA_KEY_TYPE:
raise ValueError(
"Unsupported curve for ssh private key: %r" % curve.name
)
return _ECDSA_KEY_TYPE[curve.name]
def _ssh_pem_encode(data, prefix=_SK_START + b"\n", suffix=_SK_END + b"\n"):
return b"".join([prefix, _base64_encode(data), suffix])
def _check_block_size(data, block_len):
"""Require data to be full blocks"""
if not data or len(data) % block_len != 0:
raise ValueError("Corrupt data: missing padding")
def _check_empty(data):
"""All data should have been parsed."""
if data:
raise ValueError("Corrupt data: unparsed data")
def _init_cipher(ciphername, password, salt, rounds, backend):
"""Generate key + iv and return cipher."""
if not password:
raise ValueError("Key is password-protected.")
algo, key_len, mode, iv_len = _SSH_CIPHERS[ciphername]
seed = _bcrypt_kdf(password, salt, key_len + iv_len, rounds, True)
return Cipher(algo(seed[:key_len]), mode(seed[key_len:]), backend)
def _get_u32(data):
"""Uint32"""
if len(data) < 4:
raise ValueError("Invalid data")
return _U32.unpack(data[:4])[0], data[4:]
def _get_u64(data):
"""Uint64"""
if len(data) < 8:
raise ValueError("Invalid data")
return _U64.unpack(data[:8])[0], data[8:]
def _get_sshstr(data):
"""Bytes with u32 length prefix"""
n, data = _get_u32(data)
if n > len(data):
raise ValueError("Invalid data")
return data[:n], data[n:]
def _get_mpint(data):
"""Big integer."""
val, data = _get_sshstr(data)
if val and six.indexbytes(val, 0) > 0x7F:
raise ValueError("Invalid data")
return utils.int_from_bytes(val, "big"), data
def _to_mpint(val):
"""Storage format for signed bigint."""
if val < 0:
raise ValueError("negative mpint not allowed")
if not val:
return b""
nbytes = (val.bit_length() + 8) // 8
return utils.int_to_bytes(val, nbytes)
class _FragList(object):
"""Build recursive structure without data copy."""
def __init__(self, init=None):
self.flist = []
if init:
self.flist.extend(init)
def put_raw(self, val):
"""Add plain bytes"""
self.flist.append(val)
def put_u32(self, val):
"""Big-endian uint32"""
self.flist.append(_U32.pack(val))
def put_sshstr(self, val):
"""Bytes prefixed with u32 length"""
if isinstance(val, (bytes, memoryview, bytearray)):
self.put_u32(len(val))
self.flist.append(val)
else:
self.put_u32(val.size())
self.flist.extend(val.flist)
def put_mpint(self, val):
"""Big-endian bigint prefixed with u32 length"""
self.put_sshstr(_to_mpint(val))
def size(self):
"""Current number of bytes"""
return sum(map(len, self.flist))
def render(self, dstbuf, pos=0):
"""Write into bytearray"""
for frag in self.flist:
flen = len(frag)
start, pos = pos, pos + flen
dstbuf[start:pos] = frag
return pos
def tobytes(self):
"""Return as bytes"""
buf = memoryview(bytearray(self.size()))
self.render(buf)
return buf.tobytes()
class _SSHFormatRSA(object):
"""Format for RSA keys.
Public:
mpint e, n
Private:
mpint n, e, d, iqmp, p, q
"""
def get_public(self, data):
"""RSA public fields"""
e, data = _get_mpint(data)
n, data = _get_mpint(data)
return (e, n), data
def load_public(self, key_type, data, backend):
"""Make RSA public key from data."""
(e, n), data = self.get_public(data)
public_numbers = rsa.RSAPublicNumbers(e, n)
public_key = public_numbers.public_key(backend)
return public_key, data
def load_private(self, data, pubfields, backend):
"""Make RSA private key from data."""
n, data = _get_mpint(data)
e, data = _get_mpint(data)
d, data = _get_mpint(data)
iqmp, data = _get_mpint(data)
p, data = _get_mpint(data)
q, data = _get_mpint(data)
if (e, n) != pubfields:
raise ValueError("Corrupt data: rsa field mismatch")
dmp1 = rsa.rsa_crt_dmp1(d, p)
dmq1 = rsa.rsa_crt_dmq1(d, q)
public_numbers = rsa.RSAPublicNumbers(e, n)
private_numbers = rsa.RSAPrivateNumbers(
p, q, d, dmp1, dmq1, iqmp, public_numbers
)
private_key = private_numbers.private_key(backend)
return private_key, data
def encode_public(self, public_key, f_pub):
"""Write RSA public key"""
pubn = public_key.public_numbers()
f_pub.put_mpint(pubn.e)
f_pub.put_mpint(pubn.n)
def encode_private(self, private_key, f_priv):
"""Write RSA private key"""
private_numbers = private_key.private_numbers()
public_numbers = private_numbers.public_numbers
f_priv.put_mpint(public_numbers.n)
f_priv.put_mpint(public_numbers.e)
f_priv.put_mpint(private_numbers.d)
f_priv.put_mpint(private_numbers.iqmp)
f_priv.put_mpint(private_numbers.p)
f_priv.put_mpint(private_numbers.q)
class _SSHFormatDSA(object):
"""Format for DSA keys.
Public:
mpint p, q, g, y
Private:
mpint p, q, g, y, x
"""
def get_public(self, data):
"""DSA public fields"""
p, data = _get_mpint(data)
q, data = _get_mpint(data)
g, data = _get_mpint(data)
y, data = _get_mpint(data)
return (p, q, g, y), data
def load_public(self, key_type, data, backend):
"""Make DSA public key from data."""
(p, q, g, y), data = self.get_public(data)
parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
self._validate(public_numbers)
public_key = public_numbers.public_key(backend)
return public_key, data
def load_private(self, data, pubfields, backend):
"""Make DSA private key from data."""
(p, q, g, y), data = self.get_public(data)
x, data = _get_mpint(data)
if (p, q, g, y) != pubfields:
raise ValueError("Corrupt data: dsa field mismatch")
parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
self._validate(public_numbers)
private_numbers = dsa.DSAPrivateNumbers(x, public_numbers)
private_key = private_numbers.private_key(backend)
return private_key, data
def encode_public(self, public_key, f_pub):
"""Write DSA public key"""
public_numbers = public_key.public_numbers()
parameter_numbers = public_numbers.parameter_numbers
self._validate(public_numbers)
f_pub.put_mpint(parameter_numbers.p)
f_pub.put_mpint(parameter_numbers.q)
f_pub.put_mpint(parameter_numbers.g)
f_pub.put_mpint(public_numbers.y)
def encode_private(self, private_key, f_priv):
"""Write DSA private key"""
self.encode_public(private_key.public_key(), f_priv)
f_priv.put_mpint(private_key.private_numbers().x)
def _validate(self, public_numbers):
parameter_numbers = public_numbers.parameter_numbers
if parameter_numbers.p.bit_length() != 1024:
raise ValueError("SSH supports only 1024 bit DSA keys")
class _SSHFormatECDSA(object):
"""Format for ECDSA keys.
Public:
str curve
bytes point
Private:
str curve
bytes point
mpint secret
"""
def __init__(self, ssh_curve_name, curve):
self.ssh_curve_name = ssh_curve_name
self.curve = curve
def get_public(self, data):
"""ECDSA public fields"""
curve, data = _get_sshstr(data)
point, data = _get_sshstr(data)
if curve != self.ssh_curve_name:
raise ValueError("Curve name mismatch")
if six.indexbytes(point, 0) != 4:
raise NotImplementedError("Need uncompressed point")
return (curve, point), data
def load_public(self, key_type, data, backend):
"""Make ECDSA public key from data."""
(curve_name, point), data = self.get_public(data)
public_key = ec.EllipticCurvePublicKey.from_encoded_point(
self.curve, point.tobytes()
)
return public_key, data
def load_private(self, data, pubfields, backend):
"""Make ECDSA private key from data."""
(curve_name, point), data = self.get_public(data)
secret, data = _get_mpint(data)
if (curve_name, point) != pubfields:
raise ValueError("Corrupt data: ecdsa field mismatch")
private_key = ec.derive_private_key(secret, self.curve, backend)
return private_key, data
def encode_public(self, public_key, f_pub):
"""Write ECDSA public key"""
point = public_key.public_bytes(
Encoding.X962, PublicFormat.UncompressedPoint
)
f_pub.put_sshstr(self.ssh_curve_name)
f_pub.put_sshstr(point)
def encode_private(self, private_key, f_priv):
"""Write ECDSA private key"""
public_key = private_key.public_key()
private_numbers = private_key.private_numbers()
self.encode_public(public_key, f_priv)
f_priv.put_mpint(private_numbers.private_value)
class _SSHFormatEd25519(object):
"""Format for Ed25519 keys.
Public:
bytes point
Private:
bytes point
bytes secret_and_point
"""
def get_public(self, data):
"""Ed25519 public fields"""
point, data = _get_sshstr(data)
return (point,), data
def load_public(self, key_type, data, backend):
"""Make Ed25519 public key from data."""
(point,), data = self.get_public(data)
public_key = ed25519.Ed25519PublicKey.from_public_bytes(
point.tobytes()
)
return public_key, data
def load_private(self, data, pubfields, backend):
"""Make Ed25519 private key from data."""
(point,), data = self.get_public(data)
keypair, data = _get_sshstr(data)
secret = keypair[:32]
point2 = keypair[32:]
if point != point2 or (point,) != pubfields:
raise ValueError("Corrupt data: ed25519 field mismatch")
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret)
return private_key, data
def encode_public(self, public_key, f_pub):
"""Write Ed25519 public key"""
raw_public_key = public_key.public_bytes(
Encoding.Raw, PublicFormat.Raw
)
f_pub.put_sshstr(raw_public_key)
def encode_private(self, private_key, f_priv):
"""Write Ed25519 private key"""
public_key = private_key.public_key()
raw_private_key = private_key.private_bytes(
Encoding.Raw, PrivateFormat.Raw, NoEncryption()
)
raw_public_key = public_key.public_bytes(
Encoding.Raw, PublicFormat.Raw
)
f_keypair = _FragList([raw_private_key, raw_public_key])
self.encode_public(public_key, f_priv)
f_priv.put_sshstr(f_keypair)
_KEY_FORMATS = {
_SSH_RSA: _SSHFormatRSA(),
_SSH_DSA: _SSHFormatDSA(),
_SSH_ED25519: _SSHFormatEd25519(),
_ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()),
_ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()),
_ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()),
}
def _lookup_kformat(key_type):
"""Return valid format or throw error"""
if not isinstance(key_type, bytes):
key_type = memoryview(key_type).tobytes()
if key_type in _KEY_FORMATS:
return _KEY_FORMATS[key_type]
raise UnsupportedAlgorithm("Unsupported key type: %r" % key_type)
def load_ssh_private_key(data, password, backend=None):
"""Load private key from OpenSSH custom encoding."""
utils._check_byteslike("data", data)
backend = _get_backend(backend)
if password is not None:
utils._check_bytes("password", password)
m = _PEM_RC.search(data)
if not m:
raise ValueError("Not OpenSSH private key format")
p1 = m.start(1)
p2 = m.end(1)
data = binascii.a2b_base64(memoryview(data)[p1:p2])
if not data.startswith(_SK_MAGIC):
raise ValueError("Not OpenSSH private key format")
data = memoryview(data)[len(_SK_MAGIC) :]
# parse header
ciphername, data = _get_sshstr(data)
kdfname, data = _get_sshstr(data)
kdfoptions, data = _get_sshstr(data)
nkeys, data = _get_u32(data)
if nkeys != 1:
raise ValueError("Only one key supported")
# load public key data
pubdata, data = _get_sshstr(data)
pub_key_type, pubdata = _get_sshstr(pubdata)
kformat = _lookup_kformat(pub_key_type)
pubfields, pubdata = kformat.get_public(pubdata)
_check_empty(pubdata)
# load secret data
edata, data = _get_sshstr(data)
_check_empty(data)
if (ciphername, kdfname) != (_NONE, _NONE):
ciphername = ciphername.tobytes()
if ciphername not in _SSH_CIPHERS:
raise UnsupportedAlgorithm("Unsupported cipher: %r" % ciphername)
if kdfname != _BCRYPT:
raise UnsupportedAlgorithm("Unsupported KDF: %r" % kdfname)
blklen = _SSH_CIPHERS[ciphername][3]
_check_block_size(edata, blklen)
salt, kbuf = _get_sshstr(kdfoptions)
rounds, kbuf = _get_u32(kbuf)
_check_empty(kbuf)
ciph = _init_cipher(
ciphername, password, salt.tobytes(), rounds, backend
)
edata = memoryview(ciph.decryptor().update(edata))
else:
blklen = 8
_check_block_size(edata, blklen)
ck1, edata = _get_u32(edata)
ck2, edata = _get_u32(edata)
if ck1 != ck2:
raise ValueError("Corrupt data: broken checksum")
# load per-key struct
key_type, edata = _get_sshstr(edata)
if key_type != pub_key_type:
raise ValueError("Corrupt data: key type mismatch")
private_key, edata = kformat.load_private(edata, pubfields, backend)
comment, edata = _get_sshstr(edata)
# yes, SSH does padding check *after* all other parsing is done.
# need to follow as it writes zero-byte padding too.
if edata != _PADDING[: len(edata)]:
raise ValueError("Corrupt data: invalid padding")
return private_key
def serialize_ssh_private_key(private_key, password=None):
"""Serialize private key with OpenSSH custom encoding."""
if password is not None:
utils._check_bytes("password", password)
if password and len(password) > _MAX_PASSWORD:
raise ValueError(
"Passwords longer than 72 bytes are not supported by "
"OpenSSH private key format"
)
if isinstance(private_key, ec.EllipticCurvePrivateKey):
key_type = _ecdsa_key_type(private_key.public_key())
elif isinstance(private_key, rsa.RSAPrivateKey):
key_type = _SSH_RSA
elif isinstance(private_key, dsa.DSAPrivateKey):
key_type = _SSH_DSA
elif isinstance(private_key, ed25519.Ed25519PrivateKey):
key_type = _SSH_ED25519
else:
raise ValueError("Unsupported key type")
kformat = _lookup_kformat(key_type)
# setup parameters
f_kdfoptions = _FragList()
if password:
ciphername = _DEFAULT_CIPHER
blklen = _SSH_CIPHERS[ciphername][3]
kdfname = _BCRYPT
rounds = _DEFAULT_ROUNDS
salt = os.urandom(16)
f_kdfoptions.put_sshstr(salt)
f_kdfoptions.put_u32(rounds)
backend = _get_backend(None)
ciph = _init_cipher(ciphername, password, salt, rounds, backend)
else:
ciphername = kdfname = _NONE
blklen = 8
ciph = None
nkeys = 1
checkval = os.urandom(4)
comment = b""
# encode public and private parts together
f_public_key = _FragList()
f_public_key.put_sshstr(key_type)
kformat.encode_public(private_key.public_key(), f_public_key)
f_secrets = _FragList([checkval, checkval])
f_secrets.put_sshstr(key_type)
kformat.encode_private(private_key, f_secrets)
f_secrets.put_sshstr(comment)
f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)])
# top-level structure
f_main = _FragList()
f_main.put_raw(_SK_MAGIC)
f_main.put_sshstr(ciphername)
f_main.put_sshstr(kdfname)
f_main.put_sshstr(f_kdfoptions)
f_main.put_u32(nkeys)
f_main.put_sshstr(f_public_key)
f_main.put_sshstr(f_secrets)
# copy result info bytearray
slen = f_secrets.size()
mlen = f_main.size()
buf = memoryview(bytearray(mlen + blklen))
f_main.render(buf)
ofs = mlen - slen
# encrypt in-place
if ciph is not None:
ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:])
txt = _ssh_pem_encode(buf[:mlen])
buf[ofs:mlen] = bytearray(slen)
return txt
def load_ssh_public_key(data, backend=None):
"""Load public key from OpenSSH one-line format."""
backend = _get_backend(backend)
utils._check_byteslike("data", data)
m = _SSH_PUBKEY_RC.match(data)
if not m:
raise ValueError("Invalid line format")
key_type = orig_key_type = m.group(1)
key_body = m.group(2)
with_cert = False
if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]:
with_cert = True
key_type = key_type[: -len(_CERT_SUFFIX)]
kformat = _lookup_kformat(key_type)
try:
data = memoryview(binascii.a2b_base64(key_body))
except (TypeError, binascii.Error):
raise ValueError("Invalid key format")
inner_key_type, data = _get_sshstr(data)
if inner_key_type != orig_key_type:
raise ValueError("Invalid key format")
if with_cert:
nonce, data = _get_sshstr(data)
public_key, data = kformat.load_public(key_type, data, backend)
if with_cert:
serial, data = _get_u64(data)
cctype, data = _get_u32(data)
key_id, data = _get_sshstr(data)
principals, data = _get_sshstr(data)
valid_after, data = _get_u64(data)
valid_before, data = _get_u64(data)
crit_options, data = _get_sshstr(data)
extensions, data = _get_sshstr(data)
reserved, data = _get_sshstr(data)
sig_key, data = _get_sshstr(data)
signature, data = _get_sshstr(data)
_check_empty(data)
return public_key
def serialize_ssh_public_key(public_key):
"""One-line public key format for OpenSSH"""
if isinstance(public_key, ec.EllipticCurvePublicKey):
key_type = _ecdsa_key_type(public_key)
elif isinstance(public_key, rsa.RSAPublicKey):
key_type = _SSH_RSA
elif isinstance(public_key, dsa.DSAPublicKey):
key_type = _SSH_DSA
elif isinstance(public_key, ed25519.Ed25519PublicKey):
key_type = _SSH_ED25519
else:
raise ValueError("Unsupported key type")
kformat = _lookup_kformat(key_type)
f_pub = _FragList()
f_pub.put_sshstr(key_type)
kformat.encode_public(public_key, f_pub)
pub = binascii.b2a_base64(f_pub.tobytes()).strip()
return b"".join([key_type, b" ", pub])