Skip to content

Commit

Permalink
Preparing to enable SSL/TLS for gRPC server (adap#842)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanertopal authored Jan 5, 2022
1 parent 8b56703 commit 5b92ef6
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 15 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ data/
doc/source/api_documentation
doc/source/_build
flwr_logs
.cache

# Datasets
cifar-10-python.tar.gz
Expand Down
20 changes: 20 additions & 0 deletions dev/certificates/certificate.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[req]
default_bits = 4096
prompt = no
default_md = sha256
req_extensions = req_ext
distinguished_name = dn

[dn]
C = DE
ST = HH
O = Flower
CN = localhost

[req_ext]
subjectAltName = @alt_names

[alt_names]
DNS.1 = localhost
IP.1 = ::1
IP.2 = 127.0.0.1
52 changes: 52 additions & 0 deletions dev/certificates/generate.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/bin/bash
# This script will generate all certificates if ca.crt does not exist

set -e
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/../../

CA_PASSWORD=notsafe

CERT_DIR=.cache/certificates

# Generate directories if not exists
mkdir -p .cache/certificates

# if [ -f ".cache/certificates/ca.crt" ]; then
# echo "Skipping certificate generation as they already exist."
# exit 0
# fi

rm -f $CERT_DIR/*

# Generate the root certificate authority key and certificate based on key
openssl genrsa -out $CERT_DIR/ca.key 4096
openssl req \
-new \
-x509 \
-key $CERT_DIR/ca.key \
-sha256 \
-subj "/C=DE/ST=HH/O=CA, Inc." \
-days 365 -out $CERT_DIR/ca.crt

# Generate a new private key for the server
openssl genrsa -out $CERT_DIR/server.key 4096

# Create a signing CSR
openssl req \
-new \
-key $CERT_DIR/server.key \
-out $CERT_DIR/server.csr \
-config ./dev/certificates/certificate.conf

# Generate a certificate for the server
openssl x509 \
-req \
-in $CERT_DIR/server.csr \
-CA $CERT_DIR/ca.crt \
-CAkey $CERT_DIR/ca.key \
-CAcreateserial \
-out $CERT_DIR/server.pem \
-days 365 \
-sha256 \
-extfile ./dev/certificates/certificate.conf \
-extensions req_ext
4 changes: 2 additions & 2 deletions src/py/flwr/client/grpc_client/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
from flwr.server.client_manager import SimpleClientManager
from flwr.server.grpc_server.grpc_server import start_insecure_grpc_server
from flwr.server.grpc_server.grpc_server import start_grpc_server

from .connection import insecure_grpc_connection

Expand Down Expand Up @@ -82,7 +82,7 @@ def test_integration_connection() -> None:
# Prepare
port = unused_tcp_port()

server = start_insecure_grpc_server(
server = start_grpc_server(
client_manager=SimpleClientManager(), server_address=f"[::]:{port}"
)

Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
from flwr.common.logger import log
from flwr.server.client_manager import SimpleClientManager
from flwr.server.grpc_server.grpc_server import start_insecure_grpc_server
from flwr.server.grpc_server.grpc_server import start_grpc_server
from flwr.server.history import History
from flwr.server.server import Server
from flwr.server.strategy import FedAvg, Strategy
Expand Down Expand Up @@ -70,7 +70,7 @@ def start_server( # pylint: disable=too-many-arguments
initialized_server, initialized_config = _init_defaults(server, config, strategy)

# Start gRPC server
grpc_server = start_insecure_grpc_server(
grpc_server = start_grpc_server(
client_manager=initialized_server.client_manager(),
server_address=server_address,
max_message_length=grpc_max_message_length,
Expand Down
84 changes: 78 additions & 6 deletions src/py/flwr/server/grpc_server/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,45 +12,98 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implements utility function to create a grpc server."""
"""Implements utility function to create a gRPC server."""


import concurrent.futures
import sys
from logging import ERROR
from typing import ByteString, Optional, Tuple

import grpc

from flwr.common import GRPC_MAX_MESSAGE_LENGTH
from flwr.common.logger import log
from flwr.proto import transport_pb2_grpc
from flwr.server.client_manager import ClientManager
from flwr.server.grpc_server import flower_service_servicer as fss

INVALID_SSL_FILES_ERR_MSG = """
When setting any of root_certificate, certificate, or private_key,
all of them need to be set.
"""


def valid_ssl_files(ssl_files: Tuple[ByteString, ByteString, ByteString]) -> bool:
"""Validate ssl_files tuple."""
is_valid = (
all(isinstance(ssl_file, bytes) for ssl_file in ssl_files)
and len(ssl_files) == 3
)

if not is_valid:
log(ERROR, INVALID_SSL_FILES_ERR_MSG)

return is_valid


def start_insecure_grpc_server(
def start_grpc_server(
client_manager: ClientManager,
server_address: str,
max_concurrent_workers: int = 1000,
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
ssl_files: Optional[Tuple[ByteString, ByteString, ByteString]] = None,
) -> grpc.Server:
"""Create grpc server and return registered FlowerServiceServicer instance.
"""Create gRPC server and return instance of grpc.Server.
If used in a main function server.wait_for_termination(timeout=None)
should be called as otherwise the server will immediately stop.
**SSL/TLS**
To enable SSL/TLS you have to pass all of root_certificate, certificate,
and private_key. Setting only some will make the process exit with code 1.
Parameters
----------
client_manager : ClientManager
Instance of ClientManager
server_address : str
Server address in the form of HOST:PORT e.g. "[::]:8080"
max_concurrent_workers : int
Set the maximum number of clients you want the server to process
before returning RESOURCE_EXHAUSTED status (default: 1000)
Maximum number of clients the server can process before returning
RESOURCE_EXHAUSTED status (default: 1000)
max_message_length : int
Maximum message length that the server can send or receive.
Int valued in bytes. -1 means unlimited. (default: GRPC_MAX_MESSAGE_LENGTH)
ssl_files : Tuple[ByteString, ByteString, ByteString]
Tuple containing root certificate, server certificate, and private key to start
a secure SSL/TLS server. The tuple is expected to have three byte string
elements in the following order:
* CA certificate.
* server certificate.
* server private key.
(default: None)
Returns
-------
server : grpc.Server
An instance of a gRPC server which is already started
Examples
--------
Starting a SSL/TLS enabled server.
>>> from pathlib import Path
>>> start_grpc_server(
>>> client_manager=ClientManager(),
>>> server_address="localhost:8080",
>>> ssl_files=(
>>> Path("/crts/root.pem").read_bytes(),
>>> Path("/crts/localhost.crt").read_bytes(),
>>> Path("/crts/localhost.key").read_bytes()
>>> )
"""
server = grpc.server(
concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrent_workers),
Expand All @@ -73,7 +126,26 @@ def start_insecure_grpc_server(
servicer = fss.FlowerServiceServicer(client_manager)
transport_pb2_grpc.add_FlowerServiceServicer_to_server(servicer, server)

server.add_insecure_port(server_address)
if ssl_files is not None:
if not valid_ssl_files(ssl_files):
sys.exit(1)

root_certificate_b, certificate_b, private_key_b = ssl_files

server_credentials = grpc.ssl_server_credentials(
((private_key_b, certificate_b),),
root_certificates=root_certificate_b,
# A boolean indicating whether or not to require clients to be
# authenticated. May only be True if root_certificates is not None.
# We are explicitly setting the current gRPC default to document
# the option. For further reference see:
# https://grpc.github.io/grpc/python/grpc.html#create-server-credentials
require_client_auth=False,
)
server.add_secure_port(server_address, server_credentials)
else:
server.add_insecure_port(server_address)

server.start()

return server
78 changes: 73 additions & 5 deletions src/py/flwr/server/grpc_server/grpc_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,34 @@
# limitations under the License.
# ==============================================================================
"""Tests for module server."""

import socket
import subprocess
from contextlib import closing
from typing import cast
from os.path import abspath, dirname, join
from pathlib import Path
from typing import Tuple, cast

from flwr.server.client_manager import SimpleClientManager
from flwr.server.grpc_server.grpc_server import start_insecure_grpc_server
from flwr.server.grpc_server.grpc_server import start_grpc_server, valid_ssl_files

root_dir = dirname(abspath(join(__file__, "../../../../..")))


def load_certificates() -> Tuple[str, str, str]:
"""Generate and load SSL/TLS credentials/certificates.
Utility function for loading for SSL/TLS enabled gRPC servertests.
"""
# Trigger script which generates the certificates
subprocess.run(["bash", "./dev/certificates/generate.sh"], check=True, cwd=root_dir)

ssl_files = (
join(root_dir, ".cache/certificates/ca.crt"),
join(root_dir, ".cache/certificates/server.pem"),
join(root_dir, ".cache/certificates/server.key"),
)

return ssl_files


def unused_tcp_port() -> int:
Expand All @@ -30,16 +51,63 @@ def unused_tcp_port() -> int:
return cast(int, sock.getsockname()[1])


def test_integration_start_and_shutdown_server() -> None:
def test_valid_ssl_files_when_correct() -> None:
"""Test is validation function works correctly when passed valid list."""
# Prepare
ssl_files = (b"a_byte_string", b"a_byte_string", b"a_byte_string")

# Execute
is_valid = valid_ssl_files(ssl_files)

# Assert
assert is_valid


def test_valid_ssl_files_when_wrong() -> None:
"""Test is validation function works correctly when passed invalid list."""
# Prepare
ssl_files = ("not_a_byte_string", b"a_byte_string", b"a_byte_string")

# Execute
is_valid = valid_ssl_files(ssl_files) # type: ignore

# Assert
assert not is_valid


def test_integration_start_and_shutdown_insecure_server() -> None:
"""Create server and check if FlowerServiceServicer is returned."""
# Prepare
port = unused_tcp_port()
client_manager = SimpleClientManager()

# Execute
server = start_insecure_grpc_server(
server = start_grpc_server(
client_manager=client_manager, server_address=f"[::]:{port}"
)

# Teardown
server.stop(1)


def test_integration_start_and_shutdown_secure_server() -> None:
"""Create server and check if FlowerServiceServicer is returned."""
# Prepare
port = unused_tcp_port()
client_manager = SimpleClientManager()

ssl_files = load_certificates()

# Execute
server = start_grpc_server(
client_manager=client_manager,
server_address=f"[::]:{port}",
ssl_files=(
Path(ssl_files[0]).read_bytes(),
Path(ssl_files[1]).read_bytes(),
Path(ssl_files[2]).read_bytes(),
),
)

# Teardown
server.stop(1)

0 comments on commit 5b92ef6

Please sign in to comment.