Skip to content

Commit

Permalink
Merge pull request doableware#212 from Raznak/master
Browse files Browse the repository at this point in the history
Bug fix add & drop columns & handle column renaming
  • Loading branch information
nesdis authored Jan 20, 2019
2 parents cc6c7d6 + e497b21 commit dd22f4a
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 5 deletions.
45 changes: 41 additions & 4 deletions djongo/sql2mongo/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,8 @@ class AlterQuery(VoidQuery):

def __init__(self, *args):
self._iden_name = None
self._old_column = None
self._new_column = None
self._default = None
self._cascade = None
self._null = None
Expand All @@ -530,11 +532,44 @@ def parse(self):
tok_id = self._drop(tok_id)
elif tok.match(tokens.Keyword.DDL, 'ALTER'):
tok_id = self._alter(tok_id)
elif tok.match(tokens.Keyword, 'RENAME'):
tok_id = self._rename(tok_id)
else:
raise NotImplementedError

tok_id, tok = sm.token_next(tok_id)

def _rename(self, tok_id):
sm = self.statement
tok_id, tok = sm.token_next(tok_id)

to = False
while tok_id is not None:
if tok.match(tokens.Keyword, ('COLUMN'),):
self.execute = self._rename_column
if tok.match(tokens.Keyword, ('TO'),):
to = True
elif isinstance(tok, Identifier):
if not to:
self.old_column = tok.get_real_name()
else:
self.new_column = tok.get_real_name()

tok_id, tok = sm.token_next(tok_id)

return tok_id

def _rename_column(self):
self.db_ref[self.left_table].update(
{},
{
'$rename': {
self.old_column: self.new_column
}
},
multi=True
)

def _alter(self, tok_id):
self.execute = lambda: None

Expand All @@ -559,7 +594,7 @@ def _drop(self, tok_id):
)):
pass
elif isinstance(tok, Identifier):
self._iden_name = tok.get_name()
self._iden_name = tok.get_real_name()
elif tok.match(tokens.Keyword, 'CONSTRAINT'):
self.execute = self._drop_constraint
elif tok.match(tokens.Keyword, 'COLUMN'):
Expand All @@ -581,7 +616,8 @@ def _drop_column(self):
'$unset': {
self._iden_name: ''
}
}
},
multi=True
)

def _add(self, tok_id):
Expand All @@ -600,7 +636,7 @@ def _add(self, tok_id):
)):
pass
elif isinstance(tok, Identifier):
self._iden_name = tok.get_name()
self._iden_name = tok.get_real_name()
elif isinstance(tok, Parenthesis):
self.field_dir = [
(field.strip(' "'), 1)
Expand Down Expand Up @@ -637,7 +673,8 @@ def _add_column(self):
'$set': {
self._iden_name: self._default
}
}
},
multi=True
)
def _index(self):
self.db_ref[self.left_table].create_index(
Expand Down
4 changes: 4 additions & 0 deletions tests/django_tests/tests/queries/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ def __str__(self):
return self.name


class ArticleDerived(Article):
pass


class Food(models.Model):
name = models.CharField(max_length=20, unique=True)

Expand Down
16 changes: 15 additions & 1 deletion tests/django_tests/tests/queries/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from django.test.utils import CaptureQueriesContext

from .models import (
FK1, Annotation, Article, Author, BaseA, Book, CategoryItem,
FK1, Annotation, Article, ArticleDerived, Author, BaseA, Book, CategoryItem,
CategoryRelationship, Celebrity, Channel, Chapter, Child, ChildObjectA,
Classroom, CommonMixedCaseForeignKeys, Company, Cover, CustomPk,
CustomPkTag, Detail, DumbCategory, Eaten, Employment, ExtraInfo, Fan, Food,
Expand Down Expand Up @@ -2348,9 +2348,23 @@ def setUpTestData(cls):
Article.objects.create(
name="Article {}".format(i), created=some_date)

for i in range(1, 8):
ArticleDerived.objects.create(
name="ArticleDerived {}".format(i), created=some_date)

def get_ordered_articles(self):
return Article.objects.all().order_by('name')

def get_ordered_derived_articles(self):
return ArticleDerived.objects.all().order_by('name')

def test_can_get_items_using_index_and_slice_notation_with_derived_model(self):
self.assertEqual(self.get_ordered_derived_articles()[0].name, 'ArticleDerived 1')
self.assertQuerysetEqual(
self.get_ordered_derived_articles()[4:6],
["<ArticleDerived: ArticleDerived 5>", "<ArticleDerived: ArticleDerived 6>"]
)

def test_can_get_items_using_index_and_slice_notation(self):
self.assertEqual(self.get_ordered_articles()[0].name, 'Article 1')
self.assertQuerysetEqual(
Expand Down

0 comments on commit dd22f4a

Please sign in to comment.