Skip to content

Commit

Permalink
Merge pull request tornadoweb#907 from MrTravisB/master
Browse files Browse the repository at this point in the history
Support distinguishing argument origin between query and body.
  • Loading branch information
bdarnell committed Oct 6, 2013
2 parents 38330de + 010728a commit f4f8595
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 9 deletions.
4 changes: 4 additions & 0 deletions docs/web.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@

.. automethod:: RequestHandler.get_argument
.. automethod:: RequestHandler.get_arguments
.. automethod:: RequestHandler.get_query_argument
.. automethod:: RequestHandler.get_query_arguments
.. automethod:: RequestHandler.get_body_argument
.. automethod:: RequestHandler.get_body_arguments
.. automethod:: RequestHandler.decode_argument
.. attribute:: RequestHandler.request

Expand Down
Empty file modified tornado/escape.py
100644 → 100755
Empty file.
8 changes: 7 additions & 1 deletion tornado/httpserver.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class except to start a server at the beginning of the process
import socket
import ssl
import time
import copy

from tornado.escape import native_str, parse_qs_bytes
from tornado import httputil
Expand Down Expand Up @@ -336,7 +337,10 @@ def _on_request_body(self, data):
if self._request.method in ("POST", "PATCH", "PUT"):
httputil.parse_body_arguments(
self._request.headers.get("Content-Type", ""), data,
self._request.arguments, self._request.files)
self._request.body_arguments, self._request.files)

for k, v in self._request.body_arguments.items():
self._request.arguments.setdefault(k, []).extend(v)
self.request_callback(self._request)


Expand Down Expand Up @@ -457,6 +461,8 @@ def __init__(self, method, uri, version="HTTP/1.0", headers=None,

self.path, sep, self.query = uri.partition('?')
self.arguments = parse_qs_bytes(self.query, keep_blank_values=True)
self.query_arguments = copy.deepcopy(self.arguments)
self.body_arguments = {}

def supports_http_1_1(self):
"""Returns True if this request supports HTTP/1.1 semantics"""
Expand Down
Empty file modified tornado/httputil.py
100644 → 100755
Empty file.
Empty file modified tornado/test/httputil_test.py
100644 → 100755
Empty file.
Empty file modified tornado/test/runtests.py
100644 → 100755
Empty file.
50 changes: 50 additions & 0 deletions tornado/test/web_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
import socket
import sys

try:
import urllib.parse as urllib_parse # py3
except ImportError:
import urllib as urllib_parse # py2

wsgi_safe_tests = []

relpath = lambda *a: os.path.join(os.path.dirname(__file__), *a)
Expand Down Expand Up @@ -482,6 +487,19 @@ class GetArgumentHandler(RequestHandler):
def get(self):
self.write(self.get_argument("foo", "default"))

def post(self):
self.write(self.get_argument("foo", "default"))


class GetQueryArgumentHandler(RequestHandler):
def post(self):
self.write(self.get_query_argument("foo", "default"))


class GetBodyArgumentHandler(RequestHandler):
def post(self):
self.write(self.get_body_argument("foo", "default"))


# This test is shared with wsgi_test.py
@wsgi_safe
Expand Down Expand Up @@ -521,6 +539,8 @@ def get_handlers(self):
url("/redirect", RedirectHandler),
url("/header_injection", HeaderInjectionHandler),
url("/get_argument", GetArgumentHandler),
url("/get_query_argument", GetQueryArgumentHandler),
url("/get_body_argument", GetBodyArgumentHandler),
]
return urls

Expand Down Expand Up @@ -647,6 +667,36 @@ def test_get_argument(self):
response = self.fetch("/get_argument")
self.assertEqual(response.body, b"default")

# test merging of query and body arguments
# body arguments overwrite query arguments
body = urllib_parse.urlencode(dict(foo="hello"))
response = self.fetch("/get_argument?foo=bar", method="POST", body=body)
self.assertEqual(response.body, b"hello")

def test_get_query_arguments(self):
# send as a post so we can ensure the separation between query
# string and body arguments.
body = urllib_parse.urlencode(dict(foo="hello"))
response = self.fetch("/get_query_argument?foo=bar", method="POST", body=body)
self.assertEqual(response.body, b"bar")
response = self.fetch("/get_query_argument?foo=", method="POST", body=body)
self.assertEqual(response.body, b"")
response = self.fetch("/get_query_argument", method="POST", body=body)
self.assertEqual(response.body, b"default")

