Skip to content

gh-135056: Add a --cors CLI argument to http.server #135057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion Doc/library/http.server.rst
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ instantiation, of which this module provides three different variants:
delays, it now always returns the IP address.


.. class:: SimpleHTTPRequestHandler(request, client_address, server, directory=None)
.. class:: SimpleHTTPRequestHandler(request, client_address, server, directory=None, response_headers=None)

This class serves files from the directory *directory* and below,
or the current directory if *directory* is not provided, directly
Expand All @@ -374,6 +374,10 @@ instantiation, of which this module provides three different variants:
.. versionchanged:: 3.9
The *directory* parameter accepts a :term:`path-like object`.

.. versionchanged:: next
The *response_headers* parameter accepts an optional dictionary of
additional HTTP headers to add to each response.

A lot of the work, such as parsing the request, is done by the base class
:class:`BaseHTTPRequestHandler`. This class implements the :func:`do_GET`
and :func:`do_HEAD` functions.
Expand Down Expand Up @@ -428,6 +432,9 @@ instantiation, of which this module provides three different variants:
followed by a ``'Content-Length:'`` header with the file's size and a
``'Last-Modified:'`` header with the file's modification time.

The headers specified in the dictionary instance argument
``response_headers`` are each individually sent in the response.

Then follows a blank line signifying the end of the headers, and then the
contents of the file are output.

Expand All @@ -437,6 +444,9 @@ instantiation, of which this module provides three different variants:
.. versionchanged:: 3.7
Support of the ``'If-Modified-Since'`` header.

.. versionchanged:: next
Support ``response_headers`` as an instance argument.

The :class:`SimpleHTTPRequestHandler` class can be used in the following
manner in order to create a very basic webserver serving files relative to
the current directory::
Expand Down Expand Up @@ -543,6 +553,14 @@ The following options are accepted:

.. versionadded:: 3.14

.. option:: --cors
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As Hugo said, since we're anyway exposing response-headers, I think we should also expose it from the CLI. It could be useful for users in general (e.g., --add-header NAME VALUE with the -H alias).


Adds an additional CORS (Cross-Origin Resource sharing) header to each response::

Access-Control-Allow-Origin: *

.. versionadded:: next


.. _http.server-security:

Expand Down
51 changes: 44 additions & 7 deletions Lib/http/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,24 @@ class HTTPServer(socketserver.TCPServer):
allow_reuse_address = True # Seems to make sense in testing environment
allow_reuse_port = True

def __init__(self, *args, response_headers=None, **kwargs):
self.response_headers = response_headers
super().__init__(*args, **kwargs)

def server_bind(self):
"""Override server_bind to store the server name."""
socketserver.TCPServer.server_bind(self)
host, port = self.server_address[:2]
self.server_name = socket.getfqdn(host)
self.server_port = port

def finish_request(self, request, client_address):
"""Finish one request by instantiating RequestHandlerClass."""
args = (request, client_address, self)
kwargs = {}
if hasattr(self, 'response_headers'):
kwargs['response_headers'] = self.response_headers
self.RequestHandlerClass(request, client_address, self, **kwargs)

class ThreadingHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
daemon_threads = True
Expand All @@ -132,7 +143,7 @@ class ThreadingHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
class HTTPSServer(HTTPServer):
def __init__(self, server_address, RequestHandlerClass,
bind_and_activate=True, *, certfile, keyfile=None,
password=None, alpn_protocols=None):
password=None, alpn_protocols=None, **http_server_kwargs):
try:
import ssl
except ImportError:
Expand All @@ -150,7 +161,8 @@ def __init__(self, server_address, RequestHandlerClass,

super().__init__(server_address,
RequestHandlerClass,
bind_and_activate)
bind_and_activate,
**http_server_kwargs)

def server_activate(self):
"""Wrap the socket in SSLSocket."""
Expand Down Expand Up @@ -692,10 +704,11 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
'.xz': 'application/x-xz',
}

