Skip to content

Commit

Permalink
Fix csv download. (apache#2036)
Browse files Browse the repository at this point in the history
  • Loading branch information
bkyryliuk authored Jan 26, 2017
1 parent c5c7302 commit b1bba96
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 12 deletions.
5 changes: 3 additions & 2 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,10 @@ def handle_error(msg):
logging.exception(e)
msg = "Template rendering failed: " + utils.error_msg_from_exception(e)
handle_error(msg)

query.executed_sql = executed_sql
logging.info("Running query: \n{}".format(executed_sql))
try:
query.executed_sql = executed_sql
logging.info("Running query: \n{}".format(executed_sql))
result_proxy = engine.execute(query.executed_sql, schema=query.schema)
except Exception as e:
logging.exception(e)
Expand Down
6 changes: 2 additions & 4 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sqlparse
from sqlparse.sql import IdentifierList, Identifier
from sqlparse.tokens import Keyword, Name
from sqlparse.tokens import DML, Keyword, Name

RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT'}
PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'}
Expand All @@ -9,7 +9,6 @@
# TODO: some sql_lab logic here.
class SupersetQuery(object):
def __init__(self, sql_statement):
self._tokens = []
self.sql = sql_statement
self._table_names = set()
self._alias_names = set()
Expand All @@ -23,9 +22,8 @@ def __init__(self, sql_statement):
def tables(self):
return self._table_names

# TODO: use sqlparse for this check.
def is_select(self):
return self.sql.upper().startswith('SELECT')
return self._parsed[0].get_type() == 'SELECT'

def stripped(self):
sql = self.sql
Expand Down
19 changes: 14 additions & 5 deletions superset/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from datetime import datetime, timedelta
import json
import logging
import pandas as pd
import pickle
import re
import sys
Expand Down Expand Up @@ -2563,11 +2564,19 @@ def csv(self, client_id):
if rejected_tables:
flash(get_datasource_access_error_msg('{}'.format(rejected_tables)))
return redirect('/')

sql = query.select_sql or query.sql
df = query.database.get_df(sql, query.schema)
# TODO(bkyryliuk): add compression=gzip for big files.
csv = df.to_csv(index=False, encoding='utf-8')
blob = None
if results_backend and query.results_key:
blob = results_backend.get(query.results_key)
if blob:
json_payload = zlib.decompress(blob)
obj = json.loads(json_payload)
df = pd.DataFrame.from_records(obj['data'])
csv = df.to_csv(index=False, encoding='utf-8')
else:
sql = query.select_sql or query.executed_sql
df = query.database.get_df(sql, query.schema)
# TODO(bkyryliuk): add compression=gzip for big files.
csv = df.to_csv(index=False, encoding='utf-8')
response = Response(csv, mimetype='text/csv')
response.headers['Content-Disposition'] = (
'attachment; filename={}.csv'.format(query.name))
Expand Down
2 changes: 1 addition & 1 deletion tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def test_csv_endpoint(self):
WHERE first_name='admin'
"""
client_id = "{}".format(random.getrandbits(64))[:10]
self.run_sql(sql, client_id)
self.run_sql(sql, client_id, raise_on_error=True)

resp = self.get_resp('/superset/csv/{}'.format(client_id))
data = csv.reader(io.StringIO(resp))
Expand Down

0 comments on commit b1bba96

Please sign in to comment.