Skip to content

Commit

Permalink
Support for Django 3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
nesdis committed May 14, 2020
1 parent ce0dd36 commit 7795fd2
Show file tree
Hide file tree
Showing 33 changed files with 1,283 additions and 810 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
dist/
build/
venv/
venv3.8/
.idea/
.tox/
*__pycache__*
Expand Down
3 changes: 3 additions & 0 deletions djongo/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_timezones = False
uses_savepoints = False
can_clone_databases = True
test_db_allows_multiple_connections = False
supports_unspecified_pk = True

21 changes: 2 additions & 19 deletions djongo/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,11 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):

sql_create_index = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s INDEX (%(columns)s)%(extra)s"
sql_delete_index = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s INDEX"
sql_delete_index2 = "DROP INDEX %(name)s ON %(table)s"


def quote_value(self, value):
raise NotImplementedError()
return value

def prepare_default(self, value):
raise NotImplementedError()

# def create_model(self, model):
# db_con = self.connection.connection
# db_con.create_collection(model._meta.db_table)
# logger.debug('Created table {}'.format(model._meta.db_table))
#
# for field in model._meta.local_fields:
# if field.get_internal_type() in ("AutoField", "BigAutoField"):
# db_con['__schema__'].\
# insert_one(
# {
# 'name': model._meta.db_table,
# 'auto': {
# 'field_name': field.column,
# 'seq': 0
# }
# }
# )
1 change: 1 addition & 0 deletions djongo/sql2mongo/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ def evaluate(self):
op.evaluate()
self._op = op


class WhereOp(_Op, _StatementParser):

def __init__(self, *args, **kwargs):
Expand Down
154 changes: 82 additions & 72 deletions djongo/sql2mongo/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
from logging import getLogger
from typing import Optional, Dict, List, Union as U, Sequence, Set
from dataclasses import dataclass, field
from dataclasses import dataclass, field as dataclass_field
from pymongo import MongoClient
from pymongo import ReturnDocument
from pymongo.command_cursor import CommandCursor
Expand All @@ -23,7 +23,7 @@
from ..exceptions import SQLDecodeError, MigrationError, print_warn
from .functions import SQLFunc
from .sql_tokens import (SQLToken, SQLStatement, SQLIdentifier,
AliasableToken, SQLConstIdentifier)
AliasableToken, SQLConstIdentifier, SQLColumnDef, SQLColumnConstraint)
from .converters import (
ColumnSelectConverter, AggColumnSelectConverter, FromConverter, WhereConverter,
AggWhereConverter, InnerJoinConverter, OuterJoinConverter, LimitConverter, AggLimitConverter, OrderConverter,
Expand All @@ -33,15 +33,16 @@
from djongo import base
logger = getLogger(__name__)


@dataclass
class TokenAlias:
alias2token: Dict[str, U[AliasableToken,
SQLFunc,
SQLIdentifier]] = field(default_factory=dict)
SQLIdentifier]] = dataclass_field(default_factory=dict)
token2alias: Dict[U[AliasableToken,
SQLFunc,
SQLIdentifier], str] = field(default_factory=dict)
aliased_names: Set[str] = field(default_factory=set)
SQLIdentifier], str] = dataclass_field(default_factory=dict)
aliased_names: Set[str] = dataclass_field(default_factory=set)


class BaseQuery(abc.ABC):
Expand Down Expand Up @@ -639,80 +640,89 @@ class CreateQuery(DDLQuery):
def __init__(self, *args):
super().__init__(*args)

def parse(self):
statement = SQLStatement(self.statement)
statement.skip(2)
def _create_table(self, statement):
if '__schema__' not in self.connection_properties.cached_collections:
self.db.create_collection('__schema__')
self.connection_properties.cached_collections.add('__schema__')
self.db['__schema__'].create_index('name', unique=True)
self.db['__schema__'].create_index('auto')

tok = statement.next()
if tok.match(tokens.Keyword, 'TABLE'):
if '__schema__' not in self.connection_properties.cached_collections:
self.db.create_collection('__schema__')
self.connection_properties.cached_collections.add('__schema__')
self.db['__schema__'].create_index('name', unique=True)
self.db['__schema__'].create_index('auto')
table = SQLToken.token2sql(tok, self).table
try:
self.db.create_collection(table)
except CollectionInvalid:
if self.connection_properties.enforce_schema:
raise
else:
return

tok = statement.next()
table = SQLToken.token2sql(tok, self).table
try:
self.db.create_collection(table)
except CollectionInvalid:
if self.connection_properties.enforce_schema:
raise
else:
return
logger.debug('Created table: {}'.format(table))

logger.debug('Created table: {}'.format(table))
tok = statement.next()
if not isinstance(tok, Parenthesis):
raise SQLDecodeError(f'Unexpected sql syntax'
f' for column definition: {statement}')

if statement.next():
raise SQLDecodeError(f'Unexpected sql syntax'
f' for column definition: {statement}')

_filter = {
'name': table
}
_set = {}
push = {}
update = {}

for col in SQLColumnDef.statement2col_defs(tok):
if isinstance(col, SQLColumnConstraint):
print_warn('column CONSTRAINTS')
else:
field = col.name
if field == '_id':
continue

