Skip to content

Commit

Permalink
Merge pull request pallets-eco#1612 from jasgrider/master
Browse files Browse the repository at this point in the history
Updating views.py to support changes in Peewee version 3.0 (for Field…
  • Loading branch information
mrjoes authored May 16, 2018
2 parents 247e19f + 704519b commit 765e442
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 18 deletions.
9 changes: 4 additions & 5 deletions flask_admin/contrib/peewee/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
from peewee import (CharField, DateTimeField, DateField, TimeField,
PrimaryKeyField, ForeignKeyField)

# Fix for Issue: #1602 & #1606
# Section below trys BaseModel (for versions of PeeWee < 3.0) and if that fails,
# load the new ModelBase as BaseModel (to not break things looking for BaseModel in flask-peewee and etc.)

try:
from peewee import BaseModel
except ImportError:
Expand Down Expand Up @@ -274,7 +270,10 @@ def contribute(self, converter, model, form_class, inline_model):
allow_pk=True,
converter=converter)

prop_name = reverse_field.related_name
try:
prop_name = reverse_field.related_name
except AttributeError:
prop_name = reverse_field.backref

label = self.get_label(info, prop_name)

Expand Down
48 changes: 35 additions & 13 deletions flask_admin/contrib/peewee/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,24 @@ def scaffold_filters(self, name):
raise Exception('Failed to find field for filter: %s' % name)

# Check if field is in different model
if attr.model_class != self.model:
visible_name = '%s / %s' % (self.get_column_name(attr.model_class.__name__),
try:
if attr.model_class != self.model:
visible_name = '%s / %s' % (self.get_column_name(attr.model_class.__name__),
self.get_column_name(attr.name))
else:
if not isinstance(name, string_types):
visible_name = self.get_column_name(attr.name)
else:
visible_name = self.get_column_name(name)
if not isinstance(name, string_types):
visible_name = self.get_column_name(attr.name)
else:
visible_name = self.get_column_name(name)
except AttributeError:
if attr.model != self.model:
visible_name = '%s / %s' % (self.get_column_name(attr.model.__name__),
self.get_column_name(attr.name))
else:
if not isinstance(name, string_types):
visible_name = self.get_column_name(attr.name)
else:
visible_name = self.get_column_name(name)

type_name = type(attr).__name__
flt = self.filter_converter.convert(type_name,
Expand Down Expand Up @@ -307,12 +317,20 @@ def _create_ajax_loader(self, name, options):
return create_ajax_loader(self.model, name, name, options)

def _handle_join(self, query, field, joins):
if field.model_class != self.model:
model_name = field.model_class.__name__
try:
if field.model_class != self.model:
model_name = field.model_class.__name__

if model_name not in joins:
query = query.join(field.model_class, JOIN.LEFT_OUTER)
joins.add(model_name)
except AttributeError:
if field.model != self.model:
model_name = field.model.__name__

if model_name not in joins:
query = query.join(field.model_class, JOIN.LEFT_OUTER)
joins.add(model_name)
if model_name not in joins:
query = query.join(field.model, JOIN.LEFT_OUTER)
joins.add(model_name)

return query

Expand All @@ -321,8 +339,12 @@ def _order_by(self, query, joins, sort_field, sort_desc):
field = getattr(self.model, sort_field)
query = query.order_by(field.desc() if sort_desc else field.asc())
elif isinstance(sort_field, Field):
if sort_field.model_class != self.model:
query = self._handle_join(query, sort_field, joins)
try:
if sort_field.model_class != self.model:
query = self._handle_join(query, sort_field, joins)
except AttributeError:
if sort_field.model != self.model:
query = self._handle_join(query, sort_field, joins)

query = query.order_by(sort_field.desc() if sort_desc else sort_field.asc())

Expand Down

0 comments on commit 765e442

Please sign in to comment.