Skip to content

Commit

Permalink
Add crude tests for the auth module, and fix python3 issues with oauth1
Browse files Browse the repository at this point in the history
  • Loading branch information
bdarnell committed Sep 12, 2011
1 parent ef788bc commit eb5f2ce
Show file tree
Hide file tree
Showing 3 changed files with 224 additions and 30 deletions.
67 changes: 37 additions & 30 deletions tornado/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def _on_auth(self, user):

import base64
import binascii
import cgi
import hashlib
import hmac
import logging
Expand Down Expand Up @@ -85,7 +84,7 @@ def authenticate_redirect(self, callback_uri=None,
args = self._openid_args(callback_uri, ax_attrs=ax_attrs)
self.redirect(self._OPENID_ENDPOINT + "?" + urllib.urlencode(args))

def get_authenticated_user(self, callback):
def get_authenticated_user(self, callback, http_client=None):
"""Fetches the authenticated user data upon redirect.
This method should be called by the handler that receives the
Expand All @@ -96,8 +95,8 @@ def get_authenticated_user(self, callback):
args = dict((k, v[-1]) for k, v in self.request.arguments.iteritems())
args["openid.mode"] = u"check_authentication"
url = self._OPENID_ENDPOINT
http = httpclient.AsyncHTTPClient()
http.fetch(url, self.async_callback(
if http_client is None: http_client = httpclient.AsyncHTTPClient()
http_client.fetch(url, self.async_callback(
self._on_authentication_verified, callback),
method="POST", body=urllib.urlencode(args))

Expand Down Expand Up @@ -207,7 +206,8 @@ class OAuthMixin(object):
See TwitterMixin and FriendFeedMixin below for example implementations.
"""

def authorize_redirect(self, callback_uri=None, extra_params=None):
def authorize_redirect(self, callback_uri=None, extra_params=None,
http_client=None):
"""Redirects the user to obtain OAuth authorization for this service.
Twitter and FriendFeed both require that you register a Callback
Expand All @@ -222,20 +222,25 @@ def authorize_redirect(self, callback_uri=None, extra_params=None):
"""
if callback_uri and getattr(self, "_OAUTH_NO_CALLBACKS", False):
raise Exception("This service does not support oauth_callback")
http = httpclient.AsyncHTTPClient()
if http_client is None:
http_client = httpclient.AsyncHTTPClient()
if getattr(self, "_OAUTH_VERSION", "1.0a") == "1.0a":
http.fetch(self._oauth_request_token_url(callback_uri=callback_uri,
extra_params=extra_params),
http_client.fetch(
self._oauth_request_token_url(callback_uri=callback_uri,
extra_params=extra_params),
self.async_callback(
self._on_request_token,
self._OAUTH_AUTHORIZE_URL,
callback_uri))
else:
http.fetch(self._oauth_request_token_url(), self.async_callback(
self._on_request_token, self._OAUTH_AUTHORIZE_URL, callback_uri))
http_client.fetch(
self._oauth_request_token_url(),
self.async_callback(
self._on_request_token, self._OAUTH_AUTHORIZE_URL,
callback_uri))


def get_authenticated_user(self, callback):
def get_authenticated_user(self, callback, http_client=None):
"""Gets the OAuth authorized user and access token on callback.
This method should be called from the handler for your registered
Expand All @@ -246,25 +251,27 @@ def get_authenticated_user(self, callback):
to this service on behalf of the user.
"""
request_key = self.get_argument("oauth_token")
request_key = escape.utf8(self.get_argument("oauth_token"))
oauth_verifier = self.get_argument("oauth_verifier", None)
request_cookie = self.get_cookie("_oauth_request_token")
if not request_cookie:
logging.warning("Missing OAuth request token cookie")
callback(None)
return
self.clear_cookie("_oauth_request_token")
cookie_key, cookie_secret = [base64.b64decode(i) for i in request_cookie.split("|")]
cookie_key, cookie_secret = [base64.b64decode(escape.utf8(i)) for i in request_cookie.split("|")]
if cookie_key != request_key:
logging.info((cookie_key, request_key, request_cookie))
logging.warning("Request token does not match cookie")
callback(None)
return
token = dict(key=cookie_key, secret=cookie_secret)
if oauth_verifier:
token["verifier"] = oauth_verifier
http = httpclient.AsyncHTTPClient()
http.fetch(self._oauth_access_token_url(token), self.async_callback(
self._on_access_token, callback))
token["verifier"] = oauth_verifier
if http_client is None:
http_client = httpclient.AsyncHTTPClient()
http_client.fetch(self._oauth_access_token_url(token),
self.async_callback(self._on_access_token, callback))

def _oauth_request_token_url(self, callback_uri= None, extra_params=None):
consumer_token = self._oauth_consumer_token()
Expand Down Expand Up @@ -292,8 +299,8 @@ def _on_request_token(self, authorize_url, callback_uri, response):
if response.error:
raise Exception("Could not get request token")
request_token = _oauth_parse_response(response.body)
data = "|".join([base64.b64encode(request_token["key"]),
base64.b64encode(request_token["secret"])])
data = (base64.b64encode(request_token["key"]) + b("|") +
base64.b64encode(request_token["secret"]))
self.set_cookie("_oauth_request_token", data)
args = dict(oauth_token=request_token["key"])
if callback_uri:
Expand Down Expand Up @@ -1078,11 +1085,11 @@ def _oauth_signature(consumer_token, method, url, parameters={}, token=None):
for k, v in sorted(parameters.items())))
base_string = "&".join(_oauth_escape(e) for e in base_elems)

