Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Elliptic Curve keys #8

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 94 additions & 7 deletions src/josepy/jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes # type: ignore
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ec # type: ignore
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.asymmetric import rsa

from josepy import errors, json_util, util
Expand Down Expand Up @@ -121,27 +121,114 @@ def load(cls, data, password=None, backend=None):


@JWK.register
class JWKES(JWK): # pragma: no cover
class JWKEC(JWK): # pragma: no cover
# pylint: disable=abstract-class-not-used
"""ES JWK.
"""EC JWK.

.. warning:: This is not yet implemented!

"""
typ = 'ES'
typ = 'EC'
__slots__ = ('key',)
cryptography_key_types = (
ec.EllipticCurvePublicKey, ec.EllipticCurvePrivateKey)
required = ('crv', JWK.type_field_name, 'x', 'y')

def __init__(self, *args, **kwargs):
if 'key' in kwargs and not isinstance(
kwargs['key'], util.ComparableECKey):
kwargs['key'] = util.ComparableECKey(kwargs['key'])
super(JWKEC, self).__init__(*args, **kwargs)

@classmethod
def _encode_param(cls, data):
"""Encode Base64urlUInt.

:type data: long
:rtype: unicode

"""
def _leading_zeros(arg):
if len(arg) % 2:
return '0' + arg
return arg

return json_util.encode_b64jose(binascii.unhexlify(
_leading_zeros(hex(data)[2:].rstrip('L'))))

@classmethod
def _decode_param(cls, data, name, expected_length):
"""Decode Base64urlUInt."""
try:
binary = json_util.decode_b64jose(data)
if len(binary) != expected_length:
raise errors.Error(
'Expected {name} to be {expected_length} bytes after base64-decoding; got {length}',
name=name, expected_length=expected_length, length=len(binary))
return int(binascii.hexlify(binary), 16)
except ValueError: # invalid literal for long() with base 16
raise errors.DeserializationError()

def fields_to_partial_json(self):
raise NotImplementedError()
params = {}
if isinstance(self.key._wrapped, ec.EllipticCurvePublicKey):
public = self.key.public_numbers()
elif isinstance(self.key._wrapped, ec.EllipticCurvePrivateKey):
private = self.key.private_numbers()
public = self.key.public_key().public_numbers()
params.update({
'd': private.private_value,
})
else: raise AssertionError(
"key was not an EllipticCurvePublicKey or EllipticCurvePrivateKey")

params.update({
'x': public.x,
'y': public.y,
})
params = dict((key, self._encode_param(value))
for key, value in six.iteritems(params))
params['crv'] = self._curve_name_to_crv(public.curve.name)
return params

@classmethod
def _curve_name_to_crv(cls, curve_name):
if curve_name == "secp256r1": return "P-256"
if curve_name == "secp384r1": return "P-384"
if curve_name == "secp521r1": return "P-521"
raise errors.SerializationError()

@classmethod
def _crv_to_curve(cls, crv):
# crv is case-sensitive
if crv == "P-256": return ec.SECP256R1()
if crv == "P-384": return ec.SECP384R1()
if crv == "P-521": return ec.SECP521R1()
raise errors.DeserializationError()

@classmethod
def fields_from_json(cls, jobj):
raise NotImplementedError()
# pylint: disable=invalid-name
curve = cls._crv_to_curve(jobj['crv'])
coord_length = (curve.key_size+7)//8
x, y = (cls._decode_param(jobj[n], n, coord_length) for n in ('x', 'y'))
public_numbers = ec.EllipticCurvePublicNumbers(x=x, y=y, curve=curve)
if 'd' not in jobj: # public key
key = public_numbers.public_key(default_backend())
else: # private key
exp_length = (curve.key_size.bit_length()+7)//8
d = cls._decode_param(jobj['d'], 'd', exp_length)
key = ec.EllipticCurvePrivateNumbers(d, public_numbers).private_key(
default_backend())
return cls(key=key)

def public_key(self):
raise NotImplementedError()
# Unlike RSAPrivateKey, EllipticCurvePrivateKey does not contain public_key()
if hasattr(self.key, 'public_key'):
key = self.key.public_key()
else:
key = self.key.public_numbers().public_key(default_backend())
return type(self)(key=key)


@JWK.register
Expand Down
31 changes: 30 additions & 1 deletion src/josepy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import OpenSSL
import six
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec, rsa


class abstractclassmethod(classmethod):
Expand Down Expand Up @@ -134,6 +135,34 @@ def __hash__(self):
pub = self.public_numbers()
return hash((self.__class__, pub.n, pub.e))

class ComparableECKey(ComparableKey): # pylint: disable=too-few-public-methods
"""Wrapper for ``cryptography`` RSA keys.

Wraps around:

- :class:`~cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey`
- :class:`~cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey`

"""

def __hash__(self):
# public_numbers() hasn't got stable hash!
# https://github.com/pyca/cryptography/issues/2143
if isinstance(self._wrapped, ec.EllipticCurvePrivateKeyWithSerialization):
priv = self.private_numbers()
pub = priv.public_numbers
return hash((self.__class__, pub.curve.name, pub.x, pub.y, priv.d))
elif isinstance(self._wrapped, ec.EllipticCurvePublicKeyWithSerialization):
pub = self.public_numbers()
return hash((self.__class__, pub.curve.name, pub.x, pub.y))
def public_key(self):
"""Get wrapped public key."""
# Unlike RSAPrivateKey, EllipticCurvePrivateKey does not have public_key()
if hasattr(self._wrapped, 'public_key'):
key = self._wrapped.public_key()
else:
key = self._wrapped.public_numbers().public_key(default_backend())
return self.__class__(key)

class ImmutableMap(collections.Mapping, collections.Hashable): # type: ignore
# pylint: disable=too-few-public-methods
Expand Down