Skip to content

Commit

Permalink
Merge "Consider version_id_prop when emitting bulk UPDATE"
Browse files Browse the repository at this point in the history
  • Loading branch information
zzzeek authored and Gerrit Code Review committed Nov 10, 2016
2 parents e81660d + c9d8a67 commit 0342981
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 1 deletion.
9 changes: 9 additions & 0 deletions doc/build/changelog/changelog_10.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@
collection of the mapped table, thereby interfering with the
initialization of relationships.

.. change::
:tags: bug, orm
:tickets: 3781
:versions: 1.1.4

Fixed bug in :meth:`.Session.bulk_save` where an UPDATE would
not function correctly in conjunction with a mapping that
implements a version id counter.

.. changelog::
:version: 1.0.15
:released: September 1, 2016
Expand Down
7 changes: 6 additions & 1 deletion lib/sqlalchemy/orm/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,16 @@ def _bulk_update(mapper, mappings, session_transaction,

cached_connections = _cached_connection_dict(base_mapper)

search_keys = mapper._primary_key_propkeys
if mapper._version_id_prop:
search_keys = set([mapper._version_id_prop.key]).union(search_keys)

def _changed_dict(mapper, state):
return dict(
(k, v)
for k, v in state.dict.items() if k in state.committed_state or k
in mapper._primary_key_propkeys
in search_keys

)

if isstates:
Expand Down
51 changes: 51 additions & 0 deletions test/orm/test_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,57 @@ class BulkTest(testing.AssertsExecutionResults):
run_define_tables = 'each'


class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest):
@classmethod
def define_tables(cls, metadata):
Table('version_table', metadata,
Column('id', Integer, primary_key=True,
test_needs_autoincrement=True),
Column('version_id', Integer, nullable=False),
Column('value', String(40), nullable=False))

@classmethod
def setup_classes(cls):
class Foo(cls.Comparable):
pass

@classmethod
def setup_mappers(cls):
Foo, version_table = cls.classes.Foo, cls.tables.version_table

mapper(Foo, version_table, version_id_col=version_table.c.version_id)

def test_bulk_insert_via_save(self):
Foo = self.classes.Foo

s = Session()

s.bulk_save_objects([Foo(value='value')])

eq_(
s.query(Foo).all(),
[Foo(version_id=1, value='value')]
)

def test_bulk_update_via_save(self):
Foo = self.classes.Foo

s = Session()

s.add(Foo(value='value'))
s.commit()

f1 = s.query(Foo).first()
f1.value = 'new value'
s.bulk_save_objects([f1])
s.expunge_all()

eq_(
s.query(Foo).all(),
[Foo(version_id=2, value='new value')]
)


class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):

@classmethod
Expand Down

0 comments on commit 0342981

Please sign in to comment.