Skip to content

Commit

Permalink
Implemented several arrayReferenceField Methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
nesdis committed Jun 2, 2018
1 parent d820ef7 commit d4d3542
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 43 deletions.
2 changes: 1 addition & 1 deletion djongo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@

__version__ = '1.2.28'
__version__ = '1.2.29'
192 changes: 159 additions & 33 deletions djongo/models/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from django import forms
from django.core.exceptions import ValidationError
from django.db import router, connections
from django.db import router, connections, transaction
from django.db import connections as pymongo_connections
import typing
import functools
Expand Down Expand Up @@ -388,28 +388,31 @@ class EmbeddedModelField(Field):
Example:
class Author(models.Model):
class Blog(models.Model):
name = models.CharField(max_length=100)
email = models.CharField(max_length=100)
tagline = models.TextField()
class Meta:
abstract = True
class AuthorForm(forms.ModelForm):
class BlogForm(forms.ModelForm):
class Meta:
model = Author
model = Blog
fields = (
'name', 'email'
'comment', 'author'
)
class MultipleBlogPosts(models.Model):
h1 = models.CharField(max_length=100)
content = models.ArrayModelField(
model_container=BlogContent,
model_form_class=BlogContentForm
class Entry(models.Model):
blog = models.EmbeddedModelField(
model_container=Blog,
model_form_class=BlogForm
)
headline = models.CharField(max_length=255)
objects = models.DjongoManager()
"""
empty_strings_allowed = False

Expand Down Expand Up @@ -483,8 +486,8 @@ def to_python(self, value):
return value
assert isinstance(value, dict)

self.instance = make_mdl(self.model_container, value)
return self.instance
instance = make_mdl(self.model_container, value)
return instance

def formfield(self, **kwargs):
defaults = {
Expand Down Expand Up @@ -648,6 +651,33 @@ def get_queryset(self):
queryset = super().get_queryset()
return self._apply_rel_filters(queryset)

def update_or_create(self, **kwargs):
db = router.db_for_write(self.instance.__class__, instance=self.instance)
obj, created = super(ArrayReferenceManagerMixin, self.db_manager(db)).update_or_create(**kwargs)
# We only need to add() if created because if we got an object back
# from get() then the relationship already exists.
if created:
self.add(obj)
return obj, created
update_or_create.alters_data = True

def get_or_create(self, **kwargs):
db = router.db_for_write(self.instance.__class__, instance=self.instance)
obj, created = super(ArrayReferenceManagerMixin, self.db_manager(db)).get_or_create(**kwargs)
# We only need to add() if created because if we got an object back
# from get() then the relationship already exists.
if created:
self.add(obj)
return obj, created
get_or_create.alters_data = True

def create(self, **kwargs):
db = router.db_for_write(self.instance.__class__, instance=self.instance)
new_obj = super(ArrayReferenceManagerMixin, self.db_manager(db)).create(**kwargs)
self.add(new_obj)
return new_obj
create.alters_data = True


def create_reverse_array_reference_manager(superclass, rel):
if issubclass(superclass, DjongoManager):
Expand All @@ -665,6 +695,12 @@ def __init__(self, instance):
name = rel.remote_field.name
self.core_filters = {name: instance}

def __call__(self, *, manager):
manager = getattr(self.model, manager)
manager_class = create_reverse_array_reference_manager(manager.__class__, rel)
return manager_class(instance=self.instance)
do_not_call_in_templates = True

def _apply_rel_filters(self, queryset):
queryset = super()._apply_rel_filters(queryset)
db = self._db or router.db_for_read(self.model, instance=self.instance)
Expand All @@ -686,16 +722,36 @@ def _make_filter(self, *objs):

def add(self, *objs):
_filter = self._make_filter(*objs)
for lh_field, rh_field in self.field.related_fields:
self.mongo_update(
_filter,
{
'$addToSet': {
lh_field.get_attname():
getattr(self.instance, rh_field.get_attname())
}
lh_field, rh_field = self.field.related_fields[0]
self.mongo_update(
_filter,
{
'$addToSet': {
lh_field.get_attname():
getattr(self.instance, rh_field.get_attname())
}
)
}
)
for obj in objs:
fk_field = getattr(obj, lh_field.get_attname())
fk_field.add(getattr(self.instance, rh_field.get_attname()))
add.alters_data = True

def remove(self, *objs):
pass
remove.alters_data = True

def clear(self):
pass
clear.alters_data = True

def set(self, objs, *, clear=False):
pass
set.alters_data = True

def create(self, **kwargs):
pass
create.alters_data = True

return ReverseArrayReferenceManager

Expand All @@ -721,6 +777,12 @@ def __init__(self, instance):

self.core_filters = {f'{name}__in': ids}

def __call__(self, *, manager):
manager = getattr(self.model, manager)
manager_class = create_forward_array_reference_manager(manager.__class__, rel)
return manager_class(instance=self.instance)
do_not_call_in_templates = True

def _apply_rel_filters(self, queryset):
queryset = super()._apply_rel_filters(queryset)
if not getattr(self.instance, self.field.attname):
Expand All @@ -744,9 +806,9 @@ def add(self, *objs):
setattr(self.instance, self.field.get_attname(), fks)

new_fks = set()
rh_field = self.field.foreign_related_fields[0]
for obj in objs:
for rh_field in self.field.foreign_related_fields:
new_fks.add(getattr(obj, rh_field.get_attname()))
new_fks.add(getattr(obj, rh_field.get_attname()))
fks.update(new_fks)

db = router.db_for_write(self.instance.__class__, instance=self.instance)
Expand All @@ -760,13 +822,68 @@ def add(self, *objs):
}
}
)

add.alters_data = True

def remove(self, *objs):
pass
to_del = set(
getattr(obj, self.field.foreign_related_fields[0].attname)
for obj in objs
)
self._remove(to_del)

remove.alters_data = True

def _remove(self, to_del):
fks = getattr(self.instance, self.field.attname)
fks.difference_update(to_del)
db = self._db or router.db_for_write(self.instance.__class__, instance=self.instance)
self.instance_manager.db_manager(db).mongo_update(
self._make_filter(),
{
'$pull': {
self.field.attname: {
'$in': list(to_del)
}
}
}
)

def clear(self):
pass
db = router.db_for_write(self.instance.__class__, instance=self.instance)
self.instance_manager.db_manager(db).mongo_update(
self._make_filter(),
{
'$set': {
self.field.attname: []
}
}
)
setattr(self.instance, self.field.attname, {})

clear.alters_data = True

def set(self, objs, *, clear=False):
objs = tuple(objs)

db = router.db_for_write(self.through, instance=self.instance)
with transaction.atomic(using=db, savepoint=False):
if clear:
self.clear()
self.add(*objs)
else:
fks = getattr(self.instance, self.field.attname)
rh_field = self.field.foreign_related_fields[0]
new_fks = set(getattr(obj, rh_field.get_attname()) for obj in objs)
to_del = fks - new_fks
self._remove(to_del)
fks = getattr(self.instance, self.field.attname)
to_add = []
for obj in objs:
if getattr(obj, rh_field.get_attname()) not in fks:
to_add.append(obj)
self.add(to_add)

set.alters_data = True

return ArrayReferenceManager

Expand Down Expand Up @@ -854,7 +971,7 @@ def __init__(self, to, on_delete=None, related_name=None, related_query_name=Non
limit_choices_to=None, parent_link=False, to_field=None,
db_constraint=True, **kwargs):

on_delete = on_delete or CASCADE
on_delete = on_delete or self._on_delete
super().__init__(to, on_delete=on_delete, related_name=related_name,
related_query_name=related_query_name,
limit_choices_to=limit_choices_to,
Expand All @@ -863,9 +980,20 @@ def __init__(self, to, on_delete=None, related_name=None, related_query_name=Non

self.concrete = False

# def contribute_to_class(self, cls, name, private_only=False, **kwargs):
# super().contribute_to_class(cls, name, private_only, **kwargs)
# cls._meta.local_fields.remove(self)
@staticmethod
def _on_delete(collector, field, sub_objs, using):
for model, instances in collector.data.items():
for obj in sub_objs:
getattr(obj, field.name).db_manager(using).remove(*instances)


def from_db_value(self, value, expression, connection, context):
return self.to_python(value)

def to_python(self, value):
if value is None:
return set()
return set(value)

def get_db_prep_value(self, value, connection, prepared=False):
if value is None:
Expand All @@ -879,5 +1007,3 @@ def get_db_prep_save(self, value, connection):
return list(value)


class GenericReferenceField(FieldCacheMixin):
pass
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from djongo import models
from djongo.models import DjongoManager
from djongo.models import DjongoManager, CASCADE
from django import forms


class ReferenceBlog(models.Model):
name = models.CharField(max_length=100)
tagline = models.TextField()
_id = models.ObjectIdField()

class Meta:
abstract = True
# class Meta:
# abstract = True


class BlogForm(forms.ModelForm):
Expand Down Expand Up @@ -73,6 +74,7 @@ class ReferenceEntry(models.Model):

# authors = models.ArrayReferenceField(ReferenceAuthor)
authors = models.ArrayReferenceField(ReferenceAuthor)
# authors_fk = models.ForeignKey(ReferenceBlog, CASCADE)

# n_comments = models.IntegerField()

Expand Down
11 changes: 11 additions & 0 deletions tests/djongo_tests/project/dummy/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@ def test_create(self):
email='[email protected]'
)

self.assertEqual([], list(e1.authors.all()))
self.assertEqual([], list(a1.referenceentry_set.all()))

e1.authors.add(a1)
self.assertEqual(e1.authors_id, {a1.pk})
self.assertEqual([a1], list(e1.authors.all()))
self.assertEqual([e1], list(a1.referenceentry_set.all()))

e2.authors.add(a1,a2)
self.assertEqual(e2.authors_id, {a1.pk, a2.pk})
self.assertEqual([a1, a2], list(e2.authors.all()))
self.assertEqual([e1, e2], list(a1.referenceentry_set.all()))
self.assertEqual([e2], list(a2.referenceentry_set.all()))
Expand All @@ -38,8 +43,14 @@ def test_create(self):
self.assertEqual([e1, e2], g)

a2.referenceentry_set.add(e1)
self.assertEqual(e1.authors_id, {a1.pk, a2.pk})
self.assertEqual([e1, e2], list(a2.referenceentry_set.all()))

a2.delete()
self.assertEqual([a1], list(e2.authors.all()))
self.assertEqual([a1], list(e1.authors.all()))
self.assertEqual(e1.authors_id, {a1.pk})

class TestEmbedded(TestCase):

@classmethod
Expand Down
12 changes: 6 additions & 6 deletions tests/djongo_tests/project/project/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@

INSTALLED_APPS = [
'dummy.apps.DummyConfig',
'django.contrib.admin',
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.messages',
'django.contrib.staticfiles',
# 'django.contrib.admin',
# 'django.contrib.auth',
# 'django.contrib.contenttypes',
# 'django.contrib.sessions',
# 'django.contrib.messages',
# 'django.contrib.staticfiles',
]


Expand Down

0 comments on commit d4d3542

Please sign in to comment.