Skip to content

Commit

Permalink
make algorithms extendable
Browse files Browse the repository at this point in the history
  • Loading branch information
Marko Mrdjenovic committed Jan 10, 2017
1 parent 8556fc2 commit 98e0bea
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 20 deletions.
32 changes: 26 additions & 6 deletions jose/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import hashlib

class ALGORITHMS(object):

class Algorithms(object):
NONE = 'none'
HS256 = 'HS256'
HS384 = 'HS384'
Expand All @@ -12,13 +13,13 @@ class ALGORITHMS(object):
ES384 = 'ES384'
ES512 = 'ES512'

HMAC = (HS256, HS384, HS512)
RSA = (RS256, RS384, RS512)
EC = (ES256, ES384, ES512)
HMAC = set([HS256, HS384, HS512])
RSA = set([RS256, RS384, RS512])
EC = set([ES256, ES384, ES512])

SUPPORTED = HMAC + RSA + EC
SUPPORTED = HMAC.union(RSA).union(EC)

ALL = SUPPORTED + (NONE, )
ALL = SUPPORTED.union([NONE])

HASHES = {
HS256: hashlib.sha256,
Expand All @@ -31,3 +32,22 @@ class ALGORITHMS(object):
ES384: hashlib.sha384,
ES512: hashlib.sha512,
}

KEYS = {}

def get_key(self, algorithm):
from jose.jwk import HMACKey, RSAKey, ECKey
if algorithm in self.KEYS:
return self.KEYS[algorithm]
elif algorithm in self.HMAC:
return HMACKey
elif algorithm in self.RSA:
return RSAKey
elif algorithm in self.EC:
return ECKey

def register_key(self, algorithm, key_class):
self.KEYS[algorithm] = key_class
self.SUPPORTED.add(algorithm)

ALGORITHMS = Algorithms()
21 changes: 7 additions & 14 deletions jose/jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,10 @@ def construct(key_data, algorithm=None):
if not algorithm:
raise JWKError('Unable to find a algorithm for key: %s' % key_data)

if algorithm in ALGORITHMS.HMAC:
return HMACKey(key_data, algorithm)

if algorithm in ALGORITHMS.RSA:
return RSAKey(key_data, algorithm)

if algorithm in ALGORITHMS.EC:
return ECKey(key_data, algorithm)
key_class = ALGORITHMS.get_key(algorithm)
if not key_class:
raise JWKError('Unable to find a algorithm for key: %s' % key_data)
return key_class(key_data, algorithm)


def get_algorithm_object(algorithm):
Expand Down Expand Up @@ -112,13 +108,12 @@ class HMACKey(Key):
SHA256 = hashlib.sha256
SHA384 = hashlib.sha384
SHA512 = hashlib.sha512
valid_hash_algs = ALGORITHMS.HMAC

prepared_key = None
hash_alg = None

def __init__(self, key, algorithm):
if algorithm not in self.valid_hash_algs:
if algorithm not in ALGORITHMS.HMAC:
raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm)
self.hash_alg = get_algorithm_object(algorithm)

Expand Down Expand Up @@ -174,14 +169,13 @@ class RSAKey(Key):
SHA256 = Crypto.Hash.SHA256
SHA384 = Crypto.Hash.SHA384
SHA512 = Crypto.Hash.SHA512
valid_hash_algs = ALGORITHMS.RSA

prepared_key = None
hash_alg = None

def __init__(self, key, algorithm):

if algorithm not in self.valid_hash_algs:
if algorithm not in ALGORITHMS.RSA:
raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm)
self.hash_alg = get_algorithm_object(algorithm)

Expand Down Expand Up @@ -257,7 +251,6 @@ class ECKey(Key):
SHA256 = hashlib.sha256
SHA384 = hashlib.sha384
SHA512 = hashlib.sha512
valid_hash_algs = ALGORITHMS.EC

curve_map = {
SHA256: ecdsa.curves.NIST256p,
Expand All @@ -270,7 +263,7 @@ class ECKey(Key):
curve = None

def __init__(self, key, algorithm):
if algorithm not in self.valid_hash_algs:
if algorithm not in ALGORITHMS.EC:
raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm)
self.hash_alg = get_algorithm_object(algorithm)

Expand Down

0 comments on commit 98e0bea

Please sign in to comment.