Commit 8f289ba5 authored by Fabio Utzig's avatar Fabio Utzig Committed by Fabio Utzig
Browse files

imgtool: fix getpriv format type for keys



A previous change was added to allow the `getpriv` command to dump ec256
keys in both openssl and pkcs8. That PR did not touch other key file
types which resulted in errors using that command with RSA, X25519, etc.

This commit generalizes the passing of the `format` parameter, so each
key type can decide which format it allows a dump to be produced in,
and what default to use.

Fixes #1529

Signed-off-by: default avatarFabio Utzig <utzig@apache.org>
parent 4a748bfe
Loading
Loading
Loading
Loading
+12 −10
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.hashes import SHA256

from .general import KeyClass
from .privatebytes import PrivateBytesMixin


class ECDSAUsageError(Exception):
@@ -41,7 +42,7 @@ class ECDSA256P1Public(KeyClass):
                encoding=serialization.Encoding.PEM,
                format=serialization.PublicFormat.SubjectPublicKeyInfo)

    def get_private_bytes(self, minimal):
    def get_private_bytes(self, minimal, format):
        self._unsupported('get_private_bytes')

    def export_private(self, path, passwd=None):
@@ -85,7 +86,7 @@ class ECDSA256P1Public(KeyClass):
                        signature_algorithm=ec.ECDSA(SHA256()))


class ECDSA256P1(ECDSA256P1Public):
class ECDSA256P1(ECDSA256P1Public, PrivateBytesMixin):
    """
    Wrapper around an ECDSA private key.
    """
@@ -149,16 +150,17 @@ class ECDSA256P1(ECDSA256P1Public):

        return b

    def get_private_bytes(self, minimal, format):
        formats = {'pkcs8': serialization.PrivateFormat.PKCS8,
    _VALID_FORMATS = {
        'pkcs8': serialization.PrivateFormat.PKCS8,
        'openssl': serialization.PrivateFormat.TraditionalOpenSSL
    }
        priv = self.key.private_bytes(
                encoding=serialization.Encoding.DER,
                format=formats[format],
                encryption_algorithm=serialization.NoEncryption())
    _DEFAULT_FORMAT='pkcs8'

    def get_private_bytes(self, minimal, format):
        format, priv = self._get_private_bytes(minimal, format, ECDSAUsageError)
        if minimal:
            priv = self._build_minimal_ecdsa_privkey(priv, formats[format])
            priv = self._build_minimal_ecdsa_privkey(priv,
                                                     self._VALID_FORMATS[format])
        return priv

    def export_private(self, path, passwd=None):
+2 −2
Original line number Diff line number Diff line
@@ -34,7 +34,7 @@ class Ed25519Public(KeyClass):
                encoding=serialization.Encoding.DER,
                format=serialization.PublicFormat.SubjectPublicKeyInfo)

    def get_private_bytes(self, minimal):
    def get_private_bytes(self, minimal, format):
        self._unsupported('get_private_bytes')

    def export_private(self, path, passwd=None):
@@ -75,7 +75,7 @@ class Ed25519(Ed25519Public):
    def _get_public(self):
        return self.key.public_key()

    def get_private_bytes(self, minimal):
    def get_private_bytes(self, minimal, format):
        raise Ed25519UsageError("Operation not supported with {} keys".format(
            self.shortname()))

+16 −0
Original line number Diff line number Diff line
# SPDX-License-Identifier: Apache-2.0

from cryptography.hazmat.primitives import serialization


class PrivateBytesMixin():
    def _get_private_bytes(self, minimal, format, exclass):
        if format is None:
            format = self._DEFAULT_FORMAT
        if format not in self._VALID_FORMATS:
            raise exclass("{} does not support {}".format(
                self.shortname(), format))
        return format, self.key.private_bytes(
                encoding=serialization.Encoding.DER,
                format=self._VALID_FORMATS[format],
                encryption_algorithm=serialization.NoEncryption())
+10 −7
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ from cryptography.hazmat.primitives.asymmetric.padding import PSS, MGF1
from cryptography.hazmat.primitives.hashes import SHA256

from .general import KeyClass
from .privatebytes import PrivateBytesMixin


# Sizes that bootutil will recognize
@@ -49,7 +50,7 @@ class RSAPublic(KeyClass):
                encoding=serialization.Encoding.PEM,
                format=serialization.PublicFormat.SubjectPublicKeyInfo)

    def get_private_bytes(self, minimal):
    def get_private_bytes(self, minimal, format):
        self._unsupported('get_private_bytes')

    def export_private(self, path, passwd=None):
@@ -81,7 +82,7 @@ class RSAPublic(KeyClass):
                        algorithm=SHA256())


class RSA(RSAPublic):
class RSA(RSAPublic, PrivateBytesMixin):
    """
    Wrapper around an RSA key, with imgtool support.
    """
@@ -138,11 +139,13 @@ class RSA(RSAPublic):
        b[3] = (off - 4) & 0xff
        return b[:off]

    def get_private_bytes(self, minimal):
        priv = self.key.private_bytes(
                encoding=serialization.Encoding.DER,
                format=serialization.PrivateFormat.TraditionalOpenSSL,
                encryption_algorithm=serialization.NoEncryption())
    _VALID_FORMATS = {
        'openssl': serialization.PrivateFormat.TraditionalOpenSSL
    }
    _DEFAULT_FORMAT = 'openssl'

    def get_private_bytes(self, minimal, format):
        _, priv = self._get_private_bytes(minimal, format, RSAUsageError)
        if minimal:
            priv = self._build_minimal_rsa_privkey(priv)
        return priv
+12 −7
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@ from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import x25519

from .general import KeyClass
from .privatebytes import PrivateBytesMixin


class X25519UsageError(Exception):
@@ -39,7 +40,7 @@ class X25519Public(KeyClass):
                encoding=serialization.Encoding.PEM,
                format=serialization.PublicFormat.SubjectPublicKeyInfo)

    def get_private_bytes(self, minimal):
    def get_private_bytes(self, minimal, format):
        self._unsupported('get_private_bytes')

    def export_private(self, path, passwd=None):
@@ -63,7 +64,7 @@ class X25519Public(KeyClass):
        return 32


class X25519(X25519Public):
class X25519(X25519Public, PrivateBytesMixin):
    """
    Wrapper around an X25519 private key.
    """
@@ -80,11 +81,15 @@ class X25519(X25519Public):
    def _get_public(self):
        return self.key.public_key()

    def get_private_bytes(self, minimal):
        return self.key.private_bytes(
            encoding=serialization.Encoding.DER,
            format=serialization.PrivateFormat.PKCS8,
            encryption_algorithm=serialization.NoEncryption())
    _VALID_FORMATS = {
        'pkcs8': serialization.PrivateFormat.PKCS8
    }
    _DEFAULT_FORMAT = 'pkcs8'

    def get_private_bytes(self, minimal, format):
        _, priv = self._get_private_bytes(minimal, format,
                                          X25519UsageError)
        return priv

    def export_private(self, path, passwd=None):
        """
Loading