Skip to content

Commit

Permalink
Multiple file support + gzipped file support + encodings + other stuff
Browse files Browse the repository at this point in the history
* Multiple file support using file1+file2 on the FROM clause
* automatic .gz detection and decompression
* File encoding support
* Output formatting support
* Some error handling
* Extra columns basic support
  • Loading branch information
harelba committed Feb 1, 2012
1 parent e50f6f0 commit 83e3730
Showing 1 changed file with 103 additions and 34 deletions.
137 changes: 103 additions & 34 deletions q
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,17 @@ import os,sys
import random
import sqlite3
import gzip
import glob
from optparse import OptionParser
import traceback as tb
import codecs
import locale

# Encode stdout properly,
if sys.stdout.isatty():
STDOUT = codecs.getwriter(sys.stdout.encoding)(sys.stdout)
else:
STDOUT = codecs.getwriter(locale.getpreferredencoding())(sys.stdout)

SHOW_SQL = False

Expand All @@ -30,8 +40,14 @@ parser.add_option("-z","--gzipped",dest="gzipped",default=False,action="store_tr
help="Data is gzipped. Useful for reading from stdin. For files, .gz means automatic gunzipping")
parser.add_option("-d","--delimiter",dest="delimiter",default=None,
help="Field delimiter. If none specified, then standard whitespace is used as a delimiter")
parser.add_option("-t","--tab-delimited-with-header",dest="tab_delimited_with_header",default=False,action="store_true",
help="Same as -d <tab> -H 1. Just a shorthand for handling standard tab delimited file with one header line at the beginning of the file")
parser.add_option("-H","--header-skip",dest="header_skip",default=0,
help="Skip n lines at the beginning of the data (still takes those lines into account in terms of structure)")
parser.add_option("-f","--formatting",dest="formatting",default=None,
help="Output-level formatting, in the format X=fmt,Y=fmt etc, where X,Y are output column numbers (e.g. 1 for first SELECT column etc.")
parser.add_option("-e","--encoding",dest="encoding",default='UTF-8',
help="Input file encoding. Defaults to UTF-8")

class Sqlite3DB(object):
def __init__(self,show_sql=SHOW_SQL):
Expand All @@ -50,9 +66,10 @@ class Sqlite3DB(object):
cursor.close()
return result

def generate_insert_row(self,table_name,col_vals):
def generate_insert_row(self,table_name,column_names,col_vals):
col_names_str = ",".join(["%s" % x for x in column_names])
col_vals_str = ",".join(['"%s"' % x for x in col_vals])
return 'INSERT INTO %s VALUES (%s)' % (table_name,col_vals_str)
return 'INSERT INTO %s (%s) VALUES (%s)' % (table_name,col_names_str,col_vals_str)

# Get a list of column names so order will be preserved (Could have used OrderedDict, but
# then we would need python 2.7)
Expand All @@ -79,7 +96,10 @@ class Sql(object):
self.sql_parts = sql.split()
self.column_list = self.sql_parts[0]
# Simplistic way to determine table name
self.table_name_position = [i+1 for i in range(0,len(self.sql_parts)) if self.sql_parts[i].upper() == 'FROM'][0]
try:
self.table_name_position = [i+1 for i in range(0,len(self.sql_parts)) if self.sql_parts[i].upper() == 'FROM'][0]
except:
raise Exception("Could not detect table name in query")
self.table_name = self.sql_parts[self.table_name_position]

self.actual_table_name = None
Expand Down Expand Up @@ -124,6 +144,14 @@ class TableColumnInferer(object):

# Column count according to first line only for now
self.column_count = len(self.line_splitter.split(self.example_lines[0]))

# FIXME: Hack to provide for some small variation in the column count. Will be fixed as soon as we have better column inferring
#self.column_count += max(6,int(self.column_count*0.2))
self.column_count += 5

if self.column_count == 0:
raise Exception("Detected a column count of zero... Failing")

# Only string type for now
self.column_types = [str for i in range(self.column_count)]
# Column names are cX starting from 1
Expand All @@ -144,42 +172,58 @@ class TableColumnInferer(object):
return self.column_types

