Skip to content

Commit

Permalink
Basic multiple insertion flow + option for skipping encoding
Browse files Browse the repository at this point in the history
* Option for skipping encoding
* Multiple insert basic flow, best approach against the DB needs to be tested
* More descriptive faling
  • Loading branch information
harelba committed Feb 6, 2012
1 parent 0fe5040 commit 013682f
Showing 1 changed file with 74 additions and 11 deletions.
85 changes: 74 additions & 11 deletions q
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ from optparse import OptionParser
import traceback as tb
import codecs
import locale
import time
import re

DEBUG = False

# Encode stdout properly,
if sys.stdout.isatty():
Expand All @@ -47,30 +51,62 @@ parser.add_option("-H","--header-skip",dest="header_skip",default=0,
parser.add_option("-f","--formatting",dest="formatting",default=None,
help="Output-level formatting, in the format X=fmt,Y=fmt etc, where X,Y are output column numbers (e.g. 1 for first SELECT column etc.")
parser.add_option("-e","--encoding",dest="encoding",default='UTF-8',
help="Input file encoding. Defaults to UTF-8")
help="Input file encoding. Defaults to UTF-8. set to none for not setting any encoding - faster, but at your own risk...")

class Sqlite3DB(object):
def __init__(self,show_sql=SHOW_SQL):
self.show_sql = show_sql
self.conn = sqlite3.connect(':memory:')
self.cursor = self.conn.cursor()
self.type_names = { str : 'TEXT' , int : 'INT' , float : 'FLOAT' }

def execute_and_fetch(self,q):
try:
cursor = self.conn.cursor()
if self.show_sql:
print q
cursor.execute(q)
result = cursor.fetchall()
self.cursor.execute(q)
result = self.cursor.fetchall()
finally:
cursor.close()
pass#cursor.close()
return result

def _get_as_list_str(self,l,quote=False):
if not quote:
return ",".join(["%s" % x for x in l])
else:
return ",".join(["\"%s\"" % x for x in l])

def generate_insert_row(self,table_name,column_names,col_vals):
col_names_str = self._get_as_list_str(column_names)
col_vals_str = self._get_as_list_str(col_vals,quote=True)
return 'INSERT INTO %s (%s) VALUES (%s)' % (table_name,col_names_str,col_vals_str)

def generate_begin_transaction(self):
return "BEGIN TRANSACTION"

def generate_end_transaction(self):
return "COMMIT"

# Get a list of column names so order will be preserved (Could have used OrderedDict, but
# then we would need python 2.7)
def generate_create_table(self,table_name,column_names,column_dict):
# Convert dict from python types to db types
column_name_to_db_type = dict((n,self.type_names[t]) for n,t in column_dict.iteritems())
column_defs = ','.join(['%s %s' % (n,column_name_to_db_type[n]) for n in column_names])
return 'CREATE TABLE %s (%s)' % (table_name,column_defs)


def generate_insert_row(self,table_name,column_names,col_vals):
col_names_str = ",".join(["%s" % x for x in column_names])
col_vals_str = ",".join(['"%s"' % x for x in col_vals])
col_names_str = self._get_as_list_str(column_names)
col_vals_str = self._get_as_list_str(col_vals,quote=True)
return 'INSERT INTO %s (%s) VALUES (%s)' % (table_name,col_names_str,col_vals_str)

def generate_begin_transaction(self):
return "BEGIN TRANSACTION"

def generate_end_transaction(self):
return "COMMIT"

# Get a list of column names so order will be preserved (Could have used OrderedDict, but
# then we would need python 2.7)
def generate_create_table(self,table_name,column_names,column_dict):
Expand Down Expand Up @@ -236,10 +272,16 @@ class TableCreator(object):
# Filled only after table population since we're inferring the table creation data
self.table_name = None

self.buffered_inserts = []

def get_table_name(self):
return self.table_name

def populate(self):
if self.encoding != 'none' and self.encoding is not None:
encoder = codecs.getreader(self.encoding)
else:
encoder = None
# Get the list of filenames
filenames = self.filenames_str.split("+")
# for each filename (or pattern)
Expand All @@ -256,6 +298,7 @@ class TableCreator(object):

# For each match
for filename in files_to_go_over:
self.current_filename = filename
self.lines_read = 0

# Check if it's standard input or a file
Expand All @@ -269,7 +312,9 @@ class TableCreator(object):
f = gzip.GzipFile(fileobj=f)

# And wrap it in an decoder (e.g. ascii, UTF-8 etc)
f = codecs.getreader(self.encoding)(f)
if encoder is not None:
f = encoder(f)


# Read all the lines
try:
Expand All @@ -280,6 +325,7 @@ class TableCreator(object):
finally:
if f != sys.stdin:
f.close()
self._flush_inserts()
if not self.table_created:
raise Exception('Table should have already been created for file %s' % filename)

Expand All @@ -303,12 +349,26 @@ class TableCreator(object):

# If we have more columns than we inferred
if len(col_vals) > len(self.column_inferer.column_names):
raise Exception('Encountered a line in an invalid format - %s columns instead of %s. Did you make sure to set the correct delimiter?' % (len(col_vals),len(self.column_inferer.column_names)))
raise Exception('Encountered a line in an invalid format %s:%s - %s columns instead of %s. Did you make sure to set the correct delimiter?' % (self.current_filename,self.lines_read,len(col_vals),len(self.column_inferer.column_names)))

effective_column_names = self.column_inferer.column_names[:len(col_vals)]

insert_row_stmt = self.db.generate_insert_row(self.table_name,effective_column_names,col_vals)
self.db.execute_and_fetch(insert_row_stmt)
self.buffered_inserts.append((effective_column_names,col_vals))

if len(self.buffered_inserts) < 1000:
return
self._flush_inserts()

def _flush_inserts(self):
#print self.db.execute_and_fetch(self.db.generate_begin_transaction())

for col_names,col_vals in self.buffered_inserts:
insert_row_stmt = self.db.generate_insert_row(self.table_name,col_names,col_vals)
self.db.execute_and_fetch(insert_row_stmt)

#print self.db.execute_and_fetch(self.db.generate_end_transaction())
self.buffered_inserts = []


def try_to_create_table(self,line):
if self.table_created:
Expand Down Expand Up @@ -369,7 +429,10 @@ line_splitter = LineSplitter(options.delimiter)
for filename in sql_object.qtable_names:
# Create the matching database table and populate it
table_creator = TableCreator(db,filename,line_splitter,int(options.header_skip),options.gzipped,options.encoding)
start_time = time.time()
table_creator.populate()
if DEBUG:
print "TIMING - populate time is %4.3f" % (time.time() - start_time)

# Replace the logical table name with the real table name
sql_object.set_effective_table_name(filename,table_creator.table_name)
Expand Down

0 comments on commit 013682f

Please sign in to comment.