From 2a34ee8da23ff2edd1a7b8aa7a4a7c7da3f62670 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 15 Nov 2013 17:42:55 +0400 Subject: [PATCH] Finally fixes #30: db.execute(), db.exists(), Entity.get_by_sql() and Entity.select_by_sql() methods were fixed --- pony/orm/core.py | 14 +++++++------- pony/orm/tests/test_frames.py | 24 ++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 7ceba318c..a7d441cd3 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -468,7 +468,7 @@ def rollback(database): @cut_traceback def execute(database, sql, globals=None, locals=None): database._get_cache().flush() - return database._exec_raw_sql(sql, globals, locals, 2) + return database._exec_raw_sql(sql, globals, locals, frame_depth=3) def _exec_raw_sql(database, sql, globals, locals, frame_depth): sql = sql[:] # sql = templating.plainstr(sql) if globals is None: @@ -509,7 +509,7 @@ def get(database, sql, globals=None, locals=None): @cut_traceback def exists(database, sql, globals=None, locals=None): if not select_re.match(sql): sql = 'select ' + sql - cursor = database._exec_raw_sql(sql, globals, locals, 2) + cursor = database._exec_raw_sql(sql, globals, locals, frame_depth=3) result = cursor.fetchone() return bool(result) @cut_traceback @@ -2571,20 +2571,20 @@ def __getitem__(entity, key): return objects[0] @cut_traceback def exists(entity, *args, **kwargs): - if args: return entity._query_from_args_(3, args, kwargs).exists() + if args: return entity._query_from_args_(args, kwargs, frame_depth=3).exists() try: objects = entity._find_(1, kwargs) except MultipleObjectsFoundError: return True return bool(objects) @cut_traceback def get(entity, *args, **kwargs): - if args: return entity._query_from_args_(3, args, kwargs).get() + if args: return entity._query_from_args_(args, kwargs, frame_depth=3).get() objects = entity._find_(1, kwargs) # can throw MultipleObjectsFoundError if not objects: return None assert len(objects) == 1 return objects[0] @cut_traceback def get_by_sql(entity, sql, globals=None, locals=None): - objects = entity._find_by_sql_(1, sql, globals, locals, 2) # can throw MultipleObjectsFoundError + objects = entity._find_by_sql_(1, sql, globals, locals, frame_depth=3) # can throw MultipleObjectsFoundError if not objects: return None assert len(objects) == 1 return objects[0] @@ -2600,7 +2600,7 @@ def select(entity, func=None): return entity._query_from_lambda_(func, globals, locals) @cut_traceback def select_by_sql(entity, sql, globals=None, locals=None): - return entity._find_by_sql_(None, sql, globals, locals, 2) + return entity._find_by_sql_(None, sql, globals, locals, frame_depth=3) @cut_traceback def order_by(entity, *args): query = Query(entity._default_iter_name_, entity._default_genexpr_, {}, { '.0' : entity }) @@ -2864,7 +2864,7 @@ def _load_many_(entity, objects): for obj in result: if obj not in batch: throw(UnrepeatableReadError, 'Phantom object %s disappeared' % safe_repr(obj)) - def _query_from_args_(entity, frame_depth, args, kwargs): + def _query_from_args_(entity, args, kwargs, frame_depth): if len(args) > 1: throw(TypeError, 'Only one positional argument expected') if kwargs: throw(TypeError, 'If positional argument presented, no keyword arguments expected') first_arg = args[0] diff --git a/pony/orm/tests/test_frames.py b/pony/orm/tests/test_frames.py index 10a016651..182ee79bc 100644 --- a/pony/orm/tests/test_frames.py +++ b/pony/orm/tests/test_frames.py @@ -77,6 +77,18 @@ def test_entity_get_str(self): p = Person.get('lambda p: p.age > x') self.assertEqual(p, Person[3]) + @db_session + def test_entity_get_by_sql(self): + x = 25 + p = Person.get_by_sql('select * from Person where age = $x') + self.assertEqual(p, Person[3]) + + @db_session + def test_entity_select_by_sql(self): + x = 25 + p = Person.select_by_sql('select * from Person where age = $x') + self.assertEqual(p, [ Person[3] ]) + @db_session def test_entity_exists(self): x = 23 @@ -141,5 +153,17 @@ def test_db_get(self): result = db.get('name from Person where age = $x') self.assertEqual(result, 'Mary') + @db_session + def test_db_execute(self): + x = 18 + result = db.execute('select name from Person where age = $x').fetchone() + self.assertEqual(result, ('Mary',)) + + @db_session + def test_db_exists(self): + x = 18 + result = db.exists('name from Person where age = $x') + self.assertEqual(result, True) + if __name__ == '__main__': unittest.main()