tok = statement.next()
if isinstance(tok, Parenthesis):
_filter = {
'name': table
_set[f'fields.{field}'] = {
'type_code': col.data_type
}
_set = {}
push = {}
update = {}

for col in tok.value.strip('()').split(','):
props = col.strip().split(' ')
field = props[0].strip('"')
type_code = props[1]

_set[f'fields.{field}'] = {
'type_code': type_code
}

if field == '_id':
continue

if col.find('AUTOINCREMENT') != -1:
try:
push['auto.field_names']['$each'].append(field)
except KeyError:
push['auto.field_names'] = {
'$each': [field]
}

_set['auto.seq'] = 0

if col.find('PRIMARY KEY') != -1:
self.db[table].create_index(field, unique=True, name='__primary_key__')

if col.find('UNIQUE') != -1:
self.db[table].create_index(field, unique=True)

if col.find('NOT NULL') != -1:
print_warn('NOT NULL column validation check')

if _set:
update['$set'] = _set
if push:
update['$push'] = push
if update:
self.db['__schema__'].update_one(
filter=_filter,
update=update,
upsert=True
)
if SQLColumnDef.autoincrement in col.col_constraints:
try:
push['auto.field_names']['$each'].append(field)
except KeyError:
push['auto.field_names'] = {
'$each': [field]
}
_set['auto.seq'] = 0

if SQLColumnDef.primarykey in col.col_constraints:
self.db[table].create_index(field, unique=True, name='__primary_key__')

if SQLColumnDef.unique in col.col_constraints:
self.db[table].create_index(field, unique=True)

if (SQLColumnDef.not_null in col.col_constraints or
SQLColumnDef.null in col.col_constraints):
print_warn('NULL, NOT NULL column validation check')

if _set:
update['$set'] = _set
if push:
update['$push'] = push
if update:
self.db['__schema__'].update_one(
filter=_filter,
update=update,
upsert=True
)

def parse(self):
statement = SQLStatement(self.statement)
statement.skip(2)
tok = statement.next()
if tok.match(tokens.Keyword, 'TABLE'):
self._create_table(statement)
elif tok.match(tokens.Keyword, 'DATABASE'):
pass
else:
Expand Down
93 changes: 88 additions & 5 deletions djongo/sql2mongo/sql_tokens.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import abc
import re
from typing import Union as U, Iterator

from typing import Union as U, Iterator, Optional as O
from pymongo import ASCENDING, DESCENDING
from sqlparse import tokens, parse as sqlparse
from sqlparse.sql import Token, Identifier, Function, Comparison, Parenthesis, IdentifierList, Statement
from . import query as query_module
from ..exceptions import SQLDecodeError
from ..exceptions import SQLDecodeError, NotSupportedError

all_token_types = U['SQLConstIdentifier',
'djongo.sql2mongo.functions.CountFunc',
Expand Down Expand Up @@ -222,6 +221,7 @@ def get_value(self, tok: Token):
else:
raise SQLDecodeError


class SQLStatement:

@property
Expand Down Expand Up @@ -249,9 +249,12 @@ def __getitem__(self, item: slice):
sql = sqlparse(sql)[0]
return SQLStatement(sql)

def next(self) -> Token:
def next(self) -> O[Token]:
# self._tok_id, token = self._statement.token_next(self._tok_id)
return next(self._gen_inst)
try:
return next(self._gen_inst)
except StopIteration:
return None

def skip(self, num):
self._tok_id += num
Expand All @@ -271,6 +274,86 @@ def _generator(self):
self._tok_id, token = self._statement.token_next(self._tok_id)


class SQLColumnDef:

not_null = object()
unique = object()
autoincrement = object()
primarykey = object()
null = object()
_map = {
'UNIQUE': unique,
'AUTOINCREMENT': autoincrement,
'PRIMARY KEY': primarykey,
'NOT NULL': not_null,
'NULL': null
}

def __init__(self,
name: str = None,
data_type: str = None,
col_constraints: set = None):
self.name = name
self.data_type = data_type
self.col_constraints = col_constraints

@staticmethod
def _get_constraints(others: str):
while others:
try:
name, others = others.split(' ', 1)
except ValueError:
name = others
others = None
try:
yield SQLColumnDef._map[name]
except KeyError:
if others:
try:
part2, others = others.split(' ', 1)
except ValueError:
part2 = others
others = None

name = f'{name} {part2}'
try:
yield SQLColumnDef._map[name]
except KeyError:
raise SQLDecodeError(f'Unknown column constraint: {name}')
else:
raise SQLDecodeError(f'Unknown column constraint: {name}')

@staticmethod
def statement2col_defs(token: Token):
from djongo.base import DatabaseWrapper
supported_data_types = set(DatabaseWrapper.data_types.values())

defs = token.value.strip('()').split(',')
for col in defs:
col = col.strip()
name, other = col.split(' ', 1)
if name == 'CONSTRAINT':
yield SQLColumnConstraint()
else:
if col[0] != '"':
raise SQLDecodeError('Column identifier not quoted')
name, other = col[1:].split('"', 1)
other = other.strip()

data_type, constraint_sql = other.split(' ', 1)
if data_type not in supported_data_types:
raise NotSupportedError(f'Data of type: {data_type}')

col_constraints = set(SQLColumnDef._get_constraints(constraint_sql))
yield SQLColumnDef(name=name,
data_type=data_type,
col_constraints=col_constraints)


class SQLColumnConstraint(SQLColumnDef):
pass


ORDER_BY_MAP = {
'ASC': ASCENDING,
'DESC': DESCENDING
Expand Down
Loading

0 comments on commit 7795fd2

Please sign in to comment.