Skip to content

Commit

Permalink
Implement table name extraction. (apache#1598)
Browse files Browse the repository at this point in the history
* Implement table name extraction tests.

* Address comments.

* Fix tests and reimplement the token processing.

* Exclude aliases.

* Clean up print statements and code.

* Reverse select test.

* Fix failing test.

* Test JOINs

* refactore as a class

* Check for permissions in SQL Lab.

* Implement permissions check for the datasources in sql_lab

* Address comments.
  • Loading branch information
bkyryliuk authored Nov 29, 2016
1 parent fcb8707 commit dc98c67
Show file tree
Hide file tree
Showing 6 changed files with 465 additions and 22 deletions.
2 changes: 2 additions & 0 deletions superset/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,7 @@ class Database(Model, AuditMixinNullable):
"""An ORM object that stores Database related information"""

__tablename__ = 'dbs'
type = "table"

id = Column(Integer, primary_key=True)
database_name = Column(String(250), unique=True)
Expand Down Expand Up @@ -1524,6 +1525,7 @@ class DruidCluster(Model, AuditMixinNullable):
"""ORM object referencing the Druid clusters"""

__tablename__ = 'clusters'
type = "druid"

id = Column(Integer, primary_key=True)
cluster_name = Column(String(250), unique=True)
Expand Down
21 changes: 21 additions & 0 deletions superset/source_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,27 @@ def get_datasource_by_name(cls, session, datasource_type, datasource_name,
d.name == datasource_name and schema == schema]
return db_ds[0]

@classmethod
def query_datasources_by_name(
cls, session, database, datasource_name, schema=None):
datasource_class = SourceRegistry.sources[database.type]
if database.type == 'table':
query = (
session.query(datasource_class)
.filter_by(database_id=database.id)
.filter_by(table_name=datasource_name))
if schema:
query = query.filter_by(schema=schema)
return query.all()
if database.type == 'druid':
return (
session.query(datasource_class)
.filter_by(cluster_name=database.id)
.filter_by(datasource_name=datasource_name)
.all()
)
return None

@classmethod
def get_eager_datasource(cls, session, datasource_type, datasource_id):
"""Returns datasource with columns and metrics."""
Expand Down
25 changes: 9 additions & 16 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,20 @@
from sqlalchemy.orm import sessionmaker

from superset import (
app, db, models, utils, dataframe, results_backend)
app, db, models, utils, dataframe, results_backend, sql_parse, sm)
from superset.db_engine_specs import LimitMethod
from superset.jinja_context import get_template_processor
QueryStatus = models.QueryStatus

celery_app = celery.Celery(config_source=app.config.get('CELERY_CONFIG'))


def is_query_select(sql):
return sql.upper().startswith('SELECT')


