Skip to content

Commit

Permalink
Merge pull request #42 from friedcell/extendable
Browse files Browse the repository at this point in the history
Easier extending/replacing of key algorithms
  • Loading branch information
mpdavis authored Mar 5, 2017
2 parents 8556fc2 + fae83ff commit 1a23c39
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 61 deletions.
18 changes: 12 additions & 6 deletions jose/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import hashlib

class ALGORITHMS(object):

class Algorithms(object):
NONE = 'none'
HS256 = 'HS256'
HS384 = 'HS384'
Expand All @@ -12,13 +13,13 @@ class ALGORITHMS(object):
ES384 = 'ES384'
ES512 = 'ES512'

HMAC = (HS256, HS384, HS512)
RSA = (RS256, RS384, RS512)
EC = (ES256, ES384, ES512)
HMAC = set([HS256, HS384, HS512])
RSA = set([RS256, RS384, RS512])
EC = set([ES256, ES384, ES512])

SUPPORTED = HMAC + RSA + EC
SUPPORTED = HMAC.union(RSA).union(EC)

ALL = SUPPORTED + (NONE, )
ALL = SUPPORTED.union([NONE])

HASHES = {
HS256: hashlib.sha256,
Expand All @@ -31,3 +32,8 @@ class ALGORITHMS(object):
ES384: hashlib.sha384,
ES512: hashlib.sha512,
}

KEYS = {}


ALGORITHMS = Algorithms()
64 changes: 32 additions & 32 deletions jose/jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,26 @@ def base64_to_long(data):
return int_arr_to_long(struct.unpack('%sB' % len(_d), _d))


def get_key(algorithm):
if algorithm in ALGORITHMS.KEYS:
return ALGORITHMS.KEYS[algorithm]
elif algorithm in ALGORITHMS.HMAC:
return HMACKey
elif algorithm in ALGORITHMS.RSA:
return RSAKey
elif algorithm in ALGORITHMS.EC:
return ECKey
return None


def register_key(algorithm, key_class):
if not issubclass(key_class, Key):
raise TypeError("Key class not a subclass of jwk.Key")
ALGORITHMS.KEYS[algorithm] = key_class
ALGORITHMS.SUPPORTED.add(algorithm)
return True


