Skip to content

Commit

Permalink
feat: add SSL certificate validation for Druid (apache#9396)
Browse files Browse the repository at this point in the history
* feat: add SSL certificate feature

* Address comments

* don't mutate extras

* Address comments and add polish

* Add further polish
  • Loading branch information
villebro authored Mar 27, 2020
1 parent fd22788 commit 499f9c8
Show file tree
Hide file tree
Showing 16 changed files with 274 additions and 19 deletions.
6 changes: 6 additions & 0 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,12 @@ The native Druid connector (behind the ``DRUID_IS_ACTIVE`` feature flag)
is slowly getting deprecated in favor of the SQLAlchemy/DBAPI connector made
available in the ``pydruid`` library.

To use a custom SSL certificate to validate HTTPS requests, the certificate
contents can be entered in the ``Root Certificate`` field in the Database
dialog. When using a custom certificate, ``pydruid`` will automatically use
``https`` scheme. To disable SSL verification add the following to extras:
``engine_params": {"connect_args": {"scheme": "https", "ssl_verify_cert": false}}``

Dremio
------

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ combine_as_imports = true
include_trailing_comma = true
line_length = 88
known_first_party = superset
known_third_party =alembic,backoff,bleach,celery,click,colorama,contextlib2,croniter,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
known_third_party =alembic,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
multi_line_output = 3
order_by_type = false

Expand Down
5 changes: 5 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,11 @@ class CeleryConfig: # pylint: disable=too-few-public-methods
# Typically these should not be allowed.
PREVENT_UNSAFE_DB_CONNECTIONS = True

# Path used to store SSL certificates that are generated when using custom certs.
# Defaults to temporary directory.
# Example: SSL_CERT_PATH = "/certs"
SSL_CERT_PATH: Optional[str] = None

# SIP-15 should be enabled for all new Superset deployments which ensures that the time
# range endpoints adhere to [start, end). For existing deployments admins should provide
# a dedicated period of time to allow chart producers to update their charts before
Expand Down
22 changes: 22 additions & 0 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
# pylint: disable=unused-argument
import hashlib
import json
import logging
import os
import re
from contextlib import closing
Expand Down Expand Up @@ -59,6 +61,8 @@
)
from superset.models.core import Database # pylint: disable=unused-import

logger = logging.getLogger()


class TimeGrain(NamedTuple): # pylint: disable=too-few-public-methods
name: str # TODO: redundant field, remove
Expand Down Expand Up @@ -959,3 +963,21 @@ def mutate_db_for_connection_test(database: "Database") -> None:
:param database: instance to be mutated
"""
return None

@staticmethod
def get_extra_params(database: "Database") -> Dict[str, Any]:
"""
Some databases require adding elements to connection parameters,
like passing certificates to `extra`. This can be done here.
:param database: database instance from which to extract extras
:raises CertificateException: If certificate is not valid/unparseable
"""
extra: Dict[str, Any] = {}
if database.extra:
try:
extra = json.loads(database.extra)
except json.JSONDecodeError as e:
logger.error(e)
raise e
return extra
32 changes: 31 additions & 1 deletion superset/db_engine_specs/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import TYPE_CHECKING
import json
import logging
from typing import Any, Dict, TYPE_CHECKING

from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils import core as utils

if TYPE_CHECKING:
from superset.connectors.sqla.models import ( # pylint: disable=unused-import
TableColumn,
)
from superset.models.core import Database # pylint: disable=unused-import

logger = logging.getLogger()


class DruidEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
Expand All @@ -47,3 +53,27 @@ class DruidEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
def alter_new_orm_column(cls, orm_col: "TableColumn") -> None:
if orm_col.column_name == "__time":
orm_col.is_dttm = True

@staticmethod
def get_extra_params(database: "Database") -> Dict[str, Any]:
"""
For Druid, the path to a SSL certificate is placed in `connect_args`.
:param database: database instance from which to extract extras
:raises CertificateException: If certificate is not valid/unparseable
"""
try:
extra = json.loads(database.extra or "{}")
except json.JSONDecodeError as e:
logger.error(e)
raise e

if database.server_cert:
engine_params = extra.get("engine_params", {})
connect_args = engine_params.get("connect_args", {})
connect_args["scheme"] = "https"
path = utils.create_ssl_cert_file(database.server_cert)
connect_args["ssl_verify_cert"] = path
engine_params["connect_args"] = connect_args
extra["engine_params"] = engine_params
return extra
4 changes: 4 additions & 0 deletions superset/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,9 @@ class SpatialException(SupersetException):
pass


class CertificateException(SupersetException):
pass


class DatabaseNotFound(SupersetException):
status = 400
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""add certificate to dbs
Revision ID: b5998378c225
Revises: 72428d1ea401
Create Date: 2020-03-25 10:49:10.883065
"""

# revision identifiers, used by Alembic.
revision = "b5998378c225"
down_revision = "72428d1ea401"

from typing import Dict

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects.postgresql.base import PGDialect
from sqlalchemy_utils import EncryptedType


def upgrade():
kwargs: Dict[str, str] = {}
bind = op.get_bind()
op.add_column(
"dbs",
sa.Column("server_cert", EncryptedType(sa.Text()), nullable=True, **kwargs),
)