def create_table_as(sql, table_name, schema=None, override=False):
"""Reformats the query into the create table as query.
Works only for the single select SQL statements, in all other cases
the sql query is not modified.
:param sql: string, sql query that will be executed
:param superset_query: string, sql query that will be executed
:param table_name: string, will contain the results of the query execution
:param override, boolean, table table_name will be dropped if true
:return: string, create table as query
Expand All @@ -41,12 +37,9 @@ def create_table_as(sql, table_name, schema=None, override=False):
if schema:
table_name = schema + '.' + table_name
exec_sql = ''
if is_query_select(sql):
if override:
exec_sql = 'DROP TABLE IF EXISTS {table_name};\n'
exec_sql += "CREATE TABLE {table_name} AS \n{sql}"
else:
raise Exception("Could not generate CREATE TABLE statement")
if override:
exec_sql = 'DROP TABLE IF EXISTS {table_name};\n'
exec_sql += "CREATE TABLE {table_name} AS \n{sql}"
return exec_sql.format(**locals())


Expand Down Expand Up @@ -76,12 +69,12 @@ def handle_error(msg):
raise Exception(query.error_message)

# Limit enforced only for retrieving the data, not for the CTA queries.
is_select = is_query_select(executed_sql);
if not is_select and not database.allow_dml:
superset_query = sql_parse.SupersetQuery(executed_sql)
if not superset_query.is_select() and not database.allow_dml:
handle_error(
"Only `SELECT` statements are allowed against this database")
if query.select_as_cta:
if not is_select:
if not superset_query.is_select():
handle_error(
"Only `SELECT` statements can be used with the CREATE TABLE "
"feature.")
Expand All @@ -94,7 +87,7 @@ def handle_error(msg):
executed_sql, query.tmp_table_name, database.force_ctas_schema)
query.select_as_cta_used = True
elif (
query.limit and is_select and
query.limit and superset_query.is_select() and
db_engine_spec.limit_method == LimitMethod.WRAP_SQL):
executed_sql = database.wrap_sql_limit(executed_sql, query.limit)
query.limit_used = True
Expand Down
101 changes: 101 additions & 0 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import sqlparse
from sqlparse.sql import IdentifierList, Identifier
from sqlparse.tokens import Keyword, Name

RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT'}
PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'}


# TODO: some sql_lab logic here.
class SupersetQuery(object):
def __init__(self, sql_statement):
self._tokens = []
self.sql = sql_statement
self._table_names = set()
self._alias_names = set()
# TODO: multistatement support
for statement in sqlparse.parse(self.sql):
self.__extract_from_token(statement)
self._table_names = self._table_names - self._alias_names

@property
def tables(self):
return self._table_names

# TODO: use sqlparse for this check.
def is_select(self):
return self.sql.upper().startswith('SELECT')

@staticmethod
def __precedes_table_name(token_value):
for keyword in PRECEDES_TABLE_NAME:
if keyword in token_value:
return True
return False

@staticmethod
def __get_full_name(identifier):
if len(identifier.tokens) > 1 and identifier.tokens[1].value == '.':
return "{}.{}".format(identifier.tokens[0].value,
identifier.tokens[2].value)
return identifier.get_real_name()

@staticmethod
def __is_result_operation(keyword):
for operation in RESULT_OPERATIONS:
if operation in keyword.upper():
return True
return False

@staticmethod
def __is_identifier(token):
return (
isinstance(token, IdentifierList) or isinstance(token, Identifier))

def __process_identifier(self, identifier):
# exclude subselects
if '(' not in '{}'.format(identifier):
self._table_names.add(SupersetQuery.__get_full_name(identifier))
return

# store aliases
if hasattr(identifier, 'get_alias'):
self._alias_names.add(identifier.get_alias())
if hasattr(identifier, 'tokens'):
# some aliases are not parsed properly
if identifier.tokens[0].ttype == Name:
self._alias_names.add(identifier.tokens[0].value)
self.__extract_from_token(identifier)

def __extract_from_token(self, token):
if not hasattr(token, 'tokens'):
return

table_name_preceding_token = False

for item in token.tokens:
if item.is_group and not self.__is_identifier(item):
self.__extract_from_token(item)

if item.ttype in Keyword:
if SupersetQuery.__precedes_table_name(item.value.upper()):
table_name_preceding_token = True
continue

if not table_name_preceding_token:
continue

if item.ttype in Keyword:
if SupersetQuery.__is_result_operation(item.value):
table_name_preceding_token = False
continue
# FROM clause is over
break

if isinstance(item, Identifier):
self.__process_identifier(item)

if isinstance(item, IdentifierList):
for token in item.tokens:
if SupersetQuery.__is_identifier(token):
self.__process_identifier(token)
43 changes: 37 additions & 6 deletions superset/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import superset
from superset import (
appbuilder, cache, db, models, viz, utils, app,
sm, sql_lab, results_backend, security,
sm, sql_lab, sql_parse, results_backend, security,
)
from superset.source_registry import SourceRegistry
from superset.models import DatasourceAccessRequest as DAR
Expand Down Expand Up @@ -74,6 +74,18 @@ def datasource_access(self, datasource):
self.can_access("datasource_access", datasource.perm)
)

def datasource_access_by_name(
self, database, datasource_name, schema=None):
if (self.database_access(database) or
self.all_datasource_access()):
return True
datasources = SourceRegistry.query_datasources_by_name(
db.session, database, datasource_name, schema=schema)
for datasource in datasources:
if self.can_access("datasource_access", datasource.perm):
return True
return False


class ListWidgetWithCheckboxes(ListWidget):
"""An alternative to list view that renders Boolean fields as checkboxes
Expand Down Expand Up @@ -2303,27 +2315,45 @@ def results(self, key):
@log_this
def sql_json(self):
"""Runs arbitrary sql and returns and json"""
def table_accessible(database, full_table_name, schema_name=None):
table_name_pieces = full_table_name.split(".")
if len(table_name_pieces) == 2:
table_schema = table_name_pieces[0]
table_name = table_name_pieces[1]
else:
table_schema = schema_name
table_name = table_name_pieces[0]
return self.datasource_access_by_name(
database, table_name, schema=table_schema)

async = request.form.get('runAsync') == 'true'
sql = request.form.get('sql')
database_id = request.form.get('database_id')

session = db.session()
mydb = session.query(models.Database).filter_by(id=database_id).first()
mydb = session.query(models.Database).filter_by(id=database_id).one()

if not mydb:
json_error_response(
'Database with id {} is missing.'.format(database_id))

if not self.database_access(mydb):
superset_query = sql_parse.SupersetQuery(sql)
schema = request.form.get('schema')
schema = schema if schema else None

rejected_tables = [
t for t in superset_query.tables if not
table_accessible(mydb, t, schema_name=schema)]
if rejected_tables:
json_error_response(
get_database_access_error_msg(mydb.database_name))
get_datasource_access_error_msg('{}'.format(rejected_tables)))
session.commit()

query = models.Query(
database_id=int(database_id),
limit=int(app.config.get('SQL_MAX_ROW', None)),
sql=sql,
schema=request.form.get('schema'),
schema=schema,
select_as_cta=request.form.get('select_as_cta') == 'true',
start_time=utils.now_as_float(),
tab_name=request.form.get('tab'),
Expand All @@ -2341,7 +2371,8 @@ def sql_json(self):
if async:
# Ignore the celery future object and the request may time out.
sql_lab.get_sql_results.delay(
query_id, return_results=False, store_results=not query.select_as_cta)
query_id, return_results=False,
store_results=not query.select_as_cta)
return Response(
json.dumps({'query': query.to_dict()},
default=utils.json_int_dttm_ser,
Expand Down
Loading

0 comments on commit dc98c67

Please sign in to comment.