def __init__(self, *args, directory=None, **kwargs):
def __init__(self, *args, directory=None, response_headers=None, **kwargs):
if directory is None:
directory = os.getcwd()
self.directory = os.fspath(directory)
self.response_headers = response_headers
super().__init__(*args, **kwargs)

def do_GET(self):
Expand Down Expand Up @@ -736,6 +749,10 @@ def send_head(self):
new_url = urllib.parse.urlunsplit(new_parts)
self.send_header("Location", new_url)
self.send_header("Content-Length", "0")
# User specified response_headers
if self.response_headers is not None:
for header, value in self.response_headers.items():
self.send_header(header, value)
Comment on lines +753 to +755
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make it a private method, say self._add_custom_response_headers or something like that

self.end_headers()
return None
for index in self.index_pages:
Expand Down Expand Up @@ -795,6 +812,9 @@ def send_head(self):
self.send_header("Content-Length", str(fs[6]))
self.send_header("Last-Modified",
self.date_time_string(fs.st_mtime))
if self.response_headers is not None:
for header, value in self.response_headers.items():
self.send_header(header, value)
self.end_headers()
return f
except:
Expand Down Expand Up @@ -970,7 +990,8 @@ def _get_best_family(*address):
def test(HandlerClass=BaseHTTPRequestHandler,
ServerClass=ThreadingHTTPServer,
protocol="HTTP/1.0", port=8000, bind=None,
tls_cert=None, tls_key=None, tls_password=None):
tls_cert=None, tls_key=None, tls_password=None,
response_headers=None):
"""Test the HTTP request handler class.

This runs an HTTP server on port 8000 (or the port argument).
Expand All @@ -981,9 +1002,10 @@ def test(HandlerClass=BaseHTTPRequestHandler,

if tls_cert:
server = ServerClass(addr, HandlerClass, certfile=tls_cert,
keyfile=tls_key, password=tls_password)
keyfile=tls_key, password=tls_password,
response_headers=response_headers)
else:
server = ServerClass(addr, HandlerClass)
server = ServerClass(addr, HandlerClass, response_headers=response_headers)

with server as httpd:
host, port = httpd.socket.getsockname()[:2]
Expand Down Expand Up @@ -1024,6 +1046,13 @@ def _main(args=None):
parser.add_argument('port', default=8000, type=int, nargs='?',
help='bind to this port '
'(default: %(default)s)')
parser.add_argument('--cors', action='store_true',
help='Enable Access-Control-Allow-Origin: * header')
parser.add_argument('-H', '--header', nargs=2, action='append',
# metavar='HEADER VALUE',
metavar=('HEADER', 'VALUE'),
help='Add a custom response header '
'(can be used multiple times)')
args = parser.parse_args(args)

if not args.tls_cert and args.tls_key:
Expand Down Expand Up @@ -1052,14 +1081,21 @@ def server_bind(self):

def finish_request(self, request, client_address):
self.RequestHandlerClass(request, client_address, self,
directory=args.directory)
directory=args.directory,
response_headers=self.response_headers)

class HTTPDualStackServer(DualStackServerMixin, ThreadingHTTPServer):
pass
class HTTPSDualStackServer(DualStackServerMixin, ThreadingHTTPSServer):
pass

ServerClass = HTTPSDualStackServer if args.tls_cert else HTTPDualStackServer
response_headers = {}
if args.cors:
response_headers['Access-Control-Allow-Origin'] = '*'
for header, value in args.header or []:
response_headers[header] = value


test(
HandlerClass=SimpleHTTPRequestHandler,
Expand All @@ -1070,6 +1106,7 @@ class HTTPSDualStackServer(DualStackServerMixin, ThreadingHTTPSServer):
tls_cert=args.tls_cert,
tls_key=args.tls_key,
tls_password=tls_key_password,
response_headers=response_headers or None
)


Expand Down
2 changes: 1 addition & 1 deletion Lib/socketserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ class BaseRequestHandler:

"""

