From 0bd3e6495a9bf8548d069f7f2b7fe878c0f200c0 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Sun, 6 Feb 2022 21:43:36 -0800 Subject: [PATCH 01/14] Got Django working properly --- graphql_server/__init__.py | 71 +++- graphql_server/aiohttp/graphqlview.py | 177 ++++------ graphql_server/channels/__init__.py | 5 + graphql_server/channels/consumer.py | 134 ++++++++ graphql_server/channels/context.py | 36 ++ .../channels/graphql_transport_ws.py | 46 +++ graphql_server/channels/graphql_ws.py | 53 +++ graphql_server/channels/http.py | 143 ++++++++ graphql_server/django/__init__.py | 3 + graphql_server/django/views.py | 310 ++++++++++++++++++ graphql_server/graphiql_render_jinja.py | 19 ++ graphql_server/render_graphiql.py | 194 ++++------- graphql_server/sanic/graphqlview.py | 6 +- graphql_server/websockets/__init__.py | 0 graphql_server/websockets/constants.py | 2 + .../transport_ws_protocol/__init__.py | 3 + .../transport_ws_protocol/contstants.py | 8 + .../transport_ws_protocol/handlers.py | 223 +++++++++++++ .../websockets/transport_ws_protocol/types.py | 100 ++++++ .../websockets/ws_protocol/__init__.py | 4 + .../websockets/ws_protocol/constants.py | 10 + .../websockets/ws_protocol/handlers.py | 201 ++++++++++++ .../websockets/ws_protocol/types.py | 47 +++ 23 files changed, 1535 insertions(+), 260 deletions(-) create mode 100644 graphql_server/channels/__init__.py create mode 100644 graphql_server/channels/consumer.py create mode 100644 graphql_server/channels/context.py create mode 100644 graphql_server/channels/graphql_transport_ws.py create mode 100644 graphql_server/channels/graphql_ws.py create mode 100644 graphql_server/channels/http.py create mode 100644 graphql_server/django/__init__.py create mode 100644 graphql_server/django/views.py create mode 100644 graphql_server/graphiql_render_jinja.py create mode 100644 graphql_server/websockets/__init__.py create mode 100644 graphql_server/websockets/constants.py create mode 100644 graphql_server/websockets/transport_ws_protocol/__init__.py create mode 100644 graphql_server/websockets/transport_ws_protocol/contstants.py create mode 100644 graphql_server/websockets/transport_ws_protocol/handlers.py create mode 100644 graphql_server/websockets/transport_ws_protocol/types.py create mode 100644 graphql_server/websockets/ws_protocol/__init__.py create mode 100644 graphql_server/websockets/ws_protocol/constants.py create mode 100644 graphql_server/websockets/ws_protocol/handlers.py create mode 100644 graphql_server/websockets/ws_protocol/types.py diff --git a/graphql_server/__init__.py b/graphql_server/__init__.py index ee54cdb..b20f94b 100644 --- a/graphql_server/__init__.py +++ b/graphql_server/__init__.py @@ -9,6 +9,8 @@ import json from collections import namedtuple from collections.abc import MutableMapping +from dataclasses import dataclass + from typing import ( Any, Callable, @@ -56,13 +58,34 @@ # The public data structures -GraphQLParams = namedtuple("GraphQLParams", "query variables operation_name") -GraphQLResponse = namedtuple("GraphQLResponse", "results params") -ServerResponse = namedtuple("ServerResponse", "body status_code") +@dataclass +class GraphQLParams: + query: str + variables: Optional[Dict[str, Any]] = None + operation_name: Optional[str] = None + +@dataclass +class GraphQLResponse: + params: List[GraphQLParams] + results: List[AwaitableOrValue[ExecutionResult]] + +@dataclass +class ServerResponse: + body: Optional[str] + status_code: int + headers: Optional[Dict[str, str]] = None # The public helper functions +def get_schema(schema: GraphQLSchema): + if not isinstance(schema, GraphQLSchema): + # maybe the GraphQL schema is wrapped in a Graphene schema + schema = getattr(schema, "graphql_schema", None) + if not isinstance(schema, GraphQLSchema): + raise TypeError("A Schema is required to be provided to GraphQLView.") + return schema + def format_error_default(error: GraphQLError) -> Dict: """The default function for converting GraphQLError to a dictionary.""" @@ -138,7 +161,24 @@ def run_http_query( ) for params in all_params ] - return GraphQLResponse(results, all_params) + return GraphQLResponse(results=results, params=all_params) + + +def process_preflight(origin_header: Optional[str], request_method: Optional[str], accepted_methods: List[str], max_age: int) -> ServerResponse: + """ + Preflight request support for apollo-client + https://www.w3.org/TR/cors/#resource-preflight-requests + """ + if origin_header and request_method and request_method in accepted_methods: + return ServerResponse( + status_code=200, + headers={ + "Access-Control-Allow-Origin": origin_header, + "Access-Control-Allow-Methods": ", ".join(accepted_methods), + "Access-Control-Max-Age": str(max_age), + }, + ) + return ServerResponse(status_code=400) def json_encode(data: Union[Dict, List], pretty: bool = False) -> str: @@ -184,18 +224,31 @@ def encode_execution_results( if not is_batch: result = result[0] - return ServerResponse(encode(result), status_code) + return ServerResponse(body=encode(result), status_code=status_code) -def load_json_body(data): - # type: (str) -> Union[Dict, List] +def load_json_body(data: str, batch: bool = False) -> Union[Dict, List]: """Load the request body as a dictionary or a list. The body must be passed in a string and will be deserialized from JSON, raising an HttpQueryError in case of invalid JSON. """ try: - return json.loads(data) + request_json = json.loads(data) + if batch: + assert isinstance(request_json, list), ( + "Batch requests should receive a list, but received {}." + ).format(repr(request_json)) + assert ( + len(request_json) > 0 + ), "Received an empty list in the batch request." + else: + assert isinstance( + request_json, dict + ), "The received data is not a valid JSON query." + return request_json + except AssertionError as e: + raise HttpQueryError(400, str(e)) except Exception: raise HttpQueryError(400, "POST body sent invalid JSON.") @@ -222,7 +275,7 @@ def get_graphql_params(data: Dict, query_data: Dict) -> GraphQLParams: # document_id = data.get('documentId') operation_name = data.get("operationName") or query_data.get("operationName") - return GraphQLParams(query, load_json_variables(variables), operation_name) + return GraphQLParams(query=query, variables=load_json_variables(variables), operation_name=operation_name) def load_json_variables(variables: Optional[Union[str, Dict]]) -> Optional[Dict]: diff --git a/graphql_server/aiohttp/graphqlview.py b/graphql_server/aiohttp/graphqlview.py index d98becd..a7739b8 100644 --- a/graphql_server/aiohttp/graphqlview.py +++ b/graphql_server/aiohttp/graphqlview.py @@ -1,79 +1,65 @@ -import copy -from collections.abc import MutableMapping from functools import partial -from typing import List +from typing import Type, Any, Optional, Collection from aiohttp import web from graphql import ExecutionResult, GraphQLError, specified_rules +from graphql.execution import Middleware from graphql.type.schema import GraphQLSchema +from graphql.validation import ASTValidationRule from graphql_server import ( - GraphQLParams, HttpQueryError, + get_schema, encode_execution_results, format_error_default, json_encode, load_json_body, run_http_query, + process_preflight, ) from graphql_server.render_graphiql import ( - GraphiQLConfig, - GraphiQLData, GraphiQLOptions, - render_graphiql_async, + render_graphiql_sync, ) +from typing import Dict, Any class GraphQLView: - schema = None - root_value = None - context = None - pretty = False - graphiql = False - graphiql_version = None - graphiql_template = None - graphiql_html_title = None - middleware = None - validation_rules = None - batch = False - jinja_env = None - max_age = 86400 - enable_async = False - subscriptions = None - headers = None - default_query = None - header_editor_enabled = None - should_persist_headers = None accepted_methods = ["GET", "POST", "PUT", "DELETE"] format_error = staticmethod(format_error_default) encode = staticmethod(json_encode) - def __init__(self, **kwargs): - super(GraphQLView, self).__init__() - for key, value in kwargs.items(): - if hasattr(self, key): - setattr(self, key, value) - - if not isinstance(self.schema, GraphQLSchema): - # maybe the GraphQL schema is wrapped in a Graphene schema - self.schema = getattr(self.schema, "graphql_schema", None) - if not isinstance(self.schema, GraphQLSchema): - raise TypeError("A Schema is required to be provided to GraphQLView.") + def __init__(self, schema: GraphQLSchema, *, + root_value: Any = None, + pretty: bool = False, + graphiql: bool = True, + middleware: Optional[Middleware] = None, + validation_rules: Optional[Collection[Type[ASTValidationRule]]] = None, + batch: bool = False, + max_age: int = 86400, + enable_async: bool = False, + graphiql_options: Optional[GraphiQLOptions] = None, + ): + self.schema = get_schema(schema) + self.root_value = root_value + self.pretty = pretty + self.graphiql = graphiql + self.graphiql_options = graphiql_options + self.middleware = middleware + self.validation_rules = validation_rules + self.batch = batch + self.max_age = max_age + self.enable_async = enable_async + + render_graphiql = render_graphiql_sync def get_root_value(self): return self.root_value def get_context(self, request): - context = ( - copy.copy(self.context) - if self.context and isinstance(self.context, MutableMapping) - else {} - ) - if isinstance(context, MutableMapping) and "request" not in context: - context.update({"request": request}) - return context + return {"request": request} def get_middleware(self): return self.middleware @@ -107,45 +93,33 @@ async def parse_body(request): return {} - # TODO: - # use this method to replace flask and sanic - # checks as this is equivalent to `should_display_graphiql` and - # `request_wants_html` methods. - def is_graphiql(self, request): - return all( - [ - self.graphiql, - request.method.lower() == "get", - "raw" not in request.query, - any( - [ - "text/html" in request.headers.get("accept", {}), - "*/*" in request.headers.get("accept", {}), - ] - ), - ] + def is_graphiql(self, request_method, is_raw, accept_headers): + return (self.graphiql and request_method == "get" + and not is_raw and ("text/html" in accept_headers or "*/*" in accept_headers), ) - # TODO: Same stuff as above method. - def is_pretty(self, request): - return any( - [self.pretty, self.is_graphiql(request), request.query.get("pretty")] - ) + def should_prettify(self, is_graphiql, pretty_query): + return self.pretty or is_graphiql or pretty_query async def __call__(self, request): try: data = await self.parse_body(request) request_method = request.method.lower() - is_graphiql = self.is_graphiql(request) - is_pretty = self.is_pretty(request) + accept_headers = request.headers.get("accept", {}) + is_graphiql = self.is_graphiql(request_method, request.query.get("raw"), accept_headers) + is_pretty = self.should_prettify(is_graphiql, request.query.get("pretty")) - # TODO: way better than if-else so better - # implement this too on flask and sanic if request_method == "options": - return self.process_preflight(request) + headers = request.headers + origin = headers.get("Origin", "") + method = headers.get("Access-Control-Request-Method", "").upper() + response = process_preflight(origin, method, self.accepted_methods, self.max_age) + return web.Response( + status=response.status_code, + headers = response.headers + ) - all_params: List[GraphQLParams] - execution_results, all_params = run_http_query( + graphql_response = run_http_query( self.schema, request_method, data, @@ -163,12 +137,12 @@ async def __call__(self, request): exec_res = ( [ ex if ex is None or isinstance(ex, ExecutionResult) else await ex - for ex in execution_results + for ex in graphql_response.results ] if self.enable_async - else execution_results + else graphql_response.results ) - result, status_code = encode_execution_results( + response = encode_execution_results( exec_res, is_batch=isinstance(data, list), format_error=self.format_error, @@ -176,33 +150,16 @@ async def __call__(self, request): ) if is_graphiql: - graphiql_data = GraphiQLData( - result=result, - query=getattr(all_params[0], "query"), - variables=getattr(all_params[0], "variables"), - operation_name=getattr(all_params[0], "operation_name"), - subscription_url=self.subscriptions, - headers=self.headers, - ) - graphiql_config = GraphiQLConfig( - graphiql_version=self.graphiql_version, - graphiql_template=self.graphiql_template, - graphiql_html_title=self.graphiql_html_title, - jinja_env=self.jinja_env, - ) - graphiql_options = GraphiQLOptions( - default_query=self.default_query, - header_editor_enabled=self.header_editor_enabled, - should_persist_headers=self.should_persist_headers, - ) - source = await render_graphiql_async( - data=graphiql_data, config=graphiql_config, options=graphiql_options + source = self.render_graphiql( + result=response.body, + params=graphql_response.all_params[0], + options=self.graphiql_options ) return web.Response(text=source, content_type="text/html") return web.Response( - text=result, - status=status_code, + text=response.result, + status=response.status_code, content_type="application/json", ) @@ -215,26 +172,6 @@ async def __call__(self, request): content_type="application/json", ) - def process_preflight(self, request): - """ - Preflight request support for apollo-client - https://www.w3.org/TR/cors/#resource-preflight-requests - """ - headers = request.headers - origin = headers.get("Origin", "") - method = headers.get("Access-Control-Request-Method", "").upper() - - if method and method in self.accepted_methods: - return web.Response( - status=200, - headers={ - "Access-Control-Allow-Origin": origin, - "Access-Control-Allow-Methods": ", ".join(self.accepted_methods), - "Access-Control-Max-Age": str(self.max_age), - }, - ) - return web.Response(status=400) - @classmethod def attach(cls, app, *, route_path="/graphql", route_name="graphql", **kwargs): view = cls(**kwargs) diff --git a/graphql_server/channels/__init__.py b/graphql_server/channels/__init__.py new file mode 100644 index 0000000..fb2ba77 --- /dev/null +++ b/graphql_server/channels/__init__.py @@ -0,0 +1,5 @@ +from .consumer import GraphQLWSConsumer +from .context import GraphQLChannelsContext +from .http import GraphQLHttpConsumer + +__all__ = ["GraphQLWSConsumer", "GraphQLChannelsContext", "GraphQLHttpConsumer"] diff --git a/graphql_server/channels/consumer.py b/graphql_server/channels/consumer.py new file mode 100644 index 0000000..135a735 --- /dev/null +++ b/graphql_server/channels/consumer.py @@ -0,0 +1,134 @@ +"""GraphQLWebSocketRouter +This is a simple router class that might be better placed as part of Channels itself. +It's a simple "SubProtocolRouter" that selects the websocket subprotocol based +on preferences and client support. Then it hands off to the appropriate consumer. +""" +from datetime import timedelta +from typing import Any, Optional, Sequence, Union + +from django.http import HttpRequest, HttpResponse +from django.urls import re_path + +from channels.generic.websocket import ( + AsyncJsonWebsocketConsumer, + AsyncWebsocketConsumer, +) +from graphql import GraphQLSchema +from ..websockets.constants import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL + +from .context import GraphQLChannelsContext +from .graphql_transport_ws import GraphQLTransportWSHandler +from .graphql_ws import GraphQLWSHandler + + +class GraphQLWSConsumer(AsyncJsonWebsocketConsumer): + """ + A channels websocket consumer for GraphQL + + This handles the connections, then hands off to the appropriate + handler based on the subprotocol. + To use this, place it in your ProtocolTypeRouter for your channels project, e.g: + + ``` + from graphql_ws.channels import GraphQLWSConsumer + from channels.routing import ProtocolTypeRouter, URLRouter + from django.core.asgi import get_asgi_application + application = ProtocolTypeRouter({ + "http": URLRouter([ + re_path("^", get_asgi_application()), + ]), + "websocket": URLRouter([ + re_path("^ws/graphql", GraphQLWSConsumer(schema=schema)) + ]), + }) + ``` + """ + + graphql_transport_ws_handler_class = GraphQLTransportWSHandler + graphql_ws_handler_class = GraphQLWSHandler + _handler: Union[GraphQLWSHandler, GraphQLTransportWSHandler] + + def __init__( + self, + schema: GraphQLSchema, + graphiql: bool = True, + keep_alive: bool = False, + keep_alive_interval: float = 1, + debug: bool = False, + subscription_protocols=(GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL), + connection_init_wait_timeout: timedelta = None, + ): + if connection_init_wait_timeout is None: + connection_init_wait_timeout = timedelta(minutes=1) + self.connection_init_wait_timeout = connection_init_wait_timeout + self.schema = schema + self.graphiql = graphiql + self.keep_alive = keep_alive + self.keep_alive_interval = keep_alive_interval + self.debug = debug + self.protocols = subscription_protocols + + super().__init__() + + def pick_preferred_protocol( + self, accepted_subprotocols: Sequence[str] + ) -> Optional[str]: + intersection = set(accepted_subprotocols) & set(self.protocols) + sorted_intersection = sorted(intersection, key=accepted_subprotocols.index) + return next(iter(sorted_intersection), None) + + async def connect(self): + preferred_protocol = self.pick_preferred_protocol(self.scope["subprotocols"]) + + if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL: + self._handler = self.graphql_transport_ws_handler_class( + schema=self.schema, + debug=self.debug, + connection_init_wait_timeout=self.connection_init_wait_timeout, + get_context=self.get_context, + get_root_value=self.get_root_value, + ws=self, + ) + elif preferred_protocol == GRAPHQL_WS_PROTOCOL: + self._handler = self.graphql_ws_handler_class( + schema=self.schema, + debug=self.debug, + keep_alive=self.keep_alive, + keep_alive_interval=self.keep_alive_interval, + get_context=self.get_context, + get_root_value=self.get_root_value, + ws=self, + ) + else: + # Subprotocol not acceptable + return await self.close(code=4406) + + await self._handler.handle() + + async def receive(self, text_data=None, bytes_data=None, **kwargs): + try: + await super().receive(text_data=text_data, bytes_data=bytes_data, **kwargs) + except ValueError: + await self._handler.handle_invalid_message( + "WebSocket message type must be text" + ) + + async def receive_json(self, content, **kwargs): + await self._handler.handle_message(content) + + async def disconnect(self, code): + await self._handler.handle_disconnect(code) + + async def get_root_value( + self, request: HttpRequest = None, consumer: AsyncWebsocketConsumer = None + ) -> Optional[Any]: + return None + + async def get_context( + self, + request: Union[HttpRequest, AsyncJsonWebsocketConsumer] = None, + response: Optional[HttpResponse] = None, + ) -> Optional[Any]: + return GraphQLChannelsContext( + request=request or self, response=response, scope=self.scope + ) diff --git a/graphql_server/channels/context.py b/graphql_server/channels/context.py new file mode 100644 index 0000000..caedf7a --- /dev/null +++ b/graphql_server/channels/context.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Union, Dict, Any + +from django.http import HttpRequest, HttpResponse + + +if TYPE_CHECKING: + from .consumer import GraphQLWSConsumer + + +@dataclass +class GraphQLChannelsContext: + """ + A Channels context for GraphQL + """ + + request: Optional[Union[HttpRequest, "GraphQLWSConsumer"]] + response: Optional[HttpResponse] + scope: Optional[Dict[str, Any]] + + @property + def ws(self): + return self.request + + def __getitem__(self, key): + # __getitem__ override needed to avoid issues for who's + # using info.context["request"] + return super().__getattribute__(key) + + def get(self, key): + """Enable .get notation for accessing the request""" + return super().__getattribute__(key) + + @property + def user(self): + return self.scope["user"] diff --git a/graphql_server/channels/graphql_transport_ws.py b/graphql_server/channels/graphql_transport_ws.py new file mode 100644 index 0000000..8b946ef --- /dev/null +++ b/graphql_server/channels/graphql_transport_ws.py @@ -0,0 +1,46 @@ +from datetime import timedelta +from typing import Any, Optional + +from channels.generic.websocket import AsyncJsonWebsocketConsumer +from graphql import GraphQLSchema +from graphql_server.websockets.transport_ws_protocol import ( + BaseGraphQLTransportWSHandler, +) + + +class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler): + def __init__( + self, + schema: GraphQLSchema, + debug: bool, + connection_init_wait_timeout: timedelta, + get_context, + get_root_value, + ws: AsyncJsonWebsocketConsumer, + ): + super().__init__(schema, debug, connection_init_wait_timeout) + self._get_context = get_context + self._get_root_value = get_root_value + self._ws = ws + + async def get_context(self) -> Any: + return await self._get_context(request=self._ws) + + async def get_root_value(self) -> Any: + return await self._get_root_value(request=self._ws) + + async def send_json(self, data: dict) -> None: + await self._ws.send_json(data) + + async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: + # Close messages are not part of the ASGI ref yet + await self._ws.close(code=code) + + async def handle_request(self) -> Any: + await self._ws.accept( + subprotocol=BaseGraphQLTransportWSHandler.GRAPHQL_TRANSPORT_WS_PROTOCOL + ) + + async def handle_disconnect(self, code): + for operation_id in list(self.subscriptions.keys()): + await self.cleanup_operation(operation_id) diff --git a/graphql_server/channels/graphql_ws.py b/graphql_server/channels/graphql_ws.py new file mode 100644 index 0000000..45e08eb --- /dev/null +++ b/graphql_server/channels/graphql_ws.py @@ -0,0 +1,53 @@ +from contextlib import suppress +from typing import Any, Optional + +from channels.generic.websocket import AsyncJsonWebsocketConsumer + +from graphql import GraphQLSchema +from graphql_server.websockets.ws_protocol import BaseGraphQLWSHandler, OperationMessage + + +class GraphQLWSHandler(BaseGraphQLWSHandler): + def __init__( + self, + schema: GraphQLSchema, + debug: bool, + keep_alive: bool, + keep_alive_interval: float, + get_context, + get_root_value, + ws: AsyncJsonWebsocketConsumer, + ): + super().__init__(schema, debug, keep_alive, keep_alive_interval) + self._get_context = get_context + self._get_root_value = get_root_value + self._ws = ws + + async def get_context(self) -> Any: + return await self._get_context(request=self._ws) + + async def get_root_value(self) -> Any: + return await self._get_root_value(request=self._ws) + + async def send_json(self, data: OperationMessage) -> None: + await self._ws.send_json(data) + + async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: + # Close messages are not part of the ASGI ref yet + await self._ws.close(code=code) + + async def handle_request(self) -> Any: + await self._ws.accept(subprotocol=BaseGraphQLWSHandler.PROTOCOL) + + async def handle_disconnect(self, code): + if self.keep_alive_task: + self.keep_alive_task.cancel() + with suppress(BaseException): + await self.keep_alive_task + + for operation_id in list(self.subscriptions.keys()): + await self.cleanup_operation(operation_id) + + async def handle_invalid_message(self, error_message: str) -> None: + # Do nothing + return diff --git a/graphql_server/channels/http.py b/graphql_server/channels/http.py new file mode 100644 index 0000000..d4ba271 --- /dev/null +++ b/graphql_server/channels/http.py @@ -0,0 +1,143 @@ +"""GraphQLHTTPHandler +A consumer to provide a graphql endpoint, and optionally graphiql. +""" +import json +from pathlib import Path +from typing import Any, Optional + +from channels.generic.http import AsyncHttpConsumer + +from graphql import GraphQLSchema, ExecutionResult + + +class GraphQLHttpConsumer(AsyncHttpConsumer): + """ + A consumer to provide a view for GraphQL over HTTP. + To use this, place it in your ProtocolTypeRouter for your channels project: + + ``` + from graphql_ws.channels import GraphQLHttpConsumer + from channels.routing import ProtocolTypeRouter, URLRouter + from django.core.asgi import get_asgi_application + application = ProtocolTypeRouter({ + "http": URLRouter([ + re_path("^graphql", GraphQLHttpConsumer(schema=schema)), + re_path("^", get_asgi_application()), + ]), + }) + ``` + """ + + def __init__( + self, + schema: GraphQLSchema, + graphiql: bool = True, + ): + self.schema = schema + self.graphiql = graphiql + super().__init__() + + # def headers(self): + # return { + # header_name.decode("utf-8").lower(): header_value.decode("utf-8") + # for header_name, header_value in self.scope["headers"] + # } + + # async def parse_multipart_body(self, body): + # await self.send_response(500, "Unable to parse the multipart body") + # return None + + # async def get_graphql_params(self, data): + # query = data.get("query") + # variables = data.get("variables") + # id = data.get("id") + + # if variables and isinstance(variables, str): + # try: + # variables = json.loads(variables) + # except Exception: + # await self.send_response(500, b"Variables are invalid JSON.") + # return None + # operation_name = data.get("operationName") + + # return query, variables, operation_name, id + + # async def get_request_data(self, body) -> Optional[Any]: + # if self.headers.get("content-type", "").startswith("multipart/form-data"): + # data = await self.parse_multipart_body(body) + # if data is None: + # return None + # else: + # try: + # data = json.loads(body) + # except json.JSONDecodeError: + # await self.send_response(500, b"Unable to parse request body as JSON") + # return None + + # query, variables, operation_name, id = self.get_graphql_params(data) + # if not query: + # await self.send_response(500, b"No GraphQL query found in the request") + # return None + + # return query, variables, operation_name, id + + # async def post(self, body): + # request_data = await self.get_request_data(body) + # if request_data is None: + # return + # context = await self.get_context() + # root_value = await self.get_root_value() + + # result = await self.schema.execute( + # query=request_data.query, + # root_value=root_value, + # variable_values=request_data.variables, + # context_value=context, + # operation_name=request_data.operation_name, + # ) + + # response_data = self.process_result(result) + # await self.send_response( + # 200, + # json.dumps(response_data).encode("utf-8"), + # headers=[(b"Content-Type", b"application/json")], + # ) + + # def graphiql_html_file_path(self) -> Path: + # return Path(__file__).parent.parent.parent / "static" / "graphiql.html" + + # async def render_graphiql(self, body): + # html_string = self.graphiql_html_file_path.read_text() + # html_string = html_string.replace("{{ SUBSCRIPTION_ENABLED }}", "true") + # await self.send_response( + # 200, html_string.encode("utf-8"), headers=[(b"Content-Type", b"text/html")] + # ) + + # def should_render_graphiql(self): + # return bool(self.graphiql and "text/html" in self.headers.get("accept", "")) + + # async def get(self, body): + # # if self.should_render_graphiql(): + # # return await self.render_graphiql(body) + # # else: + # await self.send_response( + # 200, "{}", headers=[(b"Content-Type", b"text/json")] + # ) + + async def handle(self, body): + # if self.scope["method"] == "GET": + # return await self.get(body) + # if self.scope["method"] == "POST": + # return await self.post(body) + await self.send_response( + 200, b"Method not allowed", headers=[b"Allow", b"GET, POST"] + ) + + # async def get_root_value(self) -> Any: + # return None + + # async def get_context(self) -> Any: + # return None + + # def process_result(self, result: ExecutionResult): + # return result.formatted diff --git a/graphql_server/django/__init__.py b/graphql_server/django/__init__.py new file mode 100644 index 0000000..a95776c --- /dev/null +++ b/graphql_server/django/__init__.py @@ -0,0 +1,3 @@ +from .views import GraphQLView, AsyncGraphQLView + +__all__ = ["GraphQLView", "AsyncGraphQLView"] diff --git a/graphql_server/django/views.py b/graphql_server/django/views.py new file mode 100644 index 0000000..d787f09 --- /dev/null +++ b/graphql_server/django/views.py @@ -0,0 +1,310 @@ +import asyncio +import re +from functools import partial +from http.client import HTTPResponse +from typing import Type, Any, Optional, Collection + +from graphql import ExecutionResult, GraphQLError, specified_rules +from graphql.execution import Middleware +from graphql.type.schema import GraphQLSchema +from graphql.validation import ASTValidationRule +from django.views.generic import View +from django.http import HttpResponse, HttpRequest, HttpResponseBadRequest +from django.utils.decorators import classonlymethod, method_decorator +from django.views.decorators.csrf import csrf_exempt + +from graphql_server import ( + HttpQueryError, + get_schema, + encode_execution_results, + format_error_default, + json_encode, + load_json_body, + run_http_query, + process_preflight, +) +from graphql_server.render_graphiql import ( + GraphiQLOptions, + render_graphiql_sync, +) + + +def get_accepted_content_types(request): + def qualify(x): + parts = x.split(";", 1) + if len(parts) == 2: + match = re.match(r"(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)", parts[1]) + if match: + return parts[0].strip(), float(match.group(2)) + return parts[0].strip(), 1 + + raw_content_types = request.META.get("HTTP_ACCEPT", "*/*").split(",") + qualified_content_types = map(qualify, raw_content_types) + return list( + x[0] for x in sorted(qualified_content_types, key=lambda x: x[1], reverse=True) + ) + + +class GraphQLView(View): + + accepted_methods = ["GET", "POST", "PUT", "DELETE"] + + format_error = staticmethod(format_error_default) + encode = staticmethod(json_encode) + + schema: GraphQLSchema = None + root_value: Any = None + pretty: bool = False + graphiql: bool = True + middleware: Optional[Middleware] = None + validation_rules: Optional[Collection[Type[ASTValidationRule]]] = None + batch: bool = False + fetch_query_on_load: bool = True + max_age: int = 86400 + graphiql_options: Optional[GraphiQLOptions] = None + + def __init__(self, schema: GraphQLSchema, + root_value: Any = None, + pretty: bool = False, + graphiql: bool = True, + middleware: Optional[Middleware] = None, + validation_rules: Optional[Collection[Type[ASTValidationRule]]] = None, + batch: bool = False, + fetch_query_on_load: bool = True, + max_age: int = 86400, + graphiql_options: Optional[GraphiQLOptions] = None, + ): + self.schema = get_schema(schema) + self.root_value = root_value + self.pretty = pretty + self.graphiql = graphiql + self.graphiql_options = graphiql_options + self.middleware = middleware + self.validation_rules = validation_rules + self.batch = batch + self.fetch_query_on_load = fetch_query_on_load + self.max_age = max_age + + def render_graphiql(self, *args, **kwargs): + return render_graphiql_sync(*args, **kwargs) + + def get_root_value(self, request: HttpRequest): + return self.root_value + + def get_context(self, request: HttpRequest): + return request + + def get_middleware(self): + return self.middleware + + def get_validation_rules(self): + if self.validation_rules is None: + return specified_rules + return self.validation_rules + + def parse_body(self, request: HttpRequest): + content_type = request.content_type + + if content_type == "application/graphql": + return {"query": request.body.decode()} + + elif content_type == "application/json": + try: + body = request.body.decode("utf-8") + except Exception as e: + raise HttpQueryError(400, str(e)) + + return load_json_body(body, self.batch) + + elif content_type in [ + "application/x-www-form-urlencoded", + "multipart/form-data", + ]: + return request.POST + + return {} + + @classmethod + def request_prefers_html(cls, request: HttpRequest): + accepted = get_accepted_content_types(request) + accepted_length = len(accepted) + # the list will be ordered in preferred first - so we have to make + # sure the most preferred gets the highest number + html_priority = ( + accepted_length - accepted.index("text/html") + if "text/html" in accepted + else 0 + ) + json_priority = ( + accepted_length - accepted.index("application/json") + if "application/json" in accepted + else 0 + ) + + return html_priority > json_priority + + def is_graphiql(self, request_method: str, is_raw: bool, prefers_html: bool): + return (self.graphiql and request_method == "get" + and not is_raw and prefers_html + ) + + def should_prettify(self, is_graphiql: bool, pretty_in_query: bool): + return self.pretty or is_graphiql or pretty_in_query + + @method_decorator(csrf_exempt) + def dispatch(self, request: HttpRequest, *args, **kwargs): + try: + data = self.parse_body(request) + request_method = request.method.lower() + prefers_html = self.request_prefers_html(request) + is_graphiql = self.is_graphiql(request_method, "raw" in request.GET, prefers_html) + is_pretty = self.should_prettify(is_graphiql, request.GET.get("pretty")) + + if request_method == "options": + headers = request.headers + origin = headers.get("Origin", "") + method = headers.get("Access-Control-Request-Method", "").upper() + response = process_preflight(origin, method, self.accepted_methods, self.max_age) + return HTTPResponse( + status=response.status_code, + headers=response.headers + ) + + graphql_response = run_http_query( + self.schema, + request_method, + data, + query_data=request.GET, + batch_enabled=self.batch, + catch=is_graphiql, + # Execute options + run_sync=True, + root_value=self.get_root_value(request), + context_value=self.get_context(request), + middleware=self.get_middleware(), + validation_rules=self.get_validation_rules(), + ) + + response = encode_execution_results( + graphql_response.results, + is_batch=isinstance(data, list), + format_error=self.format_error, + encode=partial(self.encode, pretty=is_pretty), # noqa: ignore + ) + + if is_graphiql: + source = self.render_graphiql( + result=response.body, + params=graphql_response.params[0], + options=self.graphiql_options + ) + return HttpResponse( + content=source, + content_type="text/html" + ) + + return HttpResponse( + content=response.body, + content_type="application/json", + status=response.status_code, + ) + + except HttpQueryError as err: + parsed_error = GraphQLError(err.message) + return HttpResponse( + content=self.encode(dict(errors=[self.format_error(parsed_error)])), + content_type="application/json", + status=err.status_code, + headers=err.headers, + ) + + +class AsyncGraphQLView(GraphQLView): + @classonlymethod + def as_view(cls, **initkwargs): + # This code tells django that this view is async, see docs here: + # https://docs.djangoproject.com/en/3.1/topics/async/#async-views + + view = super().as_view(**initkwargs) + view._is_coroutine = asyncio.coroutines._is_coroutine + return view + + @method_decorator(csrf_exempt) + async def dispatch(self, request, *args, **kwargs): + try: + data = self.parse_body(request) + request_method = request.method.lower() + prefers_html = self.request_prefers_html(request) + is_graphiql = self.is_graphiql(request_method, "raw" in request.GET, prefers_html) + is_pretty = self.should_prettify(is_graphiql, request.GET.get("pretty")) + + if request_method == "options": + headers = request.headers + origin = headers.get("Origin", "") + method = headers.get("Access-Control-Request-Method", "").upper() + response = process_preflight(origin, method, self.accepted_methods, self.max_age) + return HTTPResponse( + status=response.status_code, + headers=response.headers + ) + + graphql_response = run_http_query( + self.schema, + request_method, + data, + query_data=request.GET, + batch_enabled=self.batch, + catch=is_graphiql, + # Execute options + run_sync=False, + root_value=await self.get_root_value(request), + context_value=await self.get_context(request), + middleware=self.get_middleware(), + validation_rules=self.get_validation_rules(), + ) + + exec_res = ( + [ + ex if ex is None or isinstance(ex, ExecutionResult) else await ex + for ex in graphql_response.results + ] + ) + + response = encode_execution_results( + exec_res, + is_batch=isinstance(data, list), + format_error=self.format_error, + encode=partial(self.encode, pretty=is_pretty), # noqa: ignore + ) + + if is_graphiql: + source = self.render_graphiql( + result=response.body, + params=graphql_response.params[0], + options=self.graphiql_options + ) + return HttpResponse( + content=source, + content_type="text/html" + ) + + return HttpResponse( + content=response.body, + content_type="application/json", + status=response.status_code, + ) + + except HttpQueryError as err: + parsed_error = GraphQLError(err.message) + return HttpResponse( + content=self.encode(dict(errors=[self.format_error(parsed_error)])), + content_type="application/json", + status=err.status_code, + headers=err.headers, + ) + + async def get_root_value(self, request: HttpRequest) -> Any: + return None + + async def get_context(self, request: HttpRequest) -> Any: + return request diff --git a/graphql_server/graphiql_render_jinja.py b/graphql_server/graphiql_render_jinja.py new file mode 100644 index 0000000..91ba416 --- /dev/null +++ b/graphql_server/graphiql_render_jinja.py @@ -0,0 +1,19 @@ + +async def render_graphiql_async( + data: GraphiQLData, + config: GraphiQLConfig, + options: Optional[GraphiQLOptions] = None, +) -> str: + graphiql_template, template_vars = _render_graphiql(data, config, options) + jinja_env: Optional[Environment] = config.get("jinja_env") + + if jinja_env: + # This method returns a Template. See https://jinja.palletsprojects.com/en/2.11.x/api/#jinja2.Template + template = jinja_env.from_string(graphiql_template) + if jinja_env.is_async: # type: ignore + source = await template.render_async(**template_vars) + else: + source = template.render(**template_vars) + else: + source = simple_renderer(graphiql_template, **template_vars) + return source diff --git a/graphql_server/render_graphiql.py b/graphql_server/render_graphiql.py index c942300..37f414e 100644 --- a/graphql_server/render_graphiql.py +++ b/graphql_server/render_graphiql.py @@ -3,8 +3,8 @@ import json import re from typing import Any, Dict, Optional, Tuple +from graphql_server import GraphQLParams -from jinja2 import Environment from typing_extensions import TypedDict GRAPHIQL_VERSION = "1.0.3" @@ -20,7 +20,7 @@ - {{graphiql_html_title}} + {{html_title}} @@ -77,9 +77,10 @@ } // Configure the subscription client let subscriptionsFetcher = null; - if ('{{subscription_url}}') { + let subscriptionUrl = {{subscription_url}}; + if (subscriptionUrl) { let subscriptionsClient = new SubscriptionsTransportWs.SubscriptionClient( - '{{ subscription_url }}', + subscriptionUrl, { reconnect: true } ); subscriptionsFetcher = GraphiQLSubscriptionsFetcher.graphQLFetcher( @@ -134,14 +135,14 @@ onEditVariables: onEditVariables, onEditHeaders: onEditHeaders, onEditOperationName: onEditOperationName, - query: {{query|tojson}}, - response: {{result|tojson}}, - variables: {{variables|tojson}}, - headers: {{headers|tojson}}, - operationName: {{operation_name|tojson}}, - defaultQuery: {{default_query|tojson}}, - headerEditorEnabled: {{header_editor_enabled|tojson}}, - shouldPersistHeaders: {{should_persist_headers|tojson}} + query: {{query}}, + response: {{result}}, + variables: {{variables}}, + headers: {{headers}}, + operationName: {{operation_name}}, + defaultQuery: {{default_query}}, + headerEditorEnabled: {{header_editor_enabled}}, + shouldPersistHeaders: {{should_persist_headers}} }), document.getElementById('graphiql') ); @@ -150,55 +151,15 @@ """ -class GraphiQLData(TypedDict): - """GraphiQL ReactDom Data - - Has the following attributes: - - subscription_url - The GraphiQL socket endpoint for using subscriptions in graphql-ws. - headers - An optional GraphQL string to use as the initial displayed request headers, - if None is provided, the stored headers will be used. - """ - - query: Optional[str] - variables: Optional[str] - operation_name: Optional[str] - result: Optional[str] - subscription_url: Optional[str] - headers: Optional[str] - - -class GraphiQLConfig(TypedDict): - """GraphiQL Extra Config +class GraphiQLOptions(TypedDict): + """GraphiQL options to display on the UI. Has the following attributes: graphiql_version The version of the provided GraphiQL package. - graphiql_template - Inject a Jinja template string to customize GraphiQL. graphiql_html_title Replace the default html title on the GraphiQL. - jinja_env - Sets jinja environment to be used to process GraphiQL template. - If Jinja’s async mode is enabled (by enable_async=True), - uses Template.render_async instead of Template.render. - If environment is not set, fallbacks to simple regex-based renderer. - """ - - graphiql_version: Optional[str] - graphiql_template: Optional[str] - graphiql_html_title: Optional[str] - jinja_env: Optional[Environment] - - -class GraphiQLOptions(TypedDict): - """GraphiQL options to display on the UI. - - Has the following attributes: - default_query An optional GraphQL string to use when no query is provided and no stored query exists from a previous session. If None is provided, GraphiQL @@ -209,11 +170,30 @@ class GraphiQLOptions(TypedDict): should_persist_headers An optional boolean which enables to persist headers to storage when true. Defaults to false. + subscription_url + The GraphiQL socket endpoint for using subscriptions in graphql-ws. + headers + An optional GraphQL string to use as the initial displayed request headers, + if None is provided, the stored headers will be used. """ + html_title: Optional[str] + graphiql_version: Optional[str] default_query: Optional[str] header_editor_enabled: Optional[bool] should_persist_headers: Optional[bool] + subscription_url: Optional[str] + headers: Optional[str] + +GRAPHIQL_DEFAULT_OPTIONS: GraphiQLOptions = { + "html_title": "GraphiQL", + "graphiql_version": GRAPHIQL_VERSION, + "default_query": "", + "header_editor_enabled": True, + "should_persist_headers": False, + "subscription_url": None, + "headers": "" +} def escape_js_value(value: Any) -> Any: @@ -229,44 +209,28 @@ def escape_js_value(value: Any) -> Any: return value -def process_var(template: str, name: str, value: Any, jsonify=False) -> str: - pattern = r"{{\s*" + name + r"(\s*|[^}]+)*\s*}}" - if jsonify and value not in ["null", "undefined"]: +def tojson(value): + if value not in ["true", "false", "null", "undefined"]: value = json.dumps(value) - value = escape_js_value(value) - - return re.sub(pattern, value, template) + # value = escape_js_value(value) + return value def simple_renderer(template: str, **values: Dict[str, Any]) -> str: - replace = [ - "graphiql_version", - "graphiql_html_title", - "subscription_url", - "header_editor_enabled", - "should_persist_headers", - ] - replace_jsonify = [ - "query", - "result", - "variables", - "operation_name", - "default_query", - "headers", - ] + def get_var(match_obj): + var_name = match_obj.group(1) + if var_name is not None: + return values[var_name] + return "" - for r in replace: - template = process_var(template, r, values.get(r, "")) + pattern = r"{{\s*([^}]+)\s*}}" - for r in replace_jsonify: - template = process_var(template, r, values.get(r, ""), True) + return re.sub(pattern, get_var, template) - return template - -def _render_graphiql( - data: GraphiQLData, - config: GraphiQLConfig, +def get_template_vars( + data: str, + params: GraphQLParams, options: Optional[GraphiQLOptions] = None, ) -> Tuple[str, Dict[str, Any]]: """When render_graphiql receives a request which does not Accept JSON, but does @@ -274,57 +238,31 @@ def _render_graphiql( When shown, it will be pre-populated with the result of having executed the requested query. """ - graphiql_version = config.get("graphiql_version") or GRAPHIQL_VERSION - graphiql_template = config.get("graphiql_template") or GRAPHIQL_TEMPLATE - graphiql_html_title = config.get("graphiql_html_title") or "GraphiQL" + options_with_defaults = dict(GRAPHIQL_DEFAULT_OPTIONS, **(options or {})) template_vars: Dict[str, Any] = { - "graphiql_version": graphiql_version, - "graphiql_html_title": graphiql_html_title, - "query": data.get("query"), - "variables": data.get("variables"), - "operation_name": data.get("operation_name"), - "result": data.get("result"), - "subscription_url": data.get("subscription_url") or "", - "headers": data.get("headers") or "", - "default_query": options and options.get("default_query") or "", - "header_editor_enabled": options - and options.get("header_editor_enabled") - or "true", - "should_persist_headers": options - and options.get("should_persist_headers") - or "false", + "result": tojson(data), + "query": tojson(params.query), + "variables": tojson(params.variables), + "operation_name": tojson(params.operation_name), + + "html_title": options_with_defaults["html_title"], + "graphiql_version": options_with_defaults["graphiql_version"], + "subscription_url": tojson(options_with_defaults["subscription_url"]), + "headers": tojson(options_with_defaults["headers"]), + "default_query": tojson(options_with_defaults["default_query"]), + "header_editor_enabled": tojson(options_with_defaults["header_editor_enabled"]), + "should_persist_headers": tojson(options_with_defaults["should_persist_headers"]) } - return graphiql_template, template_vars - - -async def render_graphiql_async( - data: GraphiQLData, - config: GraphiQLConfig, - options: Optional[GraphiQLOptions] = None, -) -> str: - graphiql_template, template_vars = _render_graphiql(data, config, options) - jinja_env: Optional[Environment] = config.get("jinja_env") - - if jinja_env: - # This method returns a Template. See https://jinja.palletsprojects.com/en/2.11.x/api/#jinja2.Template - template = jinja_env.from_string(graphiql_template) - if jinja_env.is_async: # type: ignore - source = await template.render_async(**template_vars) - else: - source = template.render(**template_vars) - else: - source = simple_renderer(graphiql_template, **template_vars) - return source + return template_vars def render_graphiql_sync( - data: GraphiQLData, - config: GraphiQLConfig, + result: str, + params: GraphQLParams, options: Optional[GraphiQLOptions] = None, ) -> str: - graphiql_template, template_vars = _render_graphiql(data, config, options) - - source = simple_renderer(graphiql_template, **template_vars) + template_vars = get_template_vars(result, params, options) + source = simple_renderer(GRAPHIQL_TEMPLATE, **template_vars) return source diff --git a/graphql_server/sanic/graphqlview.py b/graphql_server/sanic/graphqlview.py index 569db53..b93cb2b 100644 --- a/graphql_server/sanic/graphqlview.py +++ b/graphql_server/sanic/graphqlview.py @@ -133,9 +133,9 @@ async def dispatch_request(self, request, *args, **kwargs): if show_graphiql: graphiql_data = GraphiQLData( result=result, - query=getattr(all_params[0], "query"), - variables=getattr(all_params[0], "variables"), - operation_name=getattr(all_params[0], "operation_name"), + query=all_params[0].query, + variables=all_params[0].variables, + operation_name=all_params[0].operation_name, subscription_url=self.subscriptions, headers=self.headers, ) diff --git a/graphql_server/websockets/__init__.py b/graphql_server/websockets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/graphql_server/websockets/constants.py b/graphql_server/websockets/constants.py new file mode 100644 index 0000000..61a8c08 --- /dev/null +++ b/graphql_server/websockets/constants.py @@ -0,0 +1,2 @@ +GRAPHQL_TRANSPORT_WS_PROTOCOL = "graphql-transport-ws" +GRAPHQL_WS_PROTOCOL = "graphql-ws" diff --git a/graphql_server/websockets/transport_ws_protocol/__init__.py b/graphql_server/websockets/transport_ws_protocol/__init__.py new file mode 100644 index 0000000..fc07e86 --- /dev/null +++ b/graphql_server/websockets/transport_ws_protocol/__init__.py @@ -0,0 +1,3 @@ +from .handlers import BaseGraphQLTransportWSHandler + +__all__ = ["BaseGraphQLTransportWSHandler"] diff --git a/graphql_server/websockets/transport_ws_protocol/contstants.py b/graphql_server/websockets/transport_ws_protocol/contstants.py new file mode 100644 index 0000000..24ddfa7 --- /dev/null +++ b/graphql_server/websockets/transport_ws_protocol/contstants.py @@ -0,0 +1,8 @@ +GQL_CONNECTION_INIT = "connection_init" +GQL_CONNECTION_ACK = "connection_ack" +GQL_PING = "ping" +GQL_PONG = "pong" +GQL_SUBSCRIBE = "subscribe" +GQL_NEXT = "next" +GQL_ERROR = "error" +GQL_COMPLETE = "complete" diff --git a/graphql_server/websockets/transport_ws_protocol/handlers.py b/graphql_server/websockets/transport_ws_protocol/handlers.py new file mode 100644 index 0000000..8df0b56 --- /dev/null +++ b/graphql_server/websockets/transport_ws_protocol/handlers.py @@ -0,0 +1,223 @@ +import asyncio +from abc import ABC, abstractmethod +from contextlib import suppress +from datetime import timedelta +from typing import Any, AsyncGenerator, Dict, Optional, TypedDict, cast + +from graphql import ( + parse, + ExecutionResult as GraphQLExecutionResult, + GraphQLError, + GraphQLSchema, + subscribe, +) +from graphql.error import format_error as format_graphql_error +from ..constants import GRAPHQL_TRANSPORT_WS_PROTOCOL + +from .types import ( + CompleteMessage, + ConnectionAckMessage, + ConnectionInitMessage, + ErrorMessage, + NextMessage, + PingMessage, + PongMessage, + SubscribeMessage, +) + +# from .contstants import ( +# GQL_CONNECTION_INIT, +# GQL_CONNECTION_ACK, +# GQL_PING, +# GQL_PONG, +# GQL_SUBSCRIBE, +# GQL_NEXT, +# GQL_ERROR, +# GQL_COMPLETE, +# ) + + +class BaseGraphQLTransportWSHandler(ABC): + PROTOCOL = GRAPHQL_TRANSPORT_WS_PROTOCOL + + def __init__( + self, + schema: GraphQLSchema, + debug: bool, + connection_init_wait_timeout: timedelta, + ): + self.schema = schema + self.debug = debug + self.connection_init_wait_timeout = connection_init_wait_timeout + self.connection_init_timeout_task: Optional[asyncio.Task] = None + self.connection_init_received = False + self.connection_acknowledged = False + self.subscriptions: Dict[str, AsyncGenerator] = {} + self.tasks: Dict[str, asyncio.Task] = {} + + @abstractmethod + async def get_context(self) -> Any: + """Return the operations context""" + + @abstractmethod + async def get_root_value(self) -> Any: + """Return the schemas root value""" + + @abstractmethod + async def send_xjson(self, data: dict) -> None: + """Send the data JSON encoded to the WebSocket client""" + + @abstractmethod + async def close(self, code: int, reason: str) -> None: + """Close the WebSocket with the passed code and reason""" + + @abstractmethod + async def handle_request(self) -> Any: + """Handle the request this instance was created for""" + + async def handle(self) -> Any: + timeout_handler = self.handle_connection_init_timeout() + self.connection_init_timeout_task = asyncio.create_task(timeout_handler) + return await self.handle_request() + + async def handle_connection_init_timeout(self): + delay = self.connection_init_wait_timeout.total_seconds() + await asyncio.sleep(delay=delay) + + if self.connection_init_received: + return + + reason = "Connection initialisation timeout" + await self.close(code=4408, reason=reason) + + async def handle_message(self, message: dict): + try: + message_type = message.pop("type") + + if message_type == ConnectionInitMessage.type: + await self.handle_connection_init(cast(ConnectionInitMessage, message)) + + elif message_type == PingMessage.type: + await self.handle_ping(cast(PingMessage, message)) + + elif message_type == PongMessage.type: + await self.handle_pong(cast(PongMessage, message)) + + elif message_type == SubscribeMessage.type: + await self.handle_subscribe(cast(SubscribeMessage, message)) + + elif message_type == CompleteMessage.type: + await self.handle_complete(cast(CompleteMessage, message)) + + else: + error_message = f"Unknown message type: {message_type}" + await self.handle_invalid_message(error_message) + + except (KeyError, TypeError): + error_message = "Failed to parse message" + await self.handle_invalid_message(error_message) + + async def handle_connection_init(self, message: ConnectionInitMessage) -> None: + if self.connection_init_received: + reason = "Too many initialisation requests" + await self.close(code=4429, reason=reason) + return + + self.connection_init_received = True + await self.send_message(ConnectionAckMessage()) + self.connection_acknowledged = True + + async def handle_ping(self, message: PingMessage) -> None: + await self.send_message(PongMessage()) + + async def handle_pong(self, message: PongMessage) -> None: + pass + + async def handle_subscribe(self, message: SubscribeMessage) -> None: + if not self.connection_acknowledged: + await self.close(code=4401, reason="Unauthorized") + return + + if message.id in self.subscriptions.keys(): + reason = f"Subscriber for {message.id} already exists" + await self.close(code=4409, reason=reason) + return + + context = await self.get_context() + root_value = await self.get_root_value() + + try: + result_source = await subscribe( + document=parse(message.payload.query), + schema=self.schema, + variable_values=message.payload.variables, + operation_name=message.payload.operationName, + context_value=context, + root_value=root_value, + ) + except GraphQLError as error: + payload = [format_graphql_error(error)] + await self.send_message(ErrorMessage(id=message.id, payload=payload)) + self.process_errors([error]) + return + + if isinstance(result_source, GraphQLExecutionResult): + assert result_source.errors + payload = [format_graphql_error(result_source.errors[0])] + await self.send_message(ErrorMessage(id=message.id, payload=payload)) + self.process_errors(result_source.errors) + return + + handler = self.handle_async_results(result_source, message.id) + self.subscriptions[message.id] = result_source + self.tasks[message.id] = asyncio.create_task(handler) + + async def handle_async_results( + self, + result_source: AsyncGenerator, + operation_id: str, + ) -> None: + try: + async for result in result_source: + if result.errors: + error_payload = [format_graphql_error(err) for err in result.errors] + error_message = ErrorMessage(id=operation_id, payload=error_payload) + await self.send_message(error_message) + self.process_errors(result.errors) + return + else: + next_payload = {"data": result.data} + next_message = NextMessage(id=operation_id, payload=next_payload) + await self.send_message(next_message) + except asyncio.CancelledError: + # CancelledErrors are expected during task cleanup. + return + except Exception as error: + # GraphQLErrors are handled by graphql-core and included in the + # ExecutionResult + error = GraphQLError(str(error), original_error=error) + error_payload = [format_graphql_error(error)] + error_message = ErrorMessage(id=operation_id, payload=error_payload) + await self.send_message(error_message) + self.process_errors([error]) + return + + await self.send_message(CompleteMessage(id=operation_id)) + + async def handle_complete(self, message: CompleteMessage) -> None: + await self.cleanup_operation(operation_id=message.id) + + async def handle_invalid_message(self, error_message: str) -> None: + await self.close(code=4400, reason=error_message) + + async def send_message(self, data: TypedDict) -> None: + await self.send_message({**data, "type": data.type}) + + async def cleanup_operation(self, operation_id: str) -> None: + await self.subscriptions[operation_id].aclose() + del self.subscriptions[operation_id] + + self.tasks[operation_id].cancel() + with suppress(BaseException): + await self.tasks[operation_id] + del self.tasks[operation_id] diff --git a/graphql_server/websockets/transport_ws_protocol/types.py b/graphql_server/websockets/transport_ws_protocol/types.py new file mode 100644 index 0000000..3042681 --- /dev/null +++ b/graphql_server/websockets/transport_ws_protocol/types.py @@ -0,0 +1,100 @@ +from typing import Any, Dict, List, Optional + +try: + from typing import TypedDict +except ImportError: + from typing_extensions import TypedDict + +from .contstants import ( + GQL_CONNECTION_INIT, + GQL_CONNECTION_ACK, + GQL_PING, + GQL_PONG, + GQL_SUBSCRIBE, + GQL_NEXT, + GQL_ERROR, + GQL_COMPLETE, +) + + +class ConnectionInitMessage(TypedDict): + """ + Direction: Client -> Server + """ + + payload: Optional[Dict[str, Any]] + type = GQL_CONNECTION_INIT + + +class ConnectionAckMessage(TypedDict): + """ + Direction: Server -> Client + """ + + payload: Optional[Dict[str, Any]] + type = GQL_CONNECTION_ACK + + +class PingMessage(TypedDict): + """ + Direction: bidirectional + """ + + payload: Optional[Dict[str, Any]] + type = GQL_PING + + +class PongMessage(TypedDict): + """ + Direction: bidirectional + """ + + payload: Optional[Dict[str, Any]] + type = GQL_PONG + + +class SubscribeMessagePayload(TypedDict): + query: str + operationName: Optional[str] + variables: Optional[Dict[str, Any]] + extensions: Optional[Dict[str, Any]] + + +class SubscribeMessage(TypedDict): + """ + Direction: Client -> Server + """ + + id: str + payload: SubscribeMessagePayload + type = GQL_SUBSCRIBE + + +class NextMessage(TypedDict): + """ + Direction: Server -> Client + """ + + id: str + payload: Dict[str, Any] # TODO: shape like ExecutionResult + type = GQL_NEXT + + +class ErrorMessage(TypedDict): + """ + Direction: Server -> Client + """ + + id: str + payload: List[Dict[str, Any]] # TODO: shape like List[GraphQLError] + type = GQL_ERROR + + +class CompleteMessage(TypedDict): + """ + Direction: bidirectional + """ + + type = GQL_COMPLETE + + id: str diff --git a/graphql_server/websockets/ws_protocol/__init__.py b/graphql_server/websockets/ws_protocol/__init__.py new file mode 100644 index 0000000..211c95c --- /dev/null +++ b/graphql_server/websockets/ws_protocol/__init__.py @@ -0,0 +1,4 @@ +from .handlers import BaseGraphQLWSHandler +from .types import OperationMessage + +__all__ = ["BaseGraphQLWSHandler", "OperationMessage"] diff --git a/graphql_server/websockets/ws_protocol/constants.py b/graphql_server/websockets/ws_protocol/constants.py new file mode 100644 index 0000000..8b884cf --- /dev/null +++ b/graphql_server/websockets/ws_protocol/constants.py @@ -0,0 +1,10 @@ +GQL_CONNECTION_INIT = "connection_init" +GQL_CONNECTION_ACK = "connection_ack" +GQL_CONNECTION_ERROR = "connection_error" +GQL_CONNECTION_TERMINATE = "connection_terminate" +GQL_CONNECTION_KEEP_ALIVE = "ka" +GQL_START = "start" +GQL_DATA = "data" +GQL_ERROR = "error" +GQL_COMPLETE = "complete" +GQL_STOP = "stop" diff --git a/graphql_server/websockets/ws_protocol/handlers.py b/graphql_server/websockets/ws_protocol/handlers.py new file mode 100644 index 0000000..0499e93 --- /dev/null +++ b/graphql_server/websockets/ws_protocol/handlers.py @@ -0,0 +1,201 @@ +import asyncio +from abc import ABC, abstractmethod +from contextlib import suppress +from typing import Any, AsyncGenerator, Dict, Optional, cast, List + +from graphql import ( + parse, + ExecutionResult as GraphQLExecutionResult, + GraphQLError, + GraphQLSchema, + subscribe, +) +from graphql.error import format_error as format_graphql_error +from ..constants import GRAPHQL_WS_PROTOCOL + +from .constants import ( + GQL_COMPLETE, + GQL_CONNECTION_ACK, + GQL_CONNECTION_INIT, + GQL_CONNECTION_KEEP_ALIVE, + GQL_CONNECTION_TERMINATE, + GQL_DATA, + GQL_ERROR, + GQL_START, + GQL_STOP, +) +from .types import ( + OperationMessage, + OperationMessagePayload, + StartPayload, +) + + +class BaseGraphQLWSHandler(ABC): + PROTOCOL = GRAPHQL_WS_PROTOCOL + + def __init__( + self, + schema: GraphQLSchema, + debug: bool, + keep_alive: bool, + keep_alive_interval: float, + ): + self.schema = schema + self.debug = debug + self.keep_alive = keep_alive + self.keep_alive_interval = keep_alive_interval + self.keep_alive_task: Optional[asyncio.Task] = None + self.subscriptions: Dict[str, AsyncGenerator] = {} + self.tasks: Dict[str, asyncio.Task] = {} + + @abstractmethod + async def get_context(self) -> Any: + """Return the operations context""" + + @abstractmethod + async def get_root_value(self) -> Any: + """Return the schemas root value""" + + @abstractmethod + async def send_json(self, data: OperationMessage) -> None: + """Send the data JSON encoded to the WebSocket client""" + + @abstractmethod + async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: + """Close the WebSocket with the passed code and reason""" + + @abstractmethod + async def handle_request(self) -> Any: + """Handle the request this instance was created for""" + + async def handle(self) -> Any: + return await self.handle_request() + + async def handle_message( + self, + message: OperationMessage, + ) -> None: + message_type = message["type"] + + if message_type == GQL_CONNECTION_INIT: + await self.handle_connection_init(message) + elif message_type == GQL_CONNECTION_TERMINATE: + await self.handle_connection_terminate(message) + elif message_type == GQL_START: + await self.handle_start(message) + elif message_type == GQL_STOP: + await self.handle_stop(message) + + async def handle_connection_init(self, message: OperationMessage) -> None: + data: OperationMessage = {"type": GQL_CONNECTION_ACK} + await self.send_json(data) + + if self.keep_alive: + keep_alive_handler = self.handle_keep_alive() + self.keep_alive_task = asyncio.create_task(keep_alive_handler) + + async def handle_connection_terminate(self, message: OperationMessage) -> None: + await self.close() + + def process_errors(self, errors: List[Any]): + """Process the GraphQL response errors""" + + async def handle_start(self, message: OperationMessage) -> None: + operation_id = message["id"] + payload = cast(StartPayload, message["payload"]) + query = payload["query"] + operation_name = payload.get("operationName") + variables = payload.get("variables") + + context = await self.get_context() + root_value = await self.get_root_value() + + try: + result_source = await subscribe( + document=parse(query), + schema=self.schema, + variable_values=variables, + operation_name=operation_name, + context_value=context, + root_value=root_value, + ) + except GraphQLError as error: + error_payload = format_graphql_error(error) + await self.send_message(GQL_ERROR, operation_id, error_payload) + self.process_errors([error]) + return + + if isinstance(result_source, GraphQLExecutionResult): + assert result_source.errors + error_payload = format_graphql_error(result_source.errors[0]) + await self.send_message(GQL_ERROR, operation_id, error_payload) + self.process_errors(result_source.errors) + return + + self.subscriptions[operation_id] = result_source + result_handler = self.handle_async_results(result_source, operation_id) + self.tasks[operation_id] = asyncio.create_task(result_handler) + + async def handle_stop(self, message: OperationMessage) -> None: + operation_id = message["id"] + await self.cleanup_operation(operation_id) + + async def handle_keep_alive(self) -> None: + while True: + data: OperationMessage = {"type": GQL_CONNECTION_KEEP_ALIVE} + await self.send_json(data) + await asyncio.sleep(self.keep_alive_interval) + + async def handle_async_results( + self, + result_source: AsyncGenerator, + operation_id: str, + ) -> None: + try: + async for result in result_source: + payload = {"data": result.data} + if result.errors: + payload["errors"] = [ + format_graphql_error(err) for err in result.errors + ] + await self.send_message(GQL_DATA, operation_id, payload) + # log errors after send_message to prevent potential + # slowdown of sending result + if result.errors: + self.process_errors(result.errors) + except asyncio.CancelledError: + # CancelledErrors are expected during task cleanup. + pass + except Exception as error: + # GraphQLErrors are handled by graphql-core and included in the + # ExecutionResult + error = GraphQLError(str(error), original_error=error) + await self.send_message( + GQL_DATA, + operation_id, + {"data": None, "errors": [format_graphql_error(error)]}, + ) + self.process_errors([error]) + + await self.send_message(GQL_COMPLETE, operation_id, None) + + async def cleanup_operation(self, operation_id: str) -> None: + await self.subscriptions[operation_id].aclose() + del self.subscriptions[operation_id] + + self.tasks[operation_id].cancel() + with suppress(BaseException): + await self.tasks[operation_id] + del self.tasks[operation_id] + + async def send_message( + self, + type_: str, + operation_id: str, + payload: Optional[OperationMessagePayload] = None, + ) -> None: + data: OperationMessage = {"type": type_, "id": operation_id} + if payload is not None: + data["payload"] = payload + await self.send_json(data) diff --git a/graphql_server/websockets/ws_protocol/types.py b/graphql_server/websockets/ws_protocol/types.py new file mode 100644 index 0000000..94a0112 --- /dev/null +++ b/graphql_server/websockets/ws_protocol/types.py @@ -0,0 +1,47 @@ +from typing import Any, Dict, List, Optional, Union + +try: + from typing import TypedDict +except ImportError: + from typing_extensions import TypedDict + + +ConnectionInitPayload = Dict[str, Any] + + +ConnectionErrorPayload = Dict[str, Any] + + +class StartPayload(TypedDict, total=False): + query: str + variables: Optional[Dict[str, Any]] + operationName: Optional[str] + + +class DataPayload(TypedDict, total=False): + data: Any + + # Optional list of formatted graphql.GraphQLError objects + errors: Optional[List[Dict[str, Any]]] + + +class ErrorPayload(TypedDict): + id: str + + # Formatted graphql.GraphQLError object + payload: Dict[str, Any] + + +OperationMessagePayload = Union[ + ConnectionInitPayload, + ConnectionErrorPayload, + StartPayload, + DataPayload, + ErrorPayload, +] + + +class OperationMessage(TypedDict, total=False): + type: str + id: str + payload: OperationMessagePayload From c4a80d17fee217a13b9ebbccbdb1985ea9ea3563 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Sun, 6 Feb 2022 22:54:08 -0800 Subject: [PATCH 02/14] Updated version requirements --- graphql_server/__init__.py | 21 +++++++--- graphql_server/aiohttp/graphqlview.py | 26 ++++++++---- graphql_server/django/views.py | 54 ++++++++++++------------- graphql_server/graphiql_render_jinja.py | 1 - graphql_server/render_graphiql.py | 8 ++-- setup.py | 2 +- 6 files changed, 66 insertions(+), 46 deletions(-) diff --git a/graphql_server/__init__.py b/graphql_server/__init__.py index b20f94b..278dfb8 100644 --- a/graphql_server/__init__.py +++ b/graphql_server/__init__.py @@ -58,12 +58,14 @@ # The public data structures + @dataclass class GraphQLParams: query: str variables: Optional[Dict[str, Any]] = None operation_name: Optional[str] = None + @dataclass class GraphQLResponse: params: List[GraphQLParams] @@ -76,8 +78,10 @@ class ServerResponse: status_code: int headers: Optional[Dict[str, str]] = None + # The public helper functions + def get_schema(schema: GraphQLSchema): if not isinstance(schema, GraphQLSchema): # maybe the GraphQL schema is wrapped in a Graphene schema @@ -164,7 +168,12 @@ def run_http_query( return GraphQLResponse(results=results, params=all_params) -def process_preflight(origin_header: Optional[str], request_method: Optional[str], accepted_methods: List[str], max_age: int) -> ServerResponse: +def process_preflight( + origin_header: Optional[str], + request_method: Optional[str], + accepted_methods: List[str], + max_age: int, +) -> ServerResponse: """ Preflight request support for apollo-client https://www.w3.org/TR/cors/#resource-preflight-requests @@ -239,9 +248,7 @@ def load_json_body(data: str, batch: bool = False) -> Union[Dict, List]: assert isinstance(request_json, list), ( "Batch requests should receive a list, but received {}." ).format(repr(request_json)) - assert ( - len(request_json) > 0 - ), "Received an empty list in the batch request." + assert len(request_json) > 0, "Received an empty list in the batch request." else: assert isinstance( request_json, dict @@ -275,7 +282,11 @@ def get_graphql_params(data: Dict, query_data: Dict) -> GraphQLParams: # document_id = data.get('documentId') operation_name = data.get("operationName") or query_data.get("operationName") - return GraphQLParams(query=query, variables=load_json_variables(variables), operation_name=operation_name) + return GraphQLParams( + query=query, + variables=load_json_variables(variables), + operation_name=operation_name, + ) def load_json_variables(variables: Optional[Union[str, Dict]]) -> Optional[Dict]: diff --git a/graphql_server/aiohttp/graphqlview.py b/graphql_server/aiohttp/graphqlview.py index a7739b8..8f9ba36 100644 --- a/graphql_server/aiohttp/graphqlview.py +++ b/graphql_server/aiohttp/graphqlview.py @@ -24,6 +24,7 @@ from typing import Dict, Any + class GraphQLView: accepted_methods = ["GET", "POST", "PUT", "DELETE"] @@ -31,7 +32,10 @@ class GraphQLView: format_error = staticmethod(format_error_default) encode = staticmethod(json_encode) - def __init__(self, schema: GraphQLSchema, *, + def __init__( + self, + schema: GraphQLSchema, + *, root_value: Any = None, pretty: bool = False, graphiql: bool = True, @@ -94,8 +98,11 @@ async def parse_body(request): return {} def is_graphiql(self, request_method, is_raw, accept_headers): - return (self.graphiql and request_method == "get" - and not is_raw and ("text/html" in accept_headers or "*/*" in accept_headers), + return ( + self.graphiql + and request_method == "get" + and not is_raw + and ("text/html" in accept_headers or "*/*" in accept_headers), ) def should_prettify(self, is_graphiql, pretty_query): @@ -106,17 +113,20 @@ async def __call__(self, request): data = await self.parse_body(request) request_method = request.method.lower() accept_headers = request.headers.get("accept", {}) - is_graphiql = self.is_graphiql(request_method, request.query.get("raw"), accept_headers) + is_graphiql = self.is_graphiql( + request_method, request.query.get("raw"), accept_headers + ) is_pretty = self.should_prettify(is_graphiql, request.query.get("pretty")) if request_method == "options": headers = request.headers origin = headers.get("Origin", "") method = headers.get("Access-Control-Request-Method", "").upper() - response = process_preflight(origin, method, self.accepted_methods, self.max_age) + response = process_preflight( + origin, method, self.accepted_methods, self.max_age + ) return web.Response( - status=response.status_code, - headers = response.headers + status=response.status_code, headers=response.headers ) graphql_response = run_http_query( @@ -153,7 +163,7 @@ async def __call__(self, request): source = self.render_graphiql( result=response.body, params=graphql_response.all_params[0], - options=self.graphiql_options + options=self.graphiql_options, ) return web.Response(text=source, content_type="text/html") diff --git a/graphql_server/django/views.py b/graphql_server/django/views.py index d787f09..5eb0cad 100644 --- a/graphql_server/django/views.py +++ b/graphql_server/django/views.py @@ -63,7 +63,9 @@ class GraphQLView(View): max_age: int = 86400 graphiql_options: Optional[GraphiQLOptions] = None - def __init__(self, schema: GraphQLSchema, + def __init__( + self, + schema: GraphQLSchema, root_value: Any = None, pretty: bool = False, graphiql: bool = True, @@ -144,9 +146,7 @@ def request_prefers_html(cls, request: HttpRequest): return html_priority > json_priority def is_graphiql(self, request_method: str, is_raw: bool, prefers_html: bool): - return (self.graphiql and request_method == "get" - and not is_raw and prefers_html - ) + return self.graphiql and request_method == "get" and not is_raw and prefers_html def should_prettify(self, is_graphiql: bool, pretty_in_query: bool): return self.pretty or is_graphiql or pretty_in_query @@ -157,17 +157,20 @@ def dispatch(self, request: HttpRequest, *args, **kwargs): data = self.parse_body(request) request_method = request.method.lower() prefers_html = self.request_prefers_html(request) - is_graphiql = self.is_graphiql(request_method, "raw" in request.GET, prefers_html) + is_graphiql = self.is_graphiql( + request_method, "raw" in request.GET, prefers_html + ) is_pretty = self.should_prettify(is_graphiql, request.GET.get("pretty")) if request_method == "options": headers = request.headers origin = headers.get("Origin", "") method = headers.get("Access-Control-Request-Method", "").upper() - response = process_preflight(origin, method, self.accepted_methods, self.max_age) + response = process_preflight( + origin, method, self.accepted_methods, self.max_age + ) return HTTPResponse( - status=response.status_code, - headers=response.headers + status=response.status_code, headers=response.headers ) graphql_response = run_http_query( @@ -196,12 +199,9 @@ def dispatch(self, request: HttpRequest, *args, **kwargs): source = self.render_graphiql( result=response.body, params=graphql_response.params[0], - options=self.graphiql_options - ) - return HttpResponse( - content=source, - content_type="text/html" + options=self.graphiql_options, ) + return HttpResponse(content=source, content_type="text/html") return HttpResponse( content=response.body, @@ -235,17 +235,20 @@ async def dispatch(self, request, *args, **kwargs): data = self.parse_body(request) request_method = request.method.lower() prefers_html = self.request_prefers_html(request) - is_graphiql = self.is_graphiql(request_method, "raw" in request.GET, prefers_html) + is_graphiql = self.is_graphiql( + request_method, "raw" in request.GET, prefers_html + ) is_pretty = self.should_prettify(is_graphiql, request.GET.get("pretty")) if request_method == "options": headers = request.headers origin = headers.get("Origin", "") method = headers.get("Access-Control-Request-Method", "").upper() - response = process_preflight(origin, method, self.accepted_methods, self.max_age) + response = process_preflight( + origin, method, self.accepted_methods, self.max_age + ) return HTTPResponse( - status=response.status_code, - headers=response.headers + status=response.status_code, headers=response.headers ) graphql_response = run_http_query( @@ -263,12 +266,10 @@ async def dispatch(self, request, *args, **kwargs): validation_rules=self.get_validation_rules(), ) - exec_res = ( - [ - ex if ex is None or isinstance(ex, ExecutionResult) else await ex - for ex in graphql_response.results - ] - ) + exec_res = [ + ex if ex is None or isinstance(ex, ExecutionResult) else await ex + for ex in graphql_response.results + ] response = encode_execution_results( exec_res, @@ -281,12 +282,9 @@ async def dispatch(self, request, *args, **kwargs): source = self.render_graphiql( result=response.body, params=graphql_response.params[0], - options=self.graphiql_options - ) - return HttpResponse( - content=source, - content_type="text/html" + options=self.graphiql_options, ) + return HttpResponse(content=source, content_type="text/html") return HttpResponse( content=response.body, diff --git a/graphql_server/graphiql_render_jinja.py b/graphql_server/graphiql_render_jinja.py index 91ba416..08ff049 100644 --- a/graphql_server/graphiql_render_jinja.py +++ b/graphql_server/graphiql_render_jinja.py @@ -1,4 +1,3 @@ - async def render_graphiql_async( data: GraphiQLData, config: GraphiQLConfig, diff --git a/graphql_server/render_graphiql.py b/graphql_server/render_graphiql.py index 37f414e..2073d8d 100644 --- a/graphql_server/render_graphiql.py +++ b/graphql_server/render_graphiql.py @@ -185,6 +185,7 @@ class GraphiQLOptions(TypedDict): subscription_url: Optional[str] headers: Optional[str] + GRAPHIQL_DEFAULT_OPTIONS: GraphiQLOptions = { "html_title": "GraphiQL", "graphiql_version": GRAPHIQL_VERSION, @@ -192,7 +193,7 @@ class GraphiQLOptions(TypedDict): "header_editor_enabled": True, "should_persist_headers": False, "subscription_url": None, - "headers": "" + "headers": "", } @@ -245,14 +246,15 @@ def get_template_vars( "query": tojson(params.query), "variables": tojson(params.variables), "operation_name": tojson(params.operation_name), - "html_title": options_with_defaults["html_title"], "graphiql_version": options_with_defaults["graphiql_version"], "subscription_url": tojson(options_with_defaults["subscription_url"]), "headers": tojson(options_with_defaults["headers"]), "default_query": tojson(options_with_defaults["default_query"]), "header_editor_enabled": tojson(options_with_defaults["header_editor_enabled"]), - "should_persist_headers": tojson(options_with_defaults["should_persist_headers"]) + "should_persist_headers": tojson( + options_with_defaults["should_persist_headers"] + ), } return template_vars diff --git a/setup.py b/setup.py index e2dfcaf..70655a4 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from re import search from setuptools import setup, find_packages -install_requires = ["graphql-core>=3.2,<3.3", "typing-extensions>=4,<5"] +install_requires = ["graphql-core>=3.1,<3.3", "typing-extensions>=4,<5"] tests_requires = [ "pytest>=6.2,<6.3", From 33fbc489e8a1868be9af506b8c592e7f6144ac99 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Mon, 7 Feb 2022 13:29:28 -0800 Subject: [PATCH 03/14] Added django debug toolbar fix --- graphql_server/django/debug_toolbar.py | 197 +++++++++++++++++++++++++ 1 file changed, 197 insertions(+) create mode 100644 graphql_server/django/debug_toolbar.py diff --git a/graphql_server/django/debug_toolbar.py b/graphql_server/django/debug_toolbar.py new file mode 100644 index 0000000..7c464c7 --- /dev/null +++ b/graphql_server/django/debug_toolbar.py @@ -0,0 +1,197 @@ +# Based on https://github.com/flavors/django-graphiql-debug-toolbar + +import collections +import contextvars +import json +from typing import Optional +import weakref + +from debug_toolbar.middleware import DebugToolbarMiddleware as _DebugToolbarMiddleware +from debug_toolbar.middleware import _HTML_TYPES +from debug_toolbar.middleware import show_toolbar +from debug_toolbar.panels.sql import panel, tracking +from debug_toolbar.panels.templates import TemplatesPanel +from debug_toolbar.panels.templates import panel as tpanel +from debug_toolbar.toolbar import DebugToolbar +from django.core.serializers.json import DjangoJSONEncoder +from django.http.request import HttpRequest +from django.http.response import HttpResponse +from django.template.loader import render_to_string +from django.utils.encoding import force_str +from .views import GraphQLView + +_store_cache = weakref.WeakKeyDictionary() +_original_store = DebugToolbar.store +_recording = contextvars.ContextVar("debug-toolbar-recording", default=True) + + +def _store(toolbar: DebugToolbar): + _original_store(toolbar) + _store_cache[toolbar.request] = toolbar.store_id + + +def _get_payload(request: HttpRequest, response: HttpResponse): + store_id = _store_cache.get(request) + if not store_id: + return None + + toolbar: Optional[DebugToolbar] = DebugToolbar.fetch(store_id) + if not toolbar: + return None + + content = force_str(response.content, encoding=response.charset) + payload = json.loads(content, object_pairs_hook=collections.OrderedDict) + payload["debugToolbar"] = collections.OrderedDict([("panels", collections.OrderedDict())]) + payload["debugToolbar"]["storeId"] = toolbar.store_id + + for p in reversed(toolbar.enabled_panels): + if p.panel_id == "TemplatesPanel": + continue + + if p.has_content: + title = p.title + else: + title = None + + sub = p.nav_subtitle + payload["debugToolbar"]["panels"][p.panel_id] = { + "title": title() if callable(title) else title, + "subtitle": sub() if callable(sub) else sub, + } + + return payload + + +DebugToolbar.store = _store # type:ignore +# FIXME: This is breaking async views when it tries to render the user +# without being in an async safe context. How to properly handle this? +TemplatesPanel._store_template_info = lambda *args, **kwargs: None + + +def _wrap_cursor(connection, panel): + c = type(connection) + if hasattr(c, "_djdt_cursor"): + return None + + c._djdt_cursor = c.cursor + c._djdt_chunked_cursor = c.chunked_cursor + + def cursor(*args, **kwargs): + if _recording.get(): + wrapper = tracking.NormalCursorWrapper + else: + wrapper = tracking.ExceptionCursorWrapper + return wrapper(c._djdt_cursor(*args, **kwargs), args[0], panel) + + def chunked_cursor(*args, **kwargs): + cursor = c._djdt_chunked_cursor(*args, **kwargs) + if not isinstance(cursor, tracking.BaseCursorWrapper): + if _recording.get(): + wrapper = tracking.NormalCursorWrapper + else: + wrapper = tracking.ExceptionCursorWrapper + return wrapper(cursor, args[0], panel) + return cursor + + c.cursor = cursor + c.chunked_cursor = chunked_cursor + + return cursor + + +def _unwrap_cursor(connection): + c = type(connection) + if not hasattr(c, "_djdt_cursor"): + return + + c.cursor = c._djdt_cursor + c.chunked_cursor = c._djdt_chunked_cursor + del c._djdt_cursor + del c._djdt_chunked_cursor + + +# Patch wrap_cursor/unwrap_cursor so that they work with async views +# Are there any drawbacks to this? +tracking.wrap_cursor = _wrap_cursor +tracking.unwrap_cursor = _unwrap_cursor +panel.wrap_cursor = _wrap_cursor +panel.unwrap_cursor = _unwrap_cursor +tpanel.recording = _recording + + +class DebugToolbarMiddleware(_DebugToolbarMiddleware): + sync_capable = True + async_capable = True + + def __call__(self, request: HttpRequest): + response = super().__call__(request) + + if not show_toolbar(request) or DebugToolbar.is_toolbar_request(request): + return response + + content_type = response.get("Content-Type", "").split(";")[0] + is_html = content_type in _HTML_TYPES + is_graphiql = getattr(request, "_is_graphiql", False) + + if is_html and is_graphiql and response.status_code == 200: + response.write(""" + +""") + if "Content-Length" in response: + response["Content-Length"] = len(response.content) + + if is_html or not is_graphiql or content_type != "application/json": + return response + + payload = _get_payload(request, response) + if payload is None: + return response + + response.content = json.dumps(payload, cls=DjangoJSONEncoder) + if "Content-Length" in response: + response["Content-Length"] = len(response.content) + + return response + + def process_view(self, request: HttpRequest, view_func, *args, **kwargs): + view = getattr(view_func, "view_class", None) + request._is_graphiql = bool(view and issubclass(view, GraphQLView)) # type:ignore From aa6b278917a5d5729ec2c9fb441490a7f4ff03d7 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Mon, 7 Feb 2022 16:56:47 -0800 Subject: [PATCH 04/14] Reenabled HTTP ASGI Channel --- graphql_server/__init__.py | 17 +- graphql_server/channels/http.py | 213 ++++++++++++------------- graphql_server/django/debug_toolbar.py | 14 +- 3 files changed, 121 insertions(+), 123 deletions(-) diff --git a/graphql_server/__init__.py b/graphql_server/__init__.py index 278dfb8..62c19b7 100644 --- a/graphql_server/__init__.py +++ b/graphql_server/__init__.py @@ -74,8 +74,8 @@ class GraphQLResponse: @dataclass class ServerResponse: - body: Optional[str] status_code: int + body: Optional[str] = None headers: Optional[Dict[str, str]] = None @@ -158,6 +158,8 @@ def run_http_query( all_params: List[GraphQLParams] = [ get_graphql_params(entry, extra_data) for entry in data ] + # print("GET ROOT VALUE 0", type(request_method), all_params) + # print(dict(schema=schema, all_params=all_params, catch_exc=catch_exc, allow_only=allow_only_query, run_sync=run_sync)) results: List[Optional[AwaitableOrValue[ExecutionResult]]] = [ get_response( @@ -165,6 +167,8 @@ def run_http_query( ) for params in all_params ] + # print("GET ROOT VALUE 1") + return GraphQLResponse(results=results, params=all_params) @@ -179,13 +183,14 @@ def process_preflight( https://www.w3.org/TR/cors/#resource-preflight-requests """ if origin_header and request_method and request_method in accepted_methods: + headers = { + "Access-Control-Allow-Origin": origin_header, + "Access-Control-Allow-Methods": ", ".join(accepted_methods), + "Access-Control-Max-Age": str(max_age), + } return ServerResponse( status_code=200, - headers={ - "Access-Control-Allow-Origin": origin_header, - "Access-Control-Allow-Methods": ", ".join(accepted_methods), - "Access-Control-Max-Age": str(max_age), - }, + headers=headers, ) return ServerResponse(status_code=400) diff --git a/graphql_server/channels/http.py b/graphql_server/channels/http.py index d4ba271..098cdf3 100644 --- a/graphql_server/channels/http.py +++ b/graphql_server/channels/http.py @@ -9,25 +9,11 @@ from graphql import GraphQLSchema, ExecutionResult +from ..render_graphiql import render_graphiql_sync +from .. import GraphQLParams -class GraphQLHttpConsumer(AsyncHttpConsumer): - """ - A consumer to provide a view for GraphQL over HTTP. - To use this, place it in your ProtocolTypeRouter for your channels project: - - ``` - from graphql_ws.channels import GraphQLHttpConsumer - from channels.routing import ProtocolTypeRouter, URLRouter - from django.core.asgi import get_asgi_application - application = ProtocolTypeRouter({ - "http": URLRouter([ - re_path("^graphql", GraphQLHttpConsumer(schema=schema)), - re_path("^", get_asgi_application()), - ]), - }) - ``` - """ +class GraphQLHttpConsumer(AsyncHttpConsumer): def __init__( self, schema: GraphQLSchema, @@ -37,107 +23,108 @@ def __init__( self.graphiql = graphiql super().__init__() - # def headers(self): - # return { - # header_name.decode("utf-8").lower(): header_value.decode("utf-8") - # for header_name, header_value in self.scope["headers"] - # } - - # async def parse_multipart_body(self, body): - # await self.send_response(500, "Unable to parse the multipart body") - # return None - - # async def get_graphql_params(self, data): - # query = data.get("query") - # variables = data.get("variables") - # id = data.get("id") - - # if variables and isinstance(variables, str): - # try: - # variables = json.loads(variables) - # except Exception: - # await self.send_response(500, b"Variables are invalid JSON.") - # return None - # operation_name = data.get("operationName") - - # return query, variables, operation_name, id - - # async def get_request_data(self, body) -> Optional[Any]: - # if self.headers.get("content-type", "").startswith("multipart/form-data"): - # data = await self.parse_multipart_body(body) - # if data is None: - # return None - # else: - # try: - # data = json.loads(body) - # except json.JSONDecodeError: - # await self.send_response(500, b"Unable to parse request body as JSON") - # return None - - # query, variables, operation_name, id = self.get_graphql_params(data) - # if not query: - # await self.send_response(500, b"No GraphQL query found in the request") - # return None - - # return query, variables, operation_name, id - - # async def post(self, body): - # request_data = await self.get_request_data(body) - # if request_data is None: - # return - # context = await self.get_context() - # root_value = await self.get_root_value() - - # result = await self.schema.execute( - # query=request_data.query, - # root_value=root_value, - # variable_values=request_data.variables, - # context_value=context, - # operation_name=request_data.operation_name, - # ) - - # response_data = self.process_result(result) - # await self.send_response( - # 200, - # json.dumps(response_data).encode("utf-8"), - # headers=[(b"Content-Type", b"application/json")], - # ) - - # def graphiql_html_file_path(self) -> Path: - # return Path(__file__).parent.parent.parent / "static" / "graphiql.html" - - # async def render_graphiql(self, body): - # html_string = self.graphiql_html_file_path.read_text() - # html_string = html_string.replace("{{ SUBSCRIPTION_ENABLED }}", "true") - # await self.send_response( - # 200, html_string.encode("utf-8"), headers=[(b"Content-Type", b"text/html")] - # ) - - # def should_render_graphiql(self): - # return bool(self.graphiql and "text/html" in self.headers.get("accept", "")) - - # async def get(self, body): - # # if self.should_render_graphiql(): - # # return await self.render_graphiql(body) - # # else: - # await self.send_response( - # 200, "{}", headers=[(b"Content-Type", b"text/json")] - # ) + def headers(self): + return { + header_name.decode("utf-8").lower(): header_value.decode("utf-8") + for header_name, header_value in self.scope["headers"] + } + + async def parse_multipart_body(self, body): + await self.send_response(500, "Unable to parse the multipart body") + return None + + async def get_graphql_params(self, data): + query = data.get("query") + variables = data.get("variables") + id = data.get("id") + + if variables and isinstance(variables, str): + try: + variables = json.loads(variables) + except Exception: + await self.send_response(500, b"Variables are invalid JSON.") + return None + operation_name = data.get("operationName") + + return query, variables, operation_name, id + + async def get_request_data(self, body) -> Optional[Any]: + if self.headers.get("content-type", "").startswith("multipart/form-data"): + data = await self.parse_multipart_body(body) + if data is None: + return None + else: + try: + data = json.loads(body) + except json.JSONDecodeError: + await self.send_response(500, b"Unable to parse request body as JSON") + return None + + query, variables, operation_name, id = self.get_graphql_params(data) + if not query: + await self.send_response(500, b"No GraphQL query found in the request") + return None + + return query, variables, operation_name, id + + async def post(self, body): + request_data = await self.get_request_data(body) + if request_data is None: + return + context = await self.get_context() + root_value = await self.get_root_value() + + result = await self.schema.execute( + query=request_data.query, + root_value=root_value, + variable_values=request_data.variables, + context_value=context, + operation_name=request_data.operation_name, + ) + + response_data = self.process_result(result) + await self.send_response( + 200, + json.dumps(response_data).encode("utf-8"), + headers=[(b"Content-Type", b"application/json")], + ) + + def graphiql_html_file_path(self) -> Path: + return Path(__file__).parent.parent.parent / "static" / "graphiql.html" + + async def render_graphiql(self, body, params): + # html_string = self.graphiql_html_file_path.read_text() + # html_string = html_string.replace("{{ SUBSCRIPTION_ENABLED }}", "true") + html_string = render_graphiql_sync(body, params) + await self.send_response( + 200, html_string.encode("utf-8"), headers=[(b"Content-Type", b"text/html")] + ) + + def should_render_graphiql(self): + return bool(self.graphiql and "text/html" in self.headers.get("accept", "")) + + async def get(self, body): + if self.should_render_graphiql(): + return await self.render_graphiql(body, params=GraphQLParams(query="")) + else: + await self.send_response( + 200, "{}", headers=[(b"Content-Type", b"text/json")] + ) async def handle(self, body): - # if self.scope["method"] == "GET": - # return await self.get(body) - # if self.scope["method"] == "POST": - # return await self.post(body) + if self.scope["method"] == "GET": + return await self.get(body) + if self.scope["method"] == "POST": + return await self.post(body) await self.send_response( 200, b"Method not allowed", headers=[b"Allow", b"GET, POST"] ) - # async def get_root_value(self) -> Any: - # return None + async def get_root_value(self) -> Any: + return None - # async def get_context(self) -> Any: - # return None + async def get_context(self) -> Any: + return None - # def process_result(self, result: ExecutionResult): - # return result.formatted + def process_result(self, result: ExecutionResult): + return result.formatted diff --git a/graphql_server/django/debug_toolbar.py b/graphql_server/django/debug_toolbar.py index 7c464c7..197048e 100644 --- a/graphql_server/django/debug_toolbar.py +++ b/graphql_server/django/debug_toolbar.py @@ -41,7 +41,9 @@ def _get_payload(request: HttpRequest, response: HttpResponse): content = force_str(response.content, encoding=response.charset) payload = json.loads(content, object_pairs_hook=collections.OrderedDict) - payload["debugToolbar"] = collections.OrderedDict([("panels", collections.OrderedDict())]) + payload["debugToolbar"] = collections.OrderedDict( + [("panels", collections.OrderedDict())] + ) payload["debugToolbar"]["storeId"] = toolbar.store_id for p in reversed(toolbar.enabled_panels): @@ -134,7 +136,8 @@ def __call__(self, request: HttpRequest): is_graphiql = getattr(request, "_is_graphiql", False) if is_html and is_graphiql and response.status_code == 200: - response.write(""" + response.write( + """ -""") +""" + ) if "Content-Length" in response: response["Content-Length"] = len(response.content) @@ -194,4 +198,6 @@ def __call__(self, request: HttpRequest): def process_view(self, request: HttpRequest, view_func, *args, **kwargs): view = getattr(view_func, "view_class", None) - request._is_graphiql = bool(view and issubclass(view, GraphQLView)) # type:ignore + request._is_graphiql = bool( + view and issubclass(view, GraphQLView) + ) # type:ignore From 30301efc8384d4b6cdd036c6e1fd043b7f614106 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Tue, 8 Feb 2022 18:31:21 -0800 Subject: [PATCH 05/14] Updated channels http client --- graphql_server/channels/http.py | 322 +++++++++++++++++++++++--------- 1 file changed, 235 insertions(+), 87 deletions(-) diff --git a/graphql_server/channels/http.py b/graphql_server/channels/http.py index 098cdf3..6a8ee39 100644 --- a/graphql_server/channels/http.py +++ b/graphql_server/channels/http.py @@ -1,16 +1,47 @@ -"""GraphQLHTTPHandler +"""GraphQLHttpConsumer A consumer to provide a graphql endpoint, and optionally graphiql. """ -import json -from pathlib import Path -from typing import Any, Optional +import re +from functools import partial +from urllib.parse import parse_qsl +from typing import Type, Any, Optional, Collection +from graphql import ExecutionResult, GraphQLError, specified_rules +from graphql.execution import Middleware +from graphql.type.schema import GraphQLSchema +from graphql.validation import ASTValidationRule from channels.generic.http import AsyncHttpConsumer -from graphql import GraphQLSchema, ExecutionResult +from graphql_server import ( + HttpQueryError, + get_schema, + encode_execution_results, + format_error_default, + json_encode, + load_json_body, + run_http_query, + process_preflight, +) +from graphql_server.render_graphiql import ( + GraphiQLOptions, + render_graphiql_sync, +) -from ..render_graphiql import render_graphiql_sync -from .. import GraphQLParams + +def get_accepted_content_types(accept_header: str): + def qualify(x): + parts = x.split(";", 1) + if len(parts) == 2: + match = re.match(r"(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)", parts[1]) + if match: + return parts[0].strip(), float(match.group(2)) + return parts[0].strip(), 1 + + raw_content_types = accept_header.split(",") + qualified_content_types = map(qualify, raw_content_types) + return list( + x[0] for x in sorted(qualified_content_types, key=lambda x: x[1], reverse=True) + ) class GraphQLHttpConsumer(AsyncHttpConsumer): @@ -23,108 +54,225 @@ def __init__( self.graphiql = graphiql super().__init__() + @property def headers(self): return { header_name.decode("utf-8").lower(): header_value.decode("utf-8") for header_name, header_value in self.scope["headers"] } - async def parse_multipart_body(self, body): - await self.send_response(500, "Unable to parse the multipart body") + accepted_methods = ["GET", "POST", "PUT", "DELETE"] + + format_error = staticmethod(format_error_default) + encode = staticmethod(json_encode) + + schema: GraphQLSchema = None + root_value: Any = None + pretty: bool = False + graphiql: bool = True + middleware: Optional[Middleware] = None + validation_rules: Optional[Collection[Type[ASTValidationRule]]] = None + batch: bool = False + fetch_query_on_load: bool = True + max_age: int = 86400 + cors_allow_origin: Optional[str] = None + graphiql_options: Optional[GraphiQLOptions] = None + + def __init__( + self, + schema: GraphQLSchema, + root_value: Any = None, + pretty: bool = False, + graphiql: bool = True, + middleware: Optional[Middleware] = None, + validation_rules: Optional[Collection[Type[ASTValidationRule]]] = None, + batch: bool = False, + fetch_query_on_load: bool = True, + max_age: int = 86400, + cors_allow_origin: Optional[str] = None, + graphiql_options: Optional[GraphiQLOptions] = None, + ): + self.schema = get_schema(schema) + self.root_value = root_value + self.pretty = pretty + self.graphiql = graphiql + self.graphiql_options = graphiql_options + self.middleware = middleware + self.validation_rules = validation_rules + self.batch = batch + self.fetch_query_on_load = fetch_query_on_load + self.cors_allow_origin = cors_allow_origin + self.max_age = max_age + super().__init__() + + def render_graphiql(self, *args, **kwargs): + return render_graphiql_sync(*args, **kwargs) + + async def get_root_value(self, request) -> Any: return None - async def get_graphql_params(self, data): - query = data.get("query") - variables = data.get("variables") - id = data.get("id") + async def get_context(self, request) -> Any: + return None - if variables and isinstance(variables, str): - try: - variables = json.loads(variables) - except Exception: - await self.send_response(500, b"Variables are invalid JSON.") - return None - operation_name = data.get("operationName") - - return query, variables, operation_name, id - - async def get_request_data(self, body) -> Optional[Any]: - if self.headers.get("content-type", "").startswith("multipart/form-data"): - data = await self.parse_multipart_body(body) - if data is None: - return None - else: + def get_middleware(self): + return self.middleware + + def get_validation_rules(self): + if self.validation_rules is None: + return specified_rules + return self.validation_rules + + def parse_body(self, content_type, body): + if content_type == "application/graphql": + return {"query": body.decode()} + + elif content_type == "application/json": try: - data = json.loads(body) - except json.JSONDecodeError: - await self.send_response(500, b"Unable to parse request body as JSON") - return None + body = body.decode("utf-8") + except Exception as e: + raise HttpQueryError(400, str(e)) - query, variables, operation_name, id = self.get_graphql_params(data) - if not query: - await self.send_response(500, b"No GraphQL query found in the request") - return None + return load_json_body(body, self.batch) - return query, variables, operation_name, id + elif content_type in [ + "application/x-www-form-urlencoded", + # "multipart/form-data", + ]: + return dict(parse_qsl(body.decode("utf-8"))) - async def post(self, body): - request_data = await self.get_request_data(body) - if request_data is None: - return - context = await self.get_context() - root_value = await self.get_root_value() - - result = await self.schema.execute( - query=request_data.query, - root_value=root_value, - variable_values=request_data.variables, - context_value=context, - operation_name=request_data.operation_name, - ) + return {} - response_data = self.process_result(result) - await self.send_response( - 200, - json.dumps(response_data).encode("utf-8"), - headers=[(b"Content-Type", b"application/json")], + def request_prefers_html(self, accept): + + accepted = get_accepted_content_types(accept) + accepted_length = len(accepted) + # the list will be ordered in preferred first - so we have to make + # sure the most preferred gets the highest number + html_priority = ( + accepted_length - accepted.index("text/html") + if "text/html" in accepted + else 0 + ) + json_priority = ( + accepted_length - accepted.index("application/json") + if "application/json" in accepted + else 0 ) - def graphiql_html_file_path(self) -> Path: - return Path(__file__).parent.parent.parent / "static" / "graphiql.html" + return html_priority > json_priority - async def render_graphiql(self, body, params): - # html_string = self.graphiql_html_file_path.read_text() - # html_string = html_string.replace("{{ SUBSCRIPTION_ENABLED }}", "true") - html_string = render_graphiql_sync(body, params) - await self.send_response( - 200, html_string.encode("utf-8"), headers=[(b"Content-Type", b"text/html")] - ) + def is_graphiql(self, request_method: str, is_raw: bool, prefers_html: bool): + return self.graphiql and request_method == "get" and not is_raw and prefers_html - def should_render_graphiql(self): - return bool(self.graphiql and "text/html" in self.headers.get("accept", "")) + def should_prettify(self, is_graphiql: bool, pretty_in_query: bool): + return self.pretty or is_graphiql or pretty_in_query - async def get(self, body): - if self.should_render_graphiql(): - return await self.render_graphiql(body, params=GraphQLParams(query="")) + async def handle(self, body): + if self.cors_allow_origin: + base_cors_headers = [ + (b"Access-Control-Allow-Origin", self.cors_allow_origin) + ] else: - await self.send_response( - 200, "{}", headers=[(b"Content-Type", b"text/json")] + base_cors_headers = [] + try: + req_headers = self.headers + content_type = req_headers.get("content-type", "") + accept_header = req_headers.get("accept", "*/*") + data = self.parse_body(content_type, body) + request_method = self.scope["method"].lower() + prefers_html = self.request_prefers_html(accept_header) or True + query_data = dict(parse_qsl(self.scope.get("query_string", b"").decode("utf-8"))) + is_raw = "raw" in query_data + is_pretty = "pretty" in query_data + is_pretty = False + is_graphiql = self.is_graphiql(request_method, is_raw, prefers_html) + is_pretty = self.should_prettify(is_graphiql, is_pretty) + + if request_method == "options": + origin = req_headers.get("origin", "") + method = req_headers.get("access-control-request-method", "").upper() + response = process_preflight( + origin, method, self.accepted_methods, self.max_age + ) + headers = [ + (b"Content-Type", b"application/json"), + (b"Access-Control-Allow-Headers", b"*"), + ] + if response.headers: + headers += [ + (key.encode("utf-8"), value.encode("utf-8")) + for key, value in response.headers.items() + ] + else: + headers = [] + await self.send_response(response.status_code, b"{}", headers=headers) + return + + graphql_response = run_http_query( + self.schema, + request_method, + data, + query_data=query_data, + batch_enabled=self.batch, + catch=is_graphiql, + # Execute options + run_sync=False, + root_value=await self.get_root_value(self), + context_value=await self.get_context(self), + middleware=self.get_middleware(), + validation_rules=self.get_validation_rules(), ) - async def handle(self, body): - if self.scope["method"] == "GET": - return await self.get(body) - if self.scope["method"] == "POST": - return await self.post(body) - await self.send_response( - 200, b"Method not allowed", headers=[b"Allow", b"GET, POST"] - ) + exec_res = [ + ex if ex is None or isinstance(ex, ExecutionResult) else await ex + for ex in graphql_response.results + ] + response = encode_execution_results( + exec_res, + is_batch=isinstance(data, list), + format_error=self.format_error, + encode=partial(self.encode, pretty=is_pretty), # noqa: ignore + ) - async def get_root_value(self) -> Any: - return None + if is_graphiql: + source = self.render_graphiql( + result=response.body, + params=graphql_response.params[0], + options=self.graphiql_options, + ) + await self.send_response( + 200, + source.encode("utf-8"), + headers=base_cors_headers + [(b"Content-Type", b"text/html")], + ) + return - async def get_context(self) -> Any: - return None + await self.send_response( + response.status_code, + response.body.encode("utf-8"), + headers=base_cors_headers + [(b"Content-Type", b"application/json")], + ) + return - def process_result(self, result: ExecutionResult): - return result.formatted + except HttpQueryError as err: + parsed_error = GraphQLError(err.message) + data = self.encode(dict(errors=[self.format_error(parsed_error)])) + headers = [(b"Content-Type", b"application/json")] + if err.headers: + headers = headers + [(key, value) for key, value in err.headers.items()] + await self.send_response( + err.status_code, + data.encode("utf-8"), + headers=base_cors_headers + headers, + ) + return + except Exception as e: + parsed_error = GraphQLError(str(e)) + data = self.encode(dict(errors=[self.format_error(parsed_error)])) + headers = [(b"Content-Type", b"application/json")] + await self.send_response( + 400, + data.encode("utf-8"), + headers=headers, + ) + return From 994f77f27797acac36a82e3676191c1f9ca48d6a Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Tue, 8 Feb 2022 18:44:26 -0800 Subject: [PATCH 06/14] Fixed GraphiQL --- graphql_server/render_graphiql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphql_server/render_graphiql.py b/graphql_server/render_graphiql.py index 2073d8d..2dfc3ab 100644 --- a/graphql_server/render_graphiql.py +++ b/graphql_server/render_graphiql.py @@ -244,7 +244,7 @@ def get_template_vars( template_vars: Dict[str, Any] = { "result": tojson(data), "query": tojson(params.query), - "variables": tojson(params.variables), + "variables": tojson(json.dumps(params.variables)), "operation_name": tojson(params.operation_name), "html_title": options_with_defaults["html_title"], "graphiql_version": options_with_defaults["graphiql_version"], From 714764976efad28a92ec8dc66d75a3aa4cb39600 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Wed, 9 Feb 2022 15:07:21 -0800 Subject: [PATCH 07/14] Raise error on form-data --- graphql_server/channels/http.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/graphql_server/channels/http.py b/graphql_server/channels/http.py index 6a8ee39..a4f548b 100644 --- a/graphql_server/channels/http.py +++ b/graphql_server/channels/http.py @@ -122,7 +122,7 @@ def get_validation_rules(self): return specified_rules return self.validation_rules - def parse_body(self, content_type, body): + def parse_body(self, content_type: str, body: bytes): if content_type == "application/graphql": return {"query": body.decode()} @@ -139,7 +139,8 @@ def parse_body(self, content_type, body): # "multipart/form-data", ]: return dict(parse_qsl(body.decode("utf-8"))) - + elif content_type.startswith("multipart/form-data"): + raise HttpQueryError(400, "multipart/form-data is not supported in this GraphQL endpoint") return {} def request_prefers_html(self, accept): @@ -181,7 +182,9 @@ async def handle(self, body): data = self.parse_body(content_type, body) request_method = self.scope["method"].lower() prefers_html = self.request_prefers_html(accept_header) or True - query_data = dict(parse_qsl(self.scope.get("query_string", b"").decode("utf-8"))) + query_data = dict( + parse_qsl(self.scope.get("query_string", b"").decode("utf-8")) + ) is_raw = "raw" in query_data is_pretty = "pretty" in query_data is_pretty = False From 5608de6767a9fae4e05e3452ba510b84787f7a04 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Thu, 10 Feb 2022 17:23:37 -0800 Subject: [PATCH 08/14] Added partial support for multipart --- graphql_server/channels/http.py | 14 +-- graphql_server/multipart.py | 190 ++++++++++++++++++++++++++++++++ 2 files changed, 197 insertions(+), 7 deletions(-) create mode 100644 graphql_server/multipart.py diff --git a/graphql_server/channels/http.py b/graphql_server/channels/http.py index a4f548b..6994b90 100644 --- a/graphql_server/channels/http.py +++ b/graphql_server/channels/http.py @@ -26,7 +26,7 @@ GraphiQLOptions, render_graphiql_sync, ) - +from graphql_server.multipart import get_post_and_files def get_accepted_content_types(accept_header: str): def qualify(x): @@ -124,7 +124,7 @@ def get_validation_rules(self): def parse_body(self, content_type: str, body: bytes): if content_type == "application/graphql": - return {"query": body.decode()} + return {"query": body.decode()}, None elif content_type == "application/json": try: @@ -132,16 +132,16 @@ def parse_body(self, content_type: str, body: bytes): except Exception as e: raise HttpQueryError(400, str(e)) - return load_json_body(body, self.batch) + return load_json_body(body, self.batch), None elif content_type in [ "application/x-www-form-urlencoded", - # "multipart/form-data", ]: return dict(parse_qsl(body.decode("utf-8"))) elif content_type.startswith("multipart/form-data"): - raise HttpQueryError(400, "multipart/form-data is not supported in this GraphQL endpoint") - return {} + return get_post_and_files(body, content_type) + # raise HttpQueryError(400, "multipart/form-data is not supported in this GraphQL endpoint") + return {}, None def request_prefers_html(self, accept): @@ -179,7 +179,7 @@ async def handle(self, body): req_headers = self.headers content_type = req_headers.get("content-type", "") accept_header = req_headers.get("accept", "*/*") - data = self.parse_body(content_type, body) + data, files = self.parse_body(content_type, body) request_method = self.scope["method"].lower() prefers_html = self.request_prefers_html(accept_header) or True query_data = dict( diff --git a/graphql_server/multipart.py b/graphql_server/multipart.py new file mode 100644 index 0000000..b5d273e --- /dev/null +++ b/graphql_server/multipart.py @@ -0,0 +1,190 @@ +# -*- coding: utf-8 -*- +""" +This holds all the implementation details of the MultipartDecoder +""" + +# Code adapted from requests_toolbelt.multipart.decoder + + +from collections import defaultdict +from dataclasses import dataclass +import email.parser +from urllib.parse import unquote +from .error import HttpQueryError + +def _split_on_find(content, bound): + point = content.find(bound) + return content[:point], content[point + len(bound):] + + +def _header_parser(string): + + headers = email.parser.HeaderParser().parsestr(string.decode('ascii')).items() + return { + k: v.encode('ascii') + for k, v in headers + } + + +class BodyPart(object): + """ + The ``BodyPart`` object is a ``Response``-like interface to an individual + subpart of a multipart response. It is expected that these will + generally be created by objects of the ``MultipartDecoder`` class. + Like ``Response``, there is a ``dict`` object named headers, + ``content`` to access bytes, ``text`` to access unicode, and ``encoding`` + to access the unicode codec. + """ + + def __init__(self, content): + headers = {} + # Split into header section (if any) and the content + if b'\r\n\r\n' in content: + first, self.content = _split_on_find(content, b'\r\n\r\n') + if first != b'': + headers = _header_parser(first.lstrip()) + else: + raise HttpQueryError( + 400, + 'Multipart content does not contain CR-LF-CR-LF' + ) + self.headers = headers + + +class MultipartDecoder(object): + """ + The ``MultipartDecoder`` object parses the multipart payload of + a bytestring into a tuple of ``Response``-like ``BodyPart`` objects. + The basic usage is:: + import requests + from requests_toolbelt import MultipartDecoder + response = requests.get(url) + decoder = MultipartDecoder.from_response(response) + for part in decoder.parts: + print(part.headers['content-type']) + If the multipart content is not from a response, basic usage is:: + from requests_toolbelt import MultipartDecoder + decoder = MultipartDecoder(content, content_type) + for part in decoder.parts: + print(part.headers['content-type']) + For both these usages, there is an optional ``encoding`` parameter. This is + a string, which is the name of the unicode codec to use (default is + ``'utf-8'``). + """ + def __init__(self, content, content_type, encoding='utf-8'): + #: Original Content-Type header + self.content_type = content_type + #: Response body encoding + self.encoding = encoding + #: Parsed parts of the multipart response body + self.parts = tuple() + self._find_boundary() + self._parse_body(content) + + def _find_boundary(self): + ct_info = tuple(x.strip() for x in self.content_type.split(';')) + mimetype = ct_info[0] + if mimetype.split('/')[0].lower() != 'multipart': + raise HttpQueryError( + 400, + "Unexpected mimetype in content-type: '{}'".format(mimetype) + ) + for item in ct_info[1:]: + attr, value = _split_on_find( + item, + '=' + ) + if attr.lower() == 'boundary': + self.boundary = value.strip('"').encode('utf-8') + + @staticmethod + def _fix_first_part(part, boundary_marker): + bm_len = len(boundary_marker) + if boundary_marker == part[:bm_len]: + return part[bm_len:] + else: + return part + + def _parse_body(self, content): + boundary = b''.join((b'--', self.boundary)) + + def body_part(part): + fixed = MultipartDecoder._fix_first_part(part, boundary) + return BodyPart(fixed) + + def test_part(part): + return (part != b'' and + part != b'\r\n' and + part[:4] != b'--\r\n' and + part != b'--') + + parts = content.split(b''.join((b'\r\n', boundary))) + self.parts = tuple(body_part(x) for x in parts if test_part(x)) + + +@dataclass +class File: + content: bytes + filename: str + +def get_post_and_files(body, content_type): + post = {} + files = {} + parts = MultipartDecoder(body, content_type).parts + for part in parts: + for name, header_value in part.headers.items(): + value, params = parse_header(header_value) + if name.lower() == "content-disposition": + filename = params.get("filename") + if filename: + files[name.decode('utf-8')] = File(content=part.content, filename=filename) + else: + name = params.get("name") + post[name.decode('utf-8')] = part.content.decode('utf-8') + return post, files + + +def parse_header(line): + """ + Parse the header into a key-value. + Input (line): bytes, output: str for key/name, bytes for values which + will be decoded later. + """ + plist = _parse_header_params(b";" + line) + key = plist.pop(0).lower().decode("ascii") + pdict = {} + for p in plist: + i = p.find(b"=") + if i >= 0: + has_encoding = False + name = p[:i].strip().lower().decode("ascii") + if name.endswith("*"): + # Lang/encoding embedded in the value (like "filename*=UTF-8''file.ext") + # https://tools.ietf.org/html/rfc2231#section-4 + name = name[:-1] + if p.count(b"'") == 2: + has_encoding = True + value = p[i + 1 :].strip() + if len(value) >= 2 and value[:1] == value[-1:] == b'"': + value = value[1:-1] + value = value.replace(b"\\\\", b"\\").replace(b'\\"', b'"') + if has_encoding: + encoding, lang, value = value.split(b"'") + value = unquote(value.decode(), encoding=encoding.decode()) + pdict[name] = value + return key, pdict + + +def _parse_header_params(s): + plist = [] + while s[:1] == b";": + s = s[1:] + end = s.find(b";") + while end > 0 and s.count(b'"', 0, end) % 2: + end = s.find(b";", end + 1) + if end < 0: + end = len(s) + f = s[:end] + plist.append(f.strip()) + s = s[end:] + return plist From 58e087d1c933c793c08a2232456d2c74e4905e67 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Thu, 10 Feb 2022 22:39:13 -0800 Subject: [PATCH 09/14] Updated graphql server --- graphql_server/channels/consumer.py | 7 +-- graphql_server/channels/graphql_ws.py | 4 +- graphql_server/channels/http.py | 1 + graphql_server/django/debug_toolbar.py | 2 +- graphql_server/django/views.py | 60 ++++++++++++++++++++----- graphql_server/multipart.py | 61 ++++++++++++-------------- 6 files changed, 83 insertions(+), 52 deletions(-) diff --git a/graphql_server/channels/consumer.py b/graphql_server/channels/consumer.py index 135a735..5ea0b27 100644 --- a/graphql_server/channels/consumer.py +++ b/graphql_server/channels/consumer.py @@ -6,7 +6,7 @@ from datetime import timedelta from typing import Any, Optional, Sequence, Union -from django.http import HttpRequest, HttpResponse +from django.http import HttpRequest from django.urls import re_path from channels.generic.websocket import ( @@ -127,8 +127,5 @@ async def get_root_value( async def get_context( self, request: Union[HttpRequest, AsyncJsonWebsocketConsumer] = None, - response: Optional[HttpResponse] = None, ) -> Optional[Any]: - return GraphQLChannelsContext( - request=request or self, response=response, scope=self.scope - ) + return GraphQLChannelsContext(request=request or self, scope=self.scope) diff --git a/graphql_server/channels/graphql_ws.py b/graphql_server/channels/graphql_ws.py index 45e08eb..dfcd00b 100644 --- a/graphql_server/channels/graphql_ws.py +++ b/graphql_server/channels/graphql_ws.py @@ -24,10 +24,10 @@ def __init__( self._ws = ws async def get_context(self) -> Any: - return await self._get_context(request=self._ws) + return await self._get_context(self._ws) async def get_root_value(self) -> Any: - return await self._get_root_value(request=self._ws) + return await self._get_root_value(self._ws) async def send_json(self, data: OperationMessage) -> None: await self._ws.send_json(data) diff --git a/graphql_server/channels/http.py b/graphql_server/channels/http.py index 6994b90..bfe7202 100644 --- a/graphql_server/channels/http.py +++ b/graphql_server/channels/http.py @@ -28,6 +28,7 @@ ) from graphql_server.multipart import get_post_and_files + def get_accepted_content_types(accept_header: str): def qualify(x): parts = x.split(";", 1) diff --git a/graphql_server/django/debug_toolbar.py b/graphql_server/django/debug_toolbar.py index 197048e..f78aff1 100644 --- a/graphql_server/django/debug_toolbar.py +++ b/graphql_server/django/debug_toolbar.py @@ -123,7 +123,7 @@ def _unwrap_cursor(connection): class DebugToolbarMiddleware(_DebugToolbarMiddleware): sync_capable = True - async_capable = True + async_capable = False def __call__(self, request: HttpRequest): response = super().__call__(request) diff --git a/graphql_server/django/views.py b/graphql_server/django/views.py index 5eb0cad..0c6d89f 100644 --- a/graphql_server/django/views.py +++ b/graphql_server/django/views.py @@ -1,15 +1,14 @@ import asyncio import re from functools import partial -from http.client import HTTPResponse -from typing import Type, Any, Optional, Collection +from typing import Type, Any, Optional, Collection, Dict from graphql import ExecutionResult, GraphQLError, specified_rules from graphql.execution import Middleware from graphql.type.schema import GraphQLSchema from graphql.validation import ASTValidationRule from django.views.generic import View -from django.http import HttpResponse, HttpRequest, HttpResponseBadRequest +from django.http import HttpResponse, HttpRequest from django.utils.decorators import classonlymethod, method_decorator from django.views.decorators.csrf import csrf_exempt @@ -62,6 +61,7 @@ class GraphQLView(View): fetch_query_on_load: bool = True max_age: int = 86400 graphiql_options: Optional[GraphiQLOptions] = None + cors_allow_origin: Optional[str] = None def __init__( self, @@ -75,6 +75,7 @@ def __init__( fetch_query_on_load: bool = True, max_age: int = 86400, graphiql_options: Optional[GraphiQLOptions] = None, + cors_allow_origin: Optional[str] = None, ): self.schema = get_schema(schema) self.root_value = root_value @@ -86,6 +87,10 @@ def __init__( self.batch = batch self.fetch_query_on_load = fetch_query_on_load self.max_age = max_age + self.cors_allow_origin = cors_allow_origin + + def get_graphiql_options(self, request: HttpRequest): + return self.graphiql_options def render_graphiql(self, *args, **kwargs): return render_graphiql_sync(*args, **kwargs) @@ -104,6 +109,17 @@ def get_validation_rules(self): return specified_rules return self.validation_rules + def construct_headers(self, headers: Optional[Dict] = None): + if self.cors_allow_origin: + return dict( + headers or {}, + **{ + "Access-Control-Allow-Origin": self.cors_allow_origin, + } + ) + else: + return headers + def parse_body(self, request: HttpRequest): content_type = request.content_type @@ -169,8 +185,16 @@ def dispatch(self, request: HttpRequest, *args, **kwargs): response = process_preflight( origin, method, self.accepted_methods, self.max_age ) - return HTTPResponse( - status=response.status_code, headers=response.headers + return_headers = { + "Content-Type": "application/json", + "Access-Control-Allow-Headers": "*", + } + + return HttpResponse( + status=response.status_code, + headers=self.construct_headers( + dict(response.headers or {}, **return_headers) + ), ) graphql_response = run_http_query( @@ -199,12 +223,13 @@ def dispatch(self, request: HttpRequest, *args, **kwargs): source = self.render_graphiql( result=response.body, params=graphql_response.params[0], - options=self.graphiql_options, + options=self.get_graphiql_options(request), ) return HttpResponse(content=source, content_type="text/html") return HttpResponse( content=response.body, + headers=self.construct_headers(), content_type="application/json", status=response.status_code, ) @@ -215,7 +240,7 @@ def dispatch(self, request: HttpRequest, *args, **kwargs): content=self.encode(dict(errors=[self.format_error(parsed_error)])), content_type="application/json", status=err.status_code, - headers=err.headers, + headers=self.construct_headers(err.headers), ) @@ -247,8 +272,15 @@ async def dispatch(self, request, *args, **kwargs): response = process_preflight( origin, method, self.accepted_methods, self.max_age ) - return HTTPResponse( - status=response.status_code, headers=response.headers + return_headers = { + "Content-Type": "application/json", + "Access-Control-Allow-Headers": "*", + } + return HttpResponse( + status=response.status_code, + headers=self.construct_headers( + dict(response.headers or {}, **return_headers) + ), ) graphql_response = run_http_query( @@ -282,12 +314,13 @@ async def dispatch(self, request, *args, **kwargs): source = self.render_graphiql( result=response.body, params=graphql_response.params[0], - options=self.graphiql_options, + options=await self.get_graphiql_options(request), ) return HttpResponse(content=source, content_type="text/html") return HttpResponse( content=response.body, + headers=self.construct_headers(), content_type="application/json", status=response.status_code, ) @@ -298,10 +331,13 @@ async def dispatch(self, request, *args, **kwargs): content=self.encode(dict(errors=[self.format_error(parsed_error)])), content_type="application/json", status=err.status_code, - headers=err.headers, + headers=self.construct_headers(err.headers), ) - async def get_root_value(self, request: HttpRequest) -> Any: + async def get_graphiql_options(self, _request: HttpRequest) -> Any: + return self.graphiql_options + + async def get_root_value(self, _request: HttpRequest) -> Any: return None async def get_context(self, request: HttpRequest) -> Any: diff --git a/graphql_server/multipart.py b/graphql_server/multipart.py index b5d273e..190abb4 100644 --- a/graphql_server/multipart.py +++ b/graphql_server/multipart.py @@ -12,18 +12,16 @@ from urllib.parse import unquote from .error import HttpQueryError + def _split_on_find(content, bound): point = content.find(bound) - return content[:point], content[point + len(bound):] + return content[:point], content[point + len(bound) :] def _header_parser(string): - headers = email.parser.HeaderParser().parsestr(string.decode('ascii')).items() - return { - k: v.encode('ascii') - for k, v in headers - } + headers = email.parser.HeaderParser().parsestr(string.decode("ascii")).items() + return {k: v.encode("ascii") for k, v in headers} class BodyPart(object): @@ -39,15 +37,12 @@ class BodyPart(object): def __init__(self, content): headers = {} # Split into header section (if any) and the content - if b'\r\n\r\n' in content: - first, self.content = _split_on_find(content, b'\r\n\r\n') - if first != b'': + if b"\r\n\r\n" in content: + first, self.content = _split_on_find(content, b"\r\n\r\n") + if first != b"": headers = _header_parser(first.lstrip()) else: - raise HttpQueryError( - 400, - 'Multipart content does not contain CR-LF-CR-LF' - ) + raise HttpQueryError(400, "Multipart content does not contain CR-LF-CR-LF") self.headers = headers @@ -71,7 +66,8 @@ class MultipartDecoder(object): a string, which is the name of the unicode codec to use (default is ``'utf-8'``). """ - def __init__(self, content, content_type, encoding='utf-8'): + + def __init__(self, content, content_type, encoding="utf-8"): #: Original Content-Type header self.content_type = content_type #: Response body encoding @@ -82,20 +78,16 @@ def __init__(self, content, content_type, encoding='utf-8'): self._parse_body(content) def _find_boundary(self): - ct_info = tuple(x.strip() for x in self.content_type.split(';')) + ct_info = tuple(x.strip() for x in self.content_type.split(";")) mimetype = ct_info[0] - if mimetype.split('/')[0].lower() != 'multipart': + if mimetype.split("/")[0].lower() != "multipart": raise HttpQueryError( - 400, - "Unexpected mimetype in content-type: '{}'".format(mimetype) + 400, "Unexpected mimetype in content-type: '{}'".format(mimetype) ) for item in ct_info[1:]: - attr, value = _split_on_find( - item, - '=' - ) - if attr.lower() == 'boundary': - self.boundary = value.strip('"').encode('utf-8') + attr, value = _split_on_find(item, "=") + if attr.lower() == "boundary": + self.boundary = value.strip('"').encode("utf-8") @staticmethod def _fix_first_part(part, boundary_marker): @@ -106,19 +98,21 @@ def _fix_first_part(part, boundary_marker): return part def _parse_body(self, content): - boundary = b''.join((b'--', self.boundary)) + boundary = b"".join((b"--", self.boundary)) def body_part(part): fixed = MultipartDecoder._fix_first_part(part, boundary) return BodyPart(fixed) def test_part(part): - return (part != b'' and - part != b'\r\n' and - part[:4] != b'--\r\n' and - part != b'--') + return ( + part != b"" + and part != b"\r\n" + and part[:4] != b"--\r\n" + and part != b"--" + ) - parts = content.split(b''.join((b'\r\n', boundary))) + parts = content.split(b"".join((b"\r\n", boundary))) self.parts = tuple(body_part(x) for x in parts if test_part(x)) @@ -127,6 +121,7 @@ class File: content: bytes filename: str + def get_post_and_files(body, content_type): post = {} files = {} @@ -137,10 +132,12 @@ def get_post_and_files(body, content_type): if name.lower() == "content-disposition": filename = params.get("filename") if filename: - files[name.decode('utf-8')] = File(content=part.content, filename=filename) + files[name.decode("utf-8")] = File( + content=part.content, filename=filename + ) else: name = params.get("name") - post[name.decode('utf-8')] = part.content.decode('utf-8') + post[name.decode("utf-8")] = part.content.decode("utf-8") return post, files From 792777e1177e871eb8d09699a4233989efa8d191 Mon Sep 17 00:00:00 2001 From: Ayush Jha Date: Tue, 16 May 2023 19:36:15 +0545 Subject: [PATCH 10/14] Fix bugs with transport-ws-protocol (#111) I am guessing this branch was only used with GraphQLWSHandler, and not with GraphQLTransportWSHandler. That's probably why there were a bunch of bugs. This PR fixes those. --- .../channels/graphql_transport_ws.py | 8 +-- .../transport_ws_protocol/handlers.py | 12 +++- .../websockets/transport_ws_protocol/types.py | 71 ++++++++++++------- 3 files changed, 58 insertions(+), 33 deletions(-) diff --git a/graphql_server/channels/graphql_transport_ws.py b/graphql_server/channels/graphql_transport_ws.py index 8b946ef..caefcbb 100644 --- a/graphql_server/channels/graphql_transport_ws.py +++ b/graphql_server/channels/graphql_transport_ws.py @@ -24,10 +24,10 @@ def __init__( self._ws = ws async def get_context(self) -> Any: - return await self._get_context(request=self._ws) + return await self._get_context(self._ws) async def get_root_value(self) -> Any: - return await self._get_root_value(request=self._ws) + return await self._get_root_value(self._ws) async def send_json(self, data: dict) -> None: await self._ws.send_json(data) @@ -37,9 +37,7 @@ async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: await self._ws.close(code=code) async def handle_request(self) -> Any: - await self._ws.accept( - subprotocol=BaseGraphQLTransportWSHandler.GRAPHQL_TRANSPORT_WS_PROTOCOL - ) + await self._ws.accept(subprotocol=BaseGraphQLTransportWSHandler.PROTOCOL) async def handle_disconnect(self, code): for operation_id in list(self.subscriptions.keys()): diff --git a/graphql_server/websockets/transport_ws_protocol/handlers.py b/graphql_server/websockets/transport_ws_protocol/handlers.py index 8df0b56..c3a3ea6 100644 --- a/graphql_server/websockets/transport_ws_protocol/handlers.py +++ b/graphql_server/websockets/transport_ws_protocol/handlers.py @@ -15,6 +15,7 @@ from ..constants import GRAPHQL_TRANSPORT_WS_PROTOCOL from .types import ( + Message, CompleteMessage, ConnectionAckMessage, ConnectionInitMessage, @@ -137,7 +138,8 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: if not self.connection_acknowledged: await self.close(code=4401, reason="Unauthorized") return - + if isinstance(message, dict): + message = SubscribeMessage.from_dict(message) if message.id in self.subscriptions.keys(): reason = f"Subscriber for {message.id} already exists" await self.close(code=4409, reason=reason) @@ -210,8 +212,12 @@ async def handle_complete(self, message: CompleteMessage) -> None: async def handle_invalid_message(self, error_message: str) -> None: await self.close(code=4400, reason=error_message) - async def send_message(self, data: TypedDict) -> None: - await self.send_message({**data, "type": data.type}) + async def send_message(self, data: Message) -> None: + data = data.asdict() + assert ( + data.get("type") is not None + ), "expected dict with `type` field. Got {} instead".format(data) + await self.send_json(data) async def cleanup_operation(self, operation_id: str) -> None: await self.subscriptions[operation_id].aclose() diff --git a/graphql_server/websockets/transport_ws_protocol/types.py b/graphql_server/websockets/transport_ws_protocol/types.py index 3042681..206af93 100644 --- a/graphql_server/websockets/transport_ws_protocol/types.py +++ b/graphql_server/websockets/transport_ws_protocol/types.py @@ -5,6 +5,8 @@ except ImportError: from typing_extensions import TypedDict +from dataclasses import dataclass, asdict + from .contstants import ( GQL_CONNECTION_INIT, GQL_CONNECTION_ACK, @@ -17,84 +19,103 @@ ) -class ConnectionInitMessage(TypedDict): +class Message: + def asdict(self): + return {key: value for key, value in asdict(self).items() if value is not None} + + +@dataclass +class ConnectionInitMessage(Message): """ Direction: Client -> Server """ - payload: Optional[Dict[str, Any]] - type = GQL_CONNECTION_INIT + payload: Optional[Dict[str, Any]] = None + type: str = GQL_CONNECTION_INIT -class ConnectionAckMessage(TypedDict): +@dataclass +class ConnectionAckMessage(Message): """ Direction: Server -> Client """ - payload: Optional[Dict[str, Any]] - type = GQL_CONNECTION_ACK + payload: Optional[Dict[str, Any]] = None + type: str = GQL_CONNECTION_ACK -class PingMessage(TypedDict): +@dataclass +class PingMessage(Message): """ Direction: bidirectional """ - payload: Optional[Dict[str, Any]] - type = GQL_PING + payload: Optional[Dict[str, Any]] = None + type: str = GQL_PING -class PongMessage(TypedDict): +@dataclass +class PongMessage(Message): """ Direction: bidirectional """ - payload: Optional[Dict[str, Any]] - type = GQL_PONG + payload: Optional[Dict[str, Any]] = None + type: str = GQL_PONG -class SubscribeMessagePayload(TypedDict): +@dataclass +class SubscribeMessagePayload(Message): query: str - operationName: Optional[str] - variables: Optional[Dict[str, Any]] - extensions: Optional[Dict[str, Any]] + operationName: Optional[str] = None + variables: Optional[Dict[str, Any]] = None + extensions: Optional[Dict[str, Any]] = None -class SubscribeMessage(TypedDict): +@dataclass +class SubscribeMessage(Message): """ Direction: Client -> Server """ id: str payload: SubscribeMessagePayload - type = GQL_SUBSCRIBE + type: str = GQL_SUBSCRIBE + @classmethod + def from_dict(cls, message: dict): + subscribe_message = cls(**message) + subscribe_message.payload = SubscribeMessagePayload(**subscribe_message.payload) + return subscribe_message -class NextMessage(TypedDict): + +@dataclass +class NextMessage(Message): """ Direction: Server -> Client """ id: str payload: Dict[str, Any] # TODO: shape like ExecutionResult - type = GQL_NEXT + type: str = GQL_NEXT -class ErrorMessage(TypedDict): +@dataclass +class ErrorMessage(Message): """ Direction: Server -> Client """ id: str payload: List[Dict[str, Any]] # TODO: shape like List[GraphQLError] - type = GQL_ERROR + type: str = GQL_ERROR -class CompleteMessage(TypedDict): +@dataclass +class CompleteMessage(Message): """ Direction: bidirectional """ - type = GQL_COMPLETE - id: str + type: str = GQL_COMPLETE From 08b676ffcf2a6e6c91537e61500aabc1b5820bf3 Mon Sep 17 00:00:00 2001 From: Ayush Jha Date: Mon, 4 Dec 2023 01:05:12 +0545 Subject: [PATCH 11/14] Implement the `send_xjson` abstract method defined in parent class This abstract method is defined in BaseGraphQLTransportWSHandler and is used instead of `send_json` like in graphql_ws protocol implementation. --- graphql_server/channels/graphql_transport_ws.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/graphql_server/channels/graphql_transport_ws.py b/graphql_server/channels/graphql_transport_ws.py index caefcbb..5f10d6e 100644 --- a/graphql_server/channels/graphql_transport_ws.py +++ b/graphql_server/channels/graphql_transport_ws.py @@ -32,6 +32,9 @@ async def get_root_value(self) -> Any: async def send_json(self, data: dict) -> None: await self._ws.send_json(data) + async def send_xjson(self, data: dict) -> None: + await self._ws.send_json(data) + async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: # Close messages are not part of the ASGI ref yet await self._ws.close(code=code) From 3b0df43afb5c96e32ba5ee170e2e77cb55f81986 Mon Sep 17 00:00:00 2001 From: Ayush Jha Date: Fri, 12 Jan 2024 13:15:12 +0545 Subject: [PATCH 12/14] bugfix: multipart uploads now work This makes multi-part post requests work in cases where one or more files are being sent with the graphql query. --- graphql_server/multipart.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/graphql_server/multipart.py b/graphql_server/multipart.py index 190abb4..93d161d 100644 --- a/graphql_server/multipart.py +++ b/graphql_server/multipart.py @@ -127,16 +127,14 @@ def get_post_and_files(body, content_type): files = {} parts = MultipartDecoder(body, content_type).parts for part in parts: - for name, header_value in part.headers.items(): + for header_name, header_value in part.headers.items(): value, params = parse_header(header_value) - if name.lower() == "content-disposition": + if header_name.lower() == "content-disposition": + name = params.get("name") filename = params.get("filename") if filename: - files[name.decode("utf-8")] = File( - content=part.content, filename=filename - ) + files[name] = File(content=part.content, filename=filename) else: - name = params.get("name") post[name.decode("utf-8")] = part.content.decode("utf-8") return post, files From 637b36f8312d2ceb642e9ec7a519de6187098abd Mon Sep 17 00:00:00 2001 From: Ayush Jha Date: Fri, 12 Jan 2024 18:08:53 +0545 Subject: [PATCH 13/14] Accept utf-8 encoded header values when parsing multipart requests This is done to enable passing in utf-8 filenames in headers. --- graphql_server/multipart.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/graphql_server/multipart.py b/graphql_server/multipart.py index 93d161d..4c8d85e 100644 --- a/graphql_server/multipart.py +++ b/graphql_server/multipart.py @@ -19,9 +19,8 @@ def _split_on_find(content, bound): def _header_parser(string): - - headers = email.parser.HeaderParser().parsestr(string.decode("ascii")).items() - return {k: v.encode("ascii") for k, v in headers} + headers = email.parser.HeaderParser().parsestr(string.decode("utf-8")).items() + return {k: v.encode("utf-8") for k, v in headers} class BodyPart(object): @@ -146,13 +145,13 @@ def parse_header(line): will be decoded later. """ plist = _parse_header_params(b";" + line) - key = plist.pop(0).lower().decode("ascii") + key = plist.pop(0).lower().decode("utf-8") pdict = {} for p in plist: i = p.find(b"=") if i >= 0: has_encoding = False - name = p[:i].strip().lower().decode("ascii") + name = p[:i].strip().lower().decode("utf-8") if name.endswith("*"): # Lang/encoding embedded in the value (like "filename*=UTF-8''file.ext") # https://tools.ietf.org/html/rfc2231#section-4 From 6959cb0254b76752578bb6ae10b8d145ded92c12 Mon Sep 17 00:00:00 2001 From: Ayush Jha Date: Mon, 26 Feb 2024 12:38:28 +0545 Subject: [PATCH 14/14] Allow passing DocumentNode to `get_response` to bypass validation to add support for persisted queries, we want to bypass repeated validation of queries that have already been parsed and validated. This allows that by letting the user pass in the `DocumentNode` directly which bypasses much of the parsing and validation, making the query processing faster. --- graphql_server/__init__.py | 104 ++++++++++++++++++++++--------------- 1 file changed, 62 insertions(+), 42 deletions(-) diff --git a/graphql_server/__init__.py b/graphql_server/__init__.py index 62c19b7..6bb44d5 100644 --- a/graphql_server/__init__.py +++ b/graphql_server/__init__.py @@ -25,7 +25,7 @@ from graphql.error import GraphQLError from graphql.execution import ExecutionResult, execute -from graphql.language import OperationType, parse +from graphql.language import OperationType, parse, DocumentNode from graphql.pyutils import AwaitableOrValue from graphql.type import GraphQLSchema, validate_schema from graphql.utilities import get_operation_ast @@ -61,7 +61,7 @@ @dataclass class GraphQLParams: - query: str + query: str | DocumentNode variables: Optional[Dict[str, Any]] = None operation_name: Optional[str] = None @@ -158,8 +158,6 @@ def run_http_query( all_params: List[GraphQLParams] = [ get_graphql_params(entry, extra_data) for entry in data ] - # print("GET ROOT VALUE 0", type(request_method), all_params) - # print(dict(schema=schema, all_params=all_params, catch_exc=catch_exc, allow_only=allow_only_query, run_sync=run_sync)) results: List[Optional[AwaitableOrValue[ExecutionResult]]] = [ get_response( @@ -167,7 +165,6 @@ def run_http_query( ) for params in all_params ] - # print("GET ROOT VALUE 1") return GraphQLResponse(results=results, params=all_params) @@ -314,6 +311,55 @@ def assume_not_awaitable(_value: Any) -> bool: return False +def parse_document( + schema: GraphQLSchema, + params: GraphQLParams, + allow_only_query: bool = False, + validation_rules: Optional[Collection[Type[ASTValidationRule]]] = None, + max_errors: Optional[int] = None, +) -> Optional[Dict]: + if not params.query: + raise HttpQueryError(400, "Must provide query string.") + + if not isinstance(params.query, str) and not isinstance(params.query, DocumentNode): + raise HttpQueryError(400, "Unexpected query type.") + + if isinstance(params.query, DocumentNode): + return params.query + schema_validation_errors = validate_schema(schema) + if schema_validation_errors: + return ExecutionResult(data=None, errors=schema_validation_errors) + + try: + document = parse(params.query) + except GraphQLError as e: + return ExecutionResult(data=None, errors=[e]) + except Exception as e: + e = GraphQLError(str(e), original_error=e) + return ExecutionResult(data=None, errors=[e]) + + if allow_only_query: + operation_ast = get_operation_ast(document, params.operation_name) + if operation_ast: + operation = operation_ast.operation.value + if operation != OperationType.QUERY.value: + raise HttpQueryError( + 405, + f"Can only perform a {operation} operation" " from a POST request.", + headers={"Allow": "POST"}, + ) + + validation_errors = validate( + schema, + document, + rules=validation_rules, + max_errors=max_errors, + ) + if validation_errors: + return ExecutionResult(data=None, errors=validation_errors) + return document + + def get_response( schema: GraphQLSchema, params: GraphQLParams, @@ -334,44 +380,18 @@ def get_response( belong to an exception class specified by catch_exc. """ # noinspection PyBroadException + document = parse_document( + schema, + params, + allow_only_query, + validation_rules, + max_errors, + ) + if isinstance(document, ExecutionResult): + return document + if not isinstance(document, DocumentNode): + raise Exception("GraphQL query could not be parsed properly.") try: - if not params.query: - raise HttpQueryError(400, "Must provide query string.") - - # Sanity check query - if not isinstance(params.query, str): - raise HttpQueryError(400, "Unexpected query type.") - - schema_validation_errors = validate_schema(schema) - if schema_validation_errors: - return ExecutionResult(data=None, errors=schema_validation_errors) - - try: - document = parse(params.query) - except GraphQLError as e: - return ExecutionResult(data=None, errors=[e]) - except Exception as e: - e = GraphQLError(str(e), original_error=e) - return ExecutionResult(data=None, errors=[e]) - - if allow_only_query: - operation_ast = get_operation_ast(document, params.operation_name) - if operation_ast: - operation = operation_ast.operation.value - if operation != OperationType.QUERY.value: - raise HttpQueryError( - 405, - f"Can only perform a {operation} operation" - " from a POST request.", - headers={"Allow": "POST"}, - ) - - validation_errors = validate( - schema, document, rules=validation_rules, max_errors=max_errors - ) - if validation_errors: - return ExecutionResult(data=None, errors=validation_errors) - execution_result = execute( schema, document,