Skip to content

Commit

Permalink
Refs #33397 -- Added extra tests for resolving an output_field of Com…
Browse files Browse the repository at this point in the history
…binedExpression.
  • Loading branch information
spookylukey authored and felixxm committed Mar 30, 2022
1 parent fac662f commit 04ad0f2
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 1 deletion.
1 change: 1 addition & 0 deletions tests/annotations/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ def test_combined_annotation_commutative(self):
book2 = Book.objects.annotate(adjusted_rating=None + F("rating")).get(
pk=self.b1.pk
)
self.assertIs(book1.adjusted_rating, None)
self.assertEqual(book1.adjusted_rating, book2.adjusted_rating)

def test_update_with_annotation(self):
Expand Down
90 changes: 89 additions & 1 deletion tests/expressions/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1724,6 +1724,28 @@ def test_date_comparison(self):
]
self.assertEqual(test_set, self.expnames[: i + 1])

def test_datetime_and_durationfield_addition_with_filter(self):
test_set = Experiment.objects.filter(end=F("start") + F("estimated_time"))
self.assertGreater(test_set.count(), 0)
self.assertEqual(
[e.name for e in test_set],
[
e.name
for e in Experiment.objects.all()
if e.end == e.start + e.estimated_time
],
)

@skipUnlessDBFeature("supports_temporal_subtraction")
def test_datetime_subtraction_with_annotate_and_no_output_field(self):
test_set = Experiment.objects.annotate(
calculated_duration=F("end") - F("start")
)
self.assertEqual(
[e.calculated_duration for e in test_set],
[e.end - e.start for e in test_set],
)

def test_mixed_comparisons1(self):
for i, delay in enumerate(self.delays):
test_set = [
Expand Down Expand Up @@ -2373,7 +2395,7 @@ def test_reversed_xor(self):


class CombinedExpressionTests(SimpleTestCase):
def test_resolve_output_field(self):
def test_resolve_output_field_number(self):
tests = [
(IntegerField, AutoField, IntegerField),
(AutoField, IntegerField, IntegerField),
Expand All @@ -2395,6 +2417,72 @@ def test_resolve_output_field(self):
)
self.assertIsInstance(expr.output_field, combined)

def test_resolve_output_field_with_null(self):
def null():
return Value(None)

tests = [
# Numbers.
(AutoField, Combinable.ADD, null),
(DecimalField, Combinable.ADD, null),
(FloatField, Combinable.ADD, null),
(IntegerField, Combinable.ADD, null),
(IntegerField, Combinable.SUB, null),
(null, Combinable.ADD, IntegerField),
# Dates.
(DateField, Combinable.ADD, null),
(DateTimeField, Combinable.ADD, null),
(DurationField, Combinable.ADD, null),
(TimeField, Combinable.ADD, null),
(TimeField, Combinable.SUB, null),
(null, Combinable.ADD, DateTimeField),
(DateField, Combinable.SUB, null),
]
msg = "Expression contains mixed types: "
for lhs, connector, rhs in tests:
with self.subTest(lhs=lhs, connector=connector, rhs=rhs):
expr = CombinedExpression(
Expression(lhs()),
connector,
Expression(rhs()),
)
with self.assertRaisesMessage(FieldError, msg):
expr.output_field

def test_resolve_output_field_dates(self):
tests = [
# Add - same type.
(DurationField, Combinable.ADD, DurationField, DurationField),
# Subtract - same type.
(DurationField, Combinable.SUB, DurationField, DurationField),
# Subtract - different type.
(DurationField, Combinable.SUB, DateField, FieldError),
(DurationField, Combinable.SUB, DateTimeField, FieldError),
(DurationField, Combinable.SUB, DateTimeField, FieldError),
]
msg = "Expression contains mixed types: "
for lhs, connector, rhs, combined in tests:
with self.subTest(lhs=lhs, connector=connector, rhs=rhs, combined=combined):
expr = CombinedExpression(
Expression(lhs()),
connector,
Expression(rhs()),
)
if issubclass(combined, Exception):
with self.assertRaisesMessage(combined, msg):
expr.output_field
else:
self.assertIsInstance(expr.output_field, combined)

def test_mixed_char_date_with_annotate(self):
queryset = Experiment.objects.annotate(nonsense=F("name") + F("assigned"))
msg = (
"Expression contains mixed types: CharField, DateField. You must set "
"output_field."
)
with self.assertRaisesMessage(FieldError, msg):
list(queryset)


class ExpressionWrapperTests(SimpleTestCase):
def test_empty_group_by(self):
Expand Down

0 comments on commit 04ad0f2

Please sign in to comment.