def __init__(self, request, client_address, server):
def __init__(self, request, client_address, server, **kwargs):
self.request = request
self.client_address = client_address
self.server = server
Expand Down
46 changes: 43 additions & 3 deletions Lib/test/test_httpservers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,21 +81,24 @@ def test_https_server_raises_runtime_error(self):


class TestServerThread(threading.Thread):
def __init__(self, test_object, request_handler, tls=None):
def __init__(self, test_object, request_handler, tls=None, server_kwargs=None):
threading.Thread.__init__(self)
self.request_handler = request_handler
self.test_object = test_object
self.tls = tls
self.server_kwargs = server_kwargs or {}

def run(self):
if self.tls:
certfile, keyfile, password = self.tls
self.server = create_https_server(
certfile, keyfile, password,
request_handler=self.request_handler,
**self.server_kwargs
)
else:
self.server = HTTPServer(('localhost', 0), self.request_handler)
self.server = HTTPServer(('localhost', 0), self.request_handler,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You must also modify create_https_server appropriately

**self.server_kwargs)
self.test_object.HOST, self.test_object.PORT = self.server.socket.getsockname()
self.test_object.server_started.set()
self.test_object = None
Expand All @@ -113,12 +116,14 @@ class BaseTestCase(unittest.TestCase):

# Optional tuple (certfile, keyfile, password) to use for HTTPS servers.
tls = None
server_kwargs = None

def setUp(self):
self._threads = threading_helper.threading_setup()
os.environ = os_helper.EnvironmentVarGuard()
self.server_started = threading.Event()
self.thread = TestServerThread(self, self.request_handler, self.tls)
self.thread = TestServerThread(self, self.request_handler, self.tls,
self.server_kwargs)
self.thread.start()
self.server_started.wait()

Expand Down Expand Up @@ -824,6 +829,17 @@ def test_path_without_leading_slash(self):
self.tempdir_name + "/?hi=1")


class CorsHTTPServerTestCase(SimpleHTTPServerTestCase):
server_kwargs = {
'response_headers': {'Access-Control-Allow-Origin': '*'}
}

def test_cors(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_cors(self):
def test_cors(self):

response = self.request(self.base_url + '/test')
self.check_status_and_reason(response, HTTPStatus.OK)
self.assertEqual(response.getheader('Access-Control-Allow-Origin'), '*')


class SocketlessRequestHandler(SimpleHTTPRequestHandler):
def __init__(self, directory=None):
request = mock.Mock()
Expand Down Expand Up @@ -1306,6 +1322,7 @@ class CommandLineTestCase(unittest.TestCase):
'tls_cert': None,
'tls_key': None,
'tls_password': None,
'response_headers': None,
}

def setUp(self):
Expand Down Expand Up @@ -1371,6 +1388,29 @@ def test_protocol_flag(self, mock_func):
mock_func.assert_called_once_with(**call_args)
mock_func.reset_mock()

@mock.patch('http.server.test')
def test_cors_flag(self, mock_func):
self.invoke_httpd('--cors')
call_args = self.args | dict(
response_headers={
'Access-Control-Allow-Origin': '*'
}
)
mock_func.assert_called_once_with(**call_args)
mock_func.reset_mock()

@mock.patch('http.server.test')
def test_header_flag(self, mock_func):
self.invoke_httpd('--header', 'h1', 'v1', '-H', 'h2', 'v2')
call_args = self.args | dict(
response_headers={
'h1': 'v1',
'h2': 'v2'
}
)
mock_func.assert_called_once_with(**call_args)
mock_func.reset_mock()

@unittest.skipIf(ssl is None, "requires ssl")
@mock.patch('http.server.test')
def test_tls_cert_and_key_flags(self, mock_func):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Add a ``--cors`` cli option to :program:`python -m http.server`. Contributed by
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also update What's New/3.15.rst

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used blurb to make this entry in NEWS.d, not knowing when it's appropriate to edit the main 3.15.rst file. I think once we know if we're doing --cors / --header , or both, I can make the appropriate update to What's New/3.15.rst

Anton I. Sipos.
Loading