def downgrade():
op.drop_column("dbs", "server_cert")
11 changes: 3 additions & 8 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class Database(
encrypted_extra = Column(EncryptedType(Text, config["SECRET_KEY"]), nullable=True)
perm = Column(String(1000))
impersonate_user = Column(Boolean, default=False)
server_cert = Column(EncryptedType(Text, config["SECRET_KEY"]), nullable=True)
export_fields = [
"database_name",
"sqlalchemy_uri",
Expand Down Expand Up @@ -309,6 +310,7 @@ def get_sqla_engine(
)
if configuration:
connect_args["configuration"] = configuration
if connect_args:
params["connect_args"] = connect_args

params.update(self.get_encrypted_extra())
Expand Down Expand Up @@ -555,14 +557,7 @@ def grains(self) -> Tuple[TimeGrain, ...]:
return self.db_engine_spec.get_time_grains()

def get_extra(self) -> Dict[str, Any]:
extra: Dict[str, Any] = {}
if self.extra:
try:
extra = json.loads(self.extra)
except json.JSONDecodeError as e:
logger.error(e)
raise e
return extra
return self.db_engine_spec.get_extra_params(self)

def get_encrypted_extra(self):
encrypted_extra = {}
Expand Down
1 change: 1 addition & 0 deletions superset/templates/superset/models/database/add.html
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@
{{ macros.testconn() }}
{{ macros.expand_extra_textarea() }}
{{ macros.expand_encrypted_extra_textarea() }}
{{ macros.expand_server_cert_textarea() }}
{% endblock %}
1 change: 1 addition & 0 deletions superset/templates/superset/models/database/edit.html
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@
{{ macros.testconn() }}
{{ macros.expand_extra_textarea() }}
{{ macros.expand_encrypted_extra_textarea() }}
{{ macros.expand_server_cert_textarea() }}
{% endblock %}
7 changes: 7 additions & 0 deletions superset/templates/superset/models/database/macros.html
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
impersonate_user: $('#impersonate_user').is(':checked'),
extras: extra ? JSON.parse(extra) : {},
encrypted_extra: encryptedExtra ? JSON.parse(encryptedExtra) : {},
server_cert: $("#server_cert").val(),
})
} catch(parse_error){
alert("Malformed JSON in the extras field: " + parse_error);
Expand Down Expand Up @@ -81,3 +82,9 @@
$('#encrypted_extra').attr('rows', '5');
</script>
{% endmacro %}

{% macro expand_server_cert_textarea() %}
<script>
$('#server_cert').attr('rows', '5');
</script>
{% endmacro %}
52 changes: 50 additions & 2 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
import decimal
import errno
import functools
import hashlib
import json
import logging
import os
import re
import signal
import smtplib
import tempfile
import traceback
import uuid
import zlib
Expand All @@ -45,6 +46,9 @@
import pandas as pd
import parsedatetime
import sqlalchemy as sa
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.backends.openssl.x509 import _Certificate
from dateutil.parser import parse
from dateutil.relativedelta import relativedelta
from flask import current_app, flash, Flask, g, Markup, render_template
Expand All @@ -56,7 +60,11 @@
from sqlalchemy.sql.type_api import Variant
from sqlalchemy.types import TEXT, TypeDecorator

from superset.exceptions import SupersetException, SupersetTimeoutException
from superset.exceptions import (
CertificateException,
SupersetException,
SupersetTimeoutException,
)
from superset.utils.dates import datetime_to_epoch, EPOCH

try:
Expand Down Expand Up @@ -1163,6 +1171,46 @@ def get_username() -> Optional[str]:
return None


def parse_ssl_cert(certificate: str) -> _Certificate:
"""
Parses the contents of a certificate and returns a valid certificate object
if valid.
:param certificate: Contents of certificate file
:return: Valid certificate instance
:raises CertificateException: If certificate is not valid/unparseable
"""
try:
return x509.load_pem_x509_certificate(
certificate.encode("utf-8"), default_backend()
)
except ValueError as e:
raise CertificateException("Invalid certificate")


def create_ssl_cert_file(certificate: str) -> str:
"""
This creates a certificate file that can be used to validate HTTPS
sessions. A certificate is only written to disk once; on subsequent calls,
only the path of the existing certificate is returned.
:param certificate: The contents of the certificate
:return: The path to the certificate file
:raises CertificateException: If certificate is not valid/unparseable
"""
filename = f"{hashlib.md5(certificate.encode('utf-8')).hexdigest()}.crt"
cert_dir = current_app.config["SSL_CERT_PATH"]
path = cert_dir if cert_dir else tempfile.gettempdir()
path = os.path.join(path, filename)
if not os.path.exists(path):
# Validate certificate prior to persisting to temporary directory
parse_ssl_cert(certificate)
cert_file = open(path, "w")
cert_file.write(certificate)
cert_file.close()
return path


def MediumText() -> Variant:
return Text().with_variant(MEDIUMTEXT(), "mysql")

Expand Down
13 changes: 13 additions & 0 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from superset.connectors.sqla.models import AnnotationDatasource
from superset.constants import RouteMethod
from superset.exceptions import (
CertificateException,
DatabaseNotFound,
SupersetException,
SupersetSecurityException,
Expand Down Expand Up @@ -1353,6 +1354,7 @@ def testconn(self):
# this is the database instance that will be tested
database = models.Database(
# extras is sent as json, but required to be a string in the Database model
server_cert=request.json.get("server_cert"),
extra=json.dumps(request.json.get("extras", {})),
impersonate_user=request.json.get("impersonate_user"),
encrypted_extra=json.dumps(request.json.get("encrypted_extra", {})),
Expand All @@ -1366,6 +1368,17 @@ def testconn(self):
with closing(engine.connect()) as conn:
conn.scalar(select([1]))
return json_success('"OK"')
except CertificateException as e:
logger.info("Invalid certificate %s", e)
return json_error_response(
_(
"Invalid certificate. "
"Please make sure the certificate begins with\n"
"-----BEGIN CERTIFICATE-----\n"
"and ends with \n"
"-----END CERTIFICATE-----"
)
)
except NoSuchModuleError as e:
logger.info("Invalid driver %s", e)
driver_name = make_url(uri).drivername
Expand Down
Loading

0 comments on commit 499f9c8

Please sign in to comment.