class TableCreator(object):
def __init__(self,db,filename,line_splitter,header_skip=0,gzipped=False):
def __init__(self,db,filenames_str,line_splitter,header_skip=0,gzipped=False,encoding='UTF-8'):
self.db = db
self.filename = filename
self.filenames_str = filenames_str
self.header_skip = header_skip
self.gzipped = gzipped
self.table_created = False
self.line_splitter = line_splitter
self.encoding = encoding
self.column_inferer = TableColumnInferer(line_splitter)

self.lines_read = 0

# Filled only after table population since we're inferring the table creation data
self.table_name = None

def get_table_name(self):
return self.table_name

def populate(self):
# Determine file object
if self.filename != "-":
f = file(self.filename)
else:
f = sys.stdin

# If data is gzipped, then wrap the file object
if self.gzipped:
f = gzip.GzipFile(fileobj=f)

try:
line = f.readline()
while line:
self._insert_row(line)
line = f.readline()
finally:
if f != sys.stdin:
f.close()
# Get the list of filenames
filenames = self.filenames_str.split("+")
# for each filename (or pattern)
for fileglob in filenames:
# Allow either stdin or a glob match
if fileglob == '-':
files_to_go_over = ['-']
else:
files_to_go_over = glob.glob(fileglob)
# For each match
for filename in files_to_go_over:
self.lines_read = 0

# Check if it's standard input or a file
if filename == '-':
f = sys.stdin
else:
f = file(filename,'rb')

# Wrap it with gzip decompression if needed
if self.gzipped or filename.endswith('.gz'):
f = gzip.GzipFile(fileobj=f)

# And wrap it in an decoder (e.g. ascii, UTF-8 etc)
f = codecs.getreader(self.encoding)(f)

# Read all the lines
try:
line = f.readline()
while line:
self._insert_row(line)
line = f.readline()
finally:
if f != sys.stdin:
f.close()

def _insert_row(self,line):
# If table has not been created yet
Expand All @@ -198,7 +242,14 @@ class TableCreator(object):
if self.lines_read <= self.header_skip:
return
col_vals = line_splitter.split(line)
insert_row_stmt = self.db.generate_insert_row(self.table_name,col_vals)

# If we have more columns than we inferred
if len(col_vals) > len(self.column_inferer.column_names):
raise Exception('Encountered a line in an invalid format - %s columns instead of %s. Did you make sure to set the correct delimiter?' % (len(col_vals),len(self.column_inferer.column_names)))

effective_column_names = self.column_inferer.column_names[:len(col_vals)]

insert_row_stmt = self.db.generate_insert_row(self.table_name,effective_column_names,col_vals)
self.db.execute_and_fetch(insert_row_stmt)

def try_to_create_table(self,line):
Expand Down Expand Up @@ -250,11 +301,16 @@ sql_object = Sql('%s' % args[0])
# Get "table name" which is actually the file name
filename = sql_object.table_name

# If the user flagged for a tab-delimited file then set the delimiter to tab
if options.tab_delimited_with_header:
options.delimiter = '\t'
options.header_skip = "1"

# Create a line splitter
line_splitter = LineSplitter(options.delimiter)

# Create the matching database table and populate it
table_creator = TableCreator(db,filename,line_splitter,int(options.header_skip),options.gzipped)
table_creator = TableCreator(db,filename,line_splitter,int(options.header_skip),options.gzipped,options.encoding)
table_creator.populate()

# Replace the logical table name with the real table name
Expand All @@ -272,17 +328,30 @@ if options.delimiter:
else:
output_delimiter = " "

if options.formatting:
formatting_dict = dict([(x.split("=")[0],x.split("=")[1]) for x in options.formatting.split(",")])
else:
formatting_dict = None

try:
for row in m:
for rownum,row in enumerate(m):
row_str = []
for i,col in enumerate(row):
if options.beautify:
fmt_str = "%%-%ss" % max_lengths[i]
if formatting_dict is not None and str(i+1) in formatting_dict.keys():
fmt_str = formatting_dict[str(i+1)]
else:
fmt_str = "%s"
row_str.append(fmt_str % col)
sys.stdout.write(output_delimiter.join(row_str)+"\n")
except:
if options.beautify:
fmt_str = "%%-%ss" % max_lengths[i]
else:
fmt_str = "%s"

if col is not None:
row_str.append(fmt_str % col)
else:
row_str.append(fmt_str % "")

STDOUT.write(output_delimiter.join(row_str)+"\n")
except KeyboardInterrupt:
pass

table_creator.drop_table()

0 comments on commit 83e3730

Please sign in to comment.