key_elems = [consumer_token["secret"]]
key_elems.append(token["secret"] if token else "")
key = "&".join(key_elems)
key_elems = [escape.utf8(consumer_token["secret"])]
key_elems.append(escape.utf8(token["secret"] if token else ""))
key = b("&").join(key_elems)

hash = hmac.new(key, base_string, hashlib.sha1)
hash = hmac.new(key, escape.utf8(base_string), hashlib.sha1)
return binascii.b2a_base64(hash.digest())[:-1]

def _oauth10a_signature(consumer_token, method, url, parameters={}, token=None):
Expand All @@ -1101,11 +1108,11 @@ def _oauth10a_signature(consumer_token, method, url, parameters={}, token=None):
for k, v in sorted(parameters.items())))

base_string = "&".join(_oauth_escape(e) for e in base_elems)
key_elems = [urllib.quote(consumer_token["secret"], safe='~')]
key_elems.append(urllib.quote(token["secret"], safe='~') if token else "")
key = "&".join(key_elems)
key_elems = [escape.utf8(urllib.quote(consumer_token["secret"], safe='~'))]
key_elems.append(escape.utf8(urllib.quote(token["secret"], safe='~') if token else ""))
key = b("&").join(key_elems)

hash = hmac.new(key, base_string, hashlib.sha1)
hash = hmac.new(key, escape.utf8(base_string), hashlib.sha1)
return binascii.b2a_base64(hash.digest())[:-1]

def _oauth_escape(val):
Expand All @@ -1115,11 +1122,11 @@ def _oauth_escape(val):


def _oauth_parse_response(body):
p = cgi.parse_qs(body, keep_blank_values=False)
token = dict(key=p["oauth_token"][0], secret=p["oauth_token_secret"][0])
p = escape.parse_qs(body, keep_blank_values=False)
token = dict(key=p[b("oauth_token")][0], secret=p[b("oauth_token_secret")][0])

# Add the extra parameters the Provider included to the token
special = ("oauth_token", "oauth_token_secret")
special = (b("oauth_token"), b("oauth_token_secret"))
token.update((k, p[k][0]) for k in p if k not in special)
return token

Expand Down
186 changes: 186 additions & 0 deletions tornado/test/auth_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# These tests do not currently do much to verify the correct implementation
# of the openid/oauth protocols, they just exercise the major code paths
# and ensure that it doesn't blow up (e.g. with unicode/bytes issues in
# python 3)

from tornado.auth import OpenIdMixin, OAuthMixin, OAuth2Mixin
from tornado.escape import json_decode
from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase
from tornado.util import b
from tornado.web import RequestHandler, Application, asynchronous

class OpenIdClientLoginHandler(RequestHandler, OpenIdMixin):
def initialize(self, test):
self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate')

@asynchronous
def get(self):
if self.get_argument('openid.mode', None):
self.get_authenticated_user(
self.on_user, http_client=self.settings['http_client'])
return
self.authenticate_redirect()

def on_user(self, user):
assert user is not None
self.finish(user)

class OpenIdServerAuthenticateHandler(RequestHandler):
def post(self):
assert self.get_argument('openid.mode') == 'check_authentication'
self.write('is_valid:true')

class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
def initialize(self, test, version):
self._OAUTH_VERSION = version
self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize')
self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/oauth1/server/access_token')

def _oauth_consumer_token(self):
return dict(key='asdf', secret='qwer')

@asynchronous
def get(self):
if self.get_argument('oauth_token', None):
self.get_authenticated_user(
self.on_user, http_client=self.settings['http_client'])
return
self.authorize_redirect(http_client=self.settings['http_client'])

def on_user(self, user):
assert user is not None
self.finish(user)