def test_get_body_arguments(self):
body = urllib_parse.urlencode(dict(foo="bar"))
response = self.fetch("/get_body_argument?foo=hello", method="POST", body=body)
self.assertEqual(response.body, b"bar")

body = urllib_parse.urlencode(dict(foo=""))
response = self.fetch("/get_body_argument?foo=hello", method="POST", body=body)
self.assertEqual(response.body, b"")

body = urllib_parse.urlencode(dict())
response = self.fetch("/get_body_argument?foo=hello", method="POST", body=body)
self.assertEqual(response.body, b"default")

def test_no_gzip(self):
response = self.fetch('/get_argument')
self.assertNotIn('Accept-Encoding', response.headers.get('Vary', ''))
Expand Down
65 changes: 58 additions & 7 deletions tornado/web.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -348,12 +348,7 @@ def get_argument(self, name, default=_ARG_DEFAULT, strip=True):
The returned value is always unicode.
"""
args = self.get_arguments(name, strip=strip)
if not args:
if default is self._ARG_DEFAULT:
raise MissingArgumentError(name)
return default
return args[-1]
return self._get_argument(name, default, self.request.arguments, strip)

def get_arguments(self, name, strip=True):
"""Returns a list of the arguments with the given name.
Expand All @@ -362,9 +357,65 @@ def get_arguments(self, name, strip=True):
The returned values are always unicode.
"""
return self._get_arguments(name, self.request.arguments, strip)

def get_body_argument(self, name, default=_ARG_DEFAULT, strip=True):
"""Returns the value of the argument with the given name
from the request body.
If default is not provided, the argument is considered to be
required, and we raise a `MissingArgumentError` if it is missing.
If the argument appears in the url more than once, we return the
last value.
The returned value is always unicode.
"""
return self._get_argument(name, default, self.request.body_arguments, strip)

def get_body_arguments(self, name, strip=True):
"""Returns a list of the body arguments with the given name.
If the argument is not present, returns an empty list.
The returned values are always unicode.
"""
return self._get_arguments(name, self.request.body_arguments, strip)

def get_query_argument(self, name, default=_ARG_DEFAULT, strip=True):
"""Returns the value of the argument with the given name
from the request query string.
If default is not provided, the argument is considered to be
required, and we raise a `MissingArgumentError` if it is missing.
If the argument appears in the url more than once, we return the
last value.
The returned value is always unicode.
"""
return self._get_argument(name, default, self.request.query_arguments, strip)

def get_query_arguments(self, name, strip=True):
"""Returns a list of the query arguments with the given name.
If the argument is not present, returns an empty list.
The returned values are always unicode.
"""
return self._get_arguments(name, self.request.query_arguments, strip)

def _get_argument(self, name, default, source, strip=True):
args = self._get_arguments(name, source, strip=strip)
if not args:
if default is self._ARG_DEFAULT:
raise MissingArgumentError(name)
return default
return args[-1]

def _get_arguments(self, name, source, strip=True):
values = []
for v in self.request.arguments.get(name, []):
for v in source.get(name, []):
v = self.decode_argument(v, name=name)
if isinstance(v, unicode_type):
# Get rid of any weird control chars (unless decoding gave
Expand Down
9 changes: 8 additions & 1 deletion tornado/wsgi.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

import sys
import time
import copy
import tornado

from tornado import escape
Expand Down Expand Up @@ -142,11 +143,14 @@ def __init__(self, environ):
self.path += urllib_parse.quote(from_wsgi_str(environ.get("PATH_INFO", "")))
self.uri = self.path
self.arguments = {}
self.query_arguments = {}
self.body_arguments = {}
self.query = environ.get("QUERY_STRING", "")
if self.query:
self.uri += "?" + self.query
self.arguments = parse_qs_bytes(native_str(self.query),
keep_blank_values=True)
self.query_arguments = copy.deepcopy(self.arguments)
self.version = "HTTP/1.1"
self.headers = httputil.HTTPHeaders()
if environ.get("CONTENT_TYPE"):
Expand All @@ -171,7 +175,10 @@ def __init__(self, environ):
# Parse request body
self.files = {}
httputil.parse_body_arguments(self.headers.get("Content-Type", ""),
self.body, self.arguments, self.files)
self.body, self.body_arguments, self.files)

for k, v in self.body_arguments.items():
self.arguments.setdefault(k, []).extend(v)

self._start_time = time.time()
self._finish_time = None
Expand Down

0 comments on commit f4f8595

Please sign in to comment.