forked from benoitc/gunicorn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtreq.py
290 lines (250 loc) · 9.45 KB
/
treq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
# Copyright 2009 Paul J. Davis <[email protected]>
#
# This file is part of the pywebmachine package released
# under the MIT license.
import t
import inspect
import os
import random
from gunicorn._compat import execfile_
from gunicorn.config import Config
from gunicorn.http.parser import RequestParser
from gunicorn.six.moves.urllib.parse import urlparse
from gunicorn import six
dirname = os.path.dirname(__file__)
random.seed()
def uri(data):
ret = {"raw": data}
parts = urlparse(data)
ret["scheme"] = parts.scheme or ''
ret["host"] = parts.netloc.rsplit(":", 1)[0] or None
ret["port"] = parts.port or 80
if parts.path and parts.params:
ret["path"] = ";".join([parts.path, parts.params])
elif parts.path:
ret["path"] = parts.path
elif parts.params:
# Don't think this can happen
ret["path"] = ";" + parts.path
else:
ret["path"] = ''
ret["query"] = parts.query or ''
ret["fragment"] = parts.fragment or ''
return ret
def load_py(fname):
config = globals().copy()
config["uri"] = uri
config["cfg"] = Config()
execfile_(fname, config)
return config
class request(object):
def __init__(self, fname, expect):
self.fname = fname
self.name = os.path.basename(fname)
self.expect = expect
if not isinstance(self.expect, list):
self.expect = [self.expect]
with open(self.fname, 'rb') as handle:
self.data = handle.read()
self.data = self.data.replace(b"\n", b"").replace(b"\\r\\n", b"\r\n")
self.data = self.data.replace(b"\\0", b"\000")
# Functions for sending data to the parser.
# These functions mock out reading from a
# socket or other data source that might
# be used in real life.
def send_all(self):
yield self.data
def send_lines(self):
lines = self.data
pos = lines.find(b"\r\n")
while pos > 0:
yield lines[:pos+2]
lines = lines[pos+2:]
pos = lines.find(b"\r\n")
if len(lines):
yield lines
def send_bytes(self):
for d in str(self.data.decode("latin1")):
yield bytes(d.encode("latin1"))
def send_random(self):
maxs = round(len(self.data) / 10)
read = 0
while read < len(self.data):
chunk = random.randint(1, maxs)
yield self.data[read:read+chunk]
read += chunk
def send_special_chunks(self):
"""Meant to test the request line length check.
Sends the request data in two chunks, one having a
length of 1 byte, which ensures that no CRLF is included,
and a second chunk containing the rest of the request data.
If the request line length check is not done properly,
testing the ``tests/requests/valid/099.http`` request
fails with a ``LimitRequestLine`` exception.
"""
chunk = self.data[:1]
read = 0
while read < len(self.data):
yield self.data[read:read+len(chunk)]
read += len(chunk)
chunk = self.data[read:]
# These functions define the sizes that the
# read functions will read with.
def size_all(self):
return -1
def size_bytes(self):
return 1
def size_small_random(self):
return random.randint(1, 4)
def size_random(self):
return random.randint(1, 4096)
# Match a body against various ways of reading
# a message. Pass in the request, expected body
# and one of the size functions.
def szread(self, func, sizes):
sz = sizes()
data = func(sz)
if sz >= 0 and len(data) > sz:
raise AssertionError("Read more than %d bytes: %s" % (sz, data))
return data
def match_read(self, req, body, sizes):
data = self.szread(req.body.read, sizes)
count = 1000
while len(body):
if body[:len(data)] != data:
raise AssertionError("Invalid body data read: %r != %r" % (
data, body[:len(data)]))
body = body[len(data):]
data = self.szread(req.body.read, sizes)
if not data:
count -= 1
if count <= 0:
raise AssertionError("Unexpected apparent EOF")
if len(body):
raise AssertionError("Failed to read entire body: %r" % body)
elif len(data):
raise AssertionError("Read beyond expected body: %r" % data)
data = req.body.read(sizes())
if data:
raise AssertionError("Read after body finished: %r" % data)
def match_readline(self, req, body, sizes):
data = self.szread(req.body.readline, sizes)
count = 1000
while len(body):
if body[:len(data)] != data:
raise AssertionError("Invalid data read: %r" % data)
if b'\n' in data[:-1]:
raise AssertionError("Embedded new line: %r" % data)
body = body[len(data):]
data = self.szread(req.body.readline, sizes)
if not data:
count -= 1
if count <= 0:
raise AssertionError("Apparent unexpected EOF")
if len(body):
raise AssertionError("Failed to read entire body: %r" % body)
elif len(data):
raise AssertionError("Read beyond expected body: %r" % data)
data = req.body.readline(sizes())
if data:
raise AssertionError("Read data after body finished: %r" % data)
def match_readlines(self, req, body, sizes):
"""\
This skips the sizes checks as we don't implement it.
"""
data = req.body.readlines()
for line in data:
if b'\n' in line[:-1]:
raise AssertionError("Embedded new line: %r" % line)
if line != body[:len(line)]:
raise AssertionError("Invalid body data read: %r != %r" % (
line, body[:len(line)]))
body = body[len(line):]
if len(body):
raise AssertionError("Failed to read entire body: %r" % body)
data = req.body.readlines(sizes())
if data:
raise AssertionError("Read data after body finished: %r" % data)
def match_iter(self, req, body, sizes):
"""\
This skips sizes because there's its not part of the iter api.
"""
for line in req.body:
if b'\n' in line[:-1]:
raise AssertionError("Embedded new line: %r" % line)
if line != body[:len(line)]:
raise AssertionError("Invalid body data read: %r != %r" % (
line, body[:len(line)]))
body = body[len(line):]
if len(body):
raise AssertionError("Failed to read entire body: %r" % body)
try:
data = six.next(iter(req.body))
raise AssertionError("Read data after body finished: %r" % data)
except StopIteration:
pass
# Construct a series of test cases from the permutations of
# send, size, and match functions.
def gen_cases(self, cfg):
def get_funs(p):
return [v for k, v in inspect.getmembers(self) if k.startswith(p)]
senders = get_funs("send_")
sizers = get_funs("size_")
matchers = get_funs("match_")
cfgs = [
(mt, sz, sn)
for mt in matchers
for sz in sizers
for sn in senders
]
ret = []
for (mt, sz, sn) in cfgs:
if hasattr(mt, 'funcname'):
mtn = mt.func_name[6:]
szn = sz.func_name[5:]
snn = sn.func_name[5:]
else:
mtn = mt.__name__[6:]
szn = sz.__name__[5:]
snn = sn.__name__[5:]
def test_req(sn, sz, mt):
self.check(cfg, sn, sz, mt)
desc = "%s: MT: %s SZ: %s SN: %s" % (self.name, mtn, szn, snn)
test_req.description = desc
ret.append((test_req, sn, sz, mt))
return ret
def check(self, cfg, sender, sizer, matcher):
cases = self.expect[:]
p = RequestParser(cfg, sender())
for req in p:
self.same(req, sizer, matcher, cases.pop(0))
assert len(cases) == 0
def same(self, req, sizer, matcher, exp):
assert req.method == exp["method"]
assert req.uri == exp["uri"]["raw"]
assert req.path == exp["uri"]["path"]
assert req.query == exp["uri"]["query"]
assert req.fragment == exp["uri"]["fragment"]
assert req.version == exp["version"]
assert req.headers == exp["headers"]
matcher(req, exp["body"], sizer)
assert req.trailers == exp.get("trailers", [])
class badrequest(object):
def __init__(self, fname):
self.fname = fname
self.name = os.path.basename(fname)
with open(self.fname) as handle:
self.data = handle.read()
self.data = self.data.replace("\n", "").replace("\\r\\n", "\r\n")
self.data = self.data.replace("\\0", "\000")
self.data = self.data.encode('latin1')
def send(self):
maxs = round(len(self.data) / 10)
read = 0
while read < len(self.data):
chunk = random.randint(1, maxs)
yield self.data[read:read+chunk]
read += chunk
def check(self, cfg):
p = RequestParser(cfg, self.send())
six.next(p)