def construct(key_data, algorithm=None):
"""
Construct a Key object for the given algorithm with the given
Expand All @@ -60,14 +80,10 @@ def construct(key_data, algorithm=None):
if not algorithm:
raise JWKError('Unable to find a algorithm for key: %s' % key_data)

if algorithm in ALGORITHMS.HMAC:
return HMACKey(key_data, algorithm)

if algorithm in ALGORITHMS.RSA:
return RSAKey(key_data, algorithm)

if algorithm in ALGORITHMS.EC:
return ECKey(key_data, algorithm)
key_class = get_key(algorithm)
if not key_class:
raise JWKError('Unable to find a algorithm for key: %s' % key_data)
return key_class(key_data, algorithm)


def get_algorithm_object(algorithm):
Expand All @@ -91,11 +107,8 @@ class Key(object):
"""
A simple interface for implementing JWK keys.
"""
prepared_key = None
hash_alg = None

def _process_jwk(self, jwk_dict):
raise NotImplementedError()
def __init__(self, key, algorithm):
pass

def sign(self, msg):
raise NotImplementedError()
Expand All @@ -112,13 +125,9 @@ class HMACKey(Key):
SHA256 = hashlib.sha256
SHA384 = hashlib.sha384
SHA512 = hashlib.sha512
valid_hash_algs = ALGORITHMS.HMAC

prepared_key = None
hash_alg = None

def __init__(self, key, algorithm):
if algorithm not in self.valid_hash_algs:
if algorithm not in ALGORITHMS.HMAC:
raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm)
self.hash_alg = get_algorithm_object(algorithm)

Expand Down Expand Up @@ -174,14 +183,10 @@ class RSAKey(Key):
SHA256 = Crypto.Hash.SHA256
SHA384 = Crypto.Hash.SHA384
SHA512 = Crypto.Hash.SHA512
valid_hash_algs = ALGORITHMS.RSA

prepared_key = None
hash_alg = None

def __init__(self, key, algorithm):

if algorithm not in self.valid_hash_algs:
if algorithm not in ALGORITHMS.RSA:
raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm)
self.hash_alg = get_algorithm_object(algorithm)

Expand Down Expand Up @@ -242,7 +247,7 @@ def verify(self, msg, sig):
try:
return PKCS1_v1_5.new(self.prepared_key).verify(self.hash_alg.new(msg), sig)
except Exception as e:
raise JWKError(e)
return False


class ECKey(Key):
Expand All @@ -257,24 +262,19 @@ class ECKey(Key):
SHA256 = hashlib.sha256
SHA384 = hashlib.sha384
SHA512 = hashlib.sha512
valid_hash_algs = ALGORITHMS.EC

curve_map = {
CURVE_MAP = {
SHA256: ecdsa.curves.NIST256p,
SHA384: ecdsa.curves.NIST384p,
SHA512: ecdsa.curves.NIST521p,
}

prepared_key = None
hash_alg = None
curve = None

def __init__(self, key, algorithm):
if algorithm not in self.valid_hash_algs:
if algorithm not in ALGORITHMS.EC:
raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm)
self.hash_alg = get_algorithm_object(algorithm)

self.curve = self.curve_map.get(self.hash_alg)
self.curve = self.CURVE_MAP.get(self.hash_alg)

if isinstance(key, (ecdsa.SigningKey, ecdsa.VerifyingKey)):
self.prepared_key = key
Expand Down
29 changes: 12 additions & 17 deletions tests/algorithms/test_base.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
from jose.jwk import Key

# from jose.jwk import Key
# from jose.exceptions import JOSEError
import pytest

# import pytest

@pytest.fixture
def alg():
return Key("key", "ALG")

# @pytest.fixture
# def alg():
# return Key()

class TestBaseAlgorithm:

# class TestBaseAlgorithm:
def test_sign_is_interface(self, alg):
with pytest.raises(NotImplementedError):
alg.sign('msg')

# def test_prepare_key_is_interface(self, alg):
# with pytest.raises(JOSEError):
# alg.prepare_key('secret')
def test_verify_is_interface(self, alg):
with pytest.raises(NotImplementedError):
alg.verify('msg', 'sig')

# def test_sign_is_interface(self, alg):
# with pytest.raises(JOSEError):
# alg.sign('msg', 'secret')

# def test_verify_is_interface(self, alg):
# with pytest.raises(JOSEError):
# alg.verify('msg', 'secret', 'sig')
24 changes: 18 additions & 6 deletions tests/test_jwk.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@

from jose import jwk
from jose.exceptions import JWKError

import pytest


hmac_key = {
"kty": "oct",
"kid": "018c0ae5-4d9b-471b-bfd6-eef314bc7037",
Expand Down Expand Up @@ -35,10 +33,7 @@ class TestJWK:

def test_interface(self):

key = jwk.Key()

with pytest.raises(NotImplementedError):
key._process_jwk(None)
key = jwk.Key("key", "ALG")

with pytest.raises(NotImplementedError):
key.sign('')
Expand Down Expand Up @@ -115,3 +110,20 @@ def test_construct_from_jwk_missing_alg(self):

with pytest.raises(JWKError):
key = jwk.construct(hmac_key)

with pytest.raises(JWKError):
key = jwk.construct("key", algorithm="NONEXISTENT")

def test_get_key(self):
assert jwk.get_key("HS256") == jwk.HMACKey
assert jwk.get_key("RS256") == jwk.RSAKey
assert jwk.get_key("ES256") == jwk.ECKey

assert jwk.get_key("NONEXISTENT") == None

def test_register_key(self):
assert jwk.register_key("ALG", jwk.Key)
assert jwk.get_key("ALG") == jwk.Key

with pytest.raises(TypeError):
assert jwk.register_key("ALG", object)

0 comments on commit 1a23c39

Please sign in to comment.