Skip to content

Commit

Permalink
Merge pull request aws#3285 from kyleknap/lock-history
Browse files Browse the repository at this point in the history
Add lock to DatabaseRecordWriter
  • Loading branch information
kyleknap authored Apr 24, 2018
2 parents 458714b + 9868bb8 commit fe8026d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 13 deletions.
7 changes: 3 additions & 4 deletions awscli/customizations/history/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,15 @@ class DatabaseRecordWriter(object):

def __init__(self, connection):
self._connection = connection
self._lock = threading.Lock()

def close(self):
self._connection.close()

def write_record(self, record):
# This method is not threadsafe by itself, it is only threadsafe when
# used inside a handler bound to the HistoryRecorder in botocore which
# is protected by a lock.
db_record = self._create_db_record(record)
self._connection.execute(self._WRITE_RECORD, db_record)
with self._lock:
self._connection.execute(self._WRITE_RECORD, db_record)

def _create_db_record(self, record):
event_type = record['event_type']
Expand Down
16 changes: 7 additions & 9 deletions tests/functional/history/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,18 @@


class ThreadedRecordWriter(object):
def __init__(self, writer, lock):
def __init__(self, writer):
self._read_q = queue.Queue()
self._thread = threading.Thread(
target=self._threaded_record_writer,
args=(writer, lock))
args=(writer,))

def _threaded_record_writer(self, writer, lock):
def _threaded_record_writer(self, writer):
while True:
record = self._read_q.get()
if record is False:
return
with lock:
writer.write_record(record)
writer.write_record(record)

def write_record(self, record):
self._read_q.put_nowait(record)
Expand All @@ -63,9 +62,9 @@ def setUp(self):
self.threads = []
self.writer = DatabaseRecordWriter(self.connection)

def start_n_threads(self, n, lock):
def start_n_threads(self, n):
for _ in range(n):
t = ThreadedRecordWriter(self.writer, lock)
t = ThreadedRecordWriter(self.writer)
t.start()
self.threads.append(t)

Expand All @@ -85,8 +84,7 @@ def _write_records(self, thread_number, records):

def test_bulk_writes_all_succeed(self):
thread_count = 10
lock = threading.Lock()
self.start_n_threads(thread_count, lock)
self.start_n_threads(thread_count)
for i in range(thread_count):
self._write_records(i, [
{
Expand Down

0 comments on commit fe8026d

Please sign in to comment.