def _oauth_get_user(self, access_token, callback):
assert access_token == dict(key=b('uiop'), secret=b('5678')), access_token
callback(dict(email='[email protected]'))

class OAuth1ClientRequestParametersHandler(RequestHandler, OAuthMixin):
def initialize(self, version):
self._OAUTH_VERSION = version

def _oauth_consumer_token(self):
return dict(key='asdf', secret='qwer')

def get(self):
params = self._oauth_request_parameters(
'http://www.example.com/api/asdf',
dict(key='uiop', secret='5678'),
parameters=dict(foo='bar'))
import urllib; urllib.urlencode(params)
self.write(params)

class OAuth1ServerRequestTokenHandler(RequestHandler):
def get(self):
self.write('oauth_token=zxcv&oauth_token_secret=1234')

class OAuth1ServerAccessTokenHandler(RequestHandler):
def get(self):
self.write('oauth_token=uiop&oauth_token_secret=5678')

class OAuth2ClientLoginHandler(RequestHandler, OAuth2Mixin):
def initialize(self, test):
self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth2/server/authorize')

def get(self):
self.authorize_redirect()


class AuthTest(AsyncHTTPTestCase, LogTrapTestCase):
def get_app(self):
return Application(
[
# test endpoints
('/openid/client/login', OpenIdClientLoginHandler, dict(test=self)),
('/oauth10/client/login', OAuth1ClientLoginHandler,
dict(test=self, version='1.0')),
('/oauth10/client/request_params',
OAuth1ClientRequestParametersHandler,
dict(version='1.0')),
('/oauth10a/client/login', OAuth1ClientLoginHandler,
dict(test=self, version='1.0a')),
('/oauth10a/client/request_params',
OAuth1ClientRequestParametersHandler,
dict(version='1.0a')),
('/oauth2/client/login', OAuth2ClientLoginHandler, dict(test=self)),

# simulated servers
('/openid/server/authenticate', OpenIdServerAuthenticateHandler),
('/oauth1/server/request_token', OAuth1ServerRequestTokenHandler),
('/oauth1/server/access_token', OAuth1ServerAccessTokenHandler),
],
http_client=self.http_client)

def test_openid_redirect(self):
response = self.fetch('/openid/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(
'/openid/server/authenticate?' in response.headers['Location'])

def test_openid_get_user(self):
response = self.fetch('/openid/client/login?openid.mode=blah&openid.ns.ax=http://openid.net/srv/ax/1.0&openid.ax.type.email=http://axschema.org/contact/email&[email protected]')
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["email"], "[email protected]")

def test_oauth10_redirect(self):
response = self.fetch('/oauth10/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(response.headers['Location'].endswith(
'/oauth1/server/authorize?oauth_token=zxcv'))
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
response.headers['Set-Cookie'])

def test_oauth10_get_user(self):
response = self.fetch(
'/oauth10/client/login?oauth_token=zxcv',
headers={'Cookie':'_oauth_request_token=enhjdg==|MTIzNA=='})
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed['email'], '[email protected]')
self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))

def test_oauth10_request_parameters(self):
response = self.fetch('/oauth10/client/request_params')
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed['oauth_consumer_key'], 'asdf')
self.assertEqual(parsed['oauth_token'], 'uiop')
self.assertTrue('oauth_nonce' in parsed)
self.assertTrue('oauth_signature' in parsed)

def test_oauth10a_redirect(self):
response = self.fetch('/oauth10a/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(response.headers['Location'].endswith(
'/oauth1/server/authorize?oauth_token=zxcv'))
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
response.headers['Set-Cookie'])

def test_oauth10a_get_user(self):
response = self.fetch(
'/oauth10a/client/login?oauth_token=zxcv',
headers={'Cookie':'_oauth_request_token=enhjdg==|MTIzNA=='})
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed['email'], '[email protected]')
self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))

def test_oauth10a_request_parameters(self):
response = self.fetch('/oauth10a/client/request_params')
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed['oauth_consumer_key'], 'asdf')
self.assertEqual(parsed['oauth_token'], 'uiop')
self.assertTrue('oauth_nonce' in parsed)
self.assertTrue('oauth_signature' in parsed)

def test_oauth2_redirect(self):
response = self.fetch('/oauth2/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue('/oauth2/server/authorize?' in response.headers['Location'])
1 change: 1 addition & 0 deletions tornado/test/runtests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
'tornado.httputil.doctests',
'tornado.iostream.doctests',
'tornado.util.doctests',
'tornado.test.auth_test',
'tornado.test.curl_httpclient_test',
'tornado.test.escape_test',
'tornado.test.gen_test',
Expand Down

0 comments on commit eb5f2ce

Please sign in to comment.