Skip to content

Commit

Permalink
Fixed #29048 -- Added **extra_context to database function as_vendor(…
Browse files Browse the repository at this point in the history
…) methods.
  • Loading branch information
priyanshsaxena authored and timgraham committed Aug 23, 2018
1 parent 08f3603 commit 83b04d4
Show file tree
Hide file tree
Showing 11 changed files with 108 additions and 83 deletions.
4 changes: 2 additions & 2 deletions django/contrib/gis/db/models/aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def as_sql(self, compiler, connection, function=None, **extra_context):
**extra_context
)

def as_oracle(self, compiler, connection):
def as_oracle(self, compiler, connection, **extra_context):
tolerance = self.extra.get('tolerance') or getattr(self, 'tolerance', 0.05)
template = None if self.is_extent else '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))'
return self.as_sql(compiler, connection, template=template, tolerance=tolerance)
return self.as_sql(compiler, connection, template=template, tolerance=tolerance, **extra_context)

def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
Expand Down
48 changes: 26 additions & 22 deletions django/contrib/gis/db/models/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,23 @@ class SQLiteDecimalToFloatMixin:
By default, Decimal values are converted to str by the SQLite backend, which
is not acceptable by the GIS functions expecting numeric values.
"""
def as_sqlite(self, compiler, connection):
def as_sqlite(self, compiler, connection, **extra_context):
for expr in self.get_source_expressions():
if hasattr(expr, 'value') and isinstance(expr.value, Decimal):
expr.value = float(expr.value)
return super().as_sql(compiler, connection)
return super().as_sql(compiler, connection, **extra_context)


class OracleToleranceMixin:
tolerance = 0.05

def as_oracle(self, compiler, connection):
def as_oracle(self, compiler, connection, **extra_context):
tol = self.extra.get('tolerance', self.tolerance)
return self.as_sql(compiler, connection, template="%%(function)s(%%(expressions)s, %s)" % tol)
return self.as_sql(
compiler, connection,
template="%%(function)s(%%(expressions)s, %s)" % tol,
**extra_context
)


class Area(OracleToleranceMixin, GeoFunc):
Expand Down Expand Up @@ -181,11 +185,11 @@ def as_oracle(self, compiler, connection, **extra_context):


class AsKML(AsGML):
def as_sqlite(self, compiler, connection):
def as_sqlite(self, compiler, connection, **extra_context):
# No version parameter
clone = self.copy()
clone.set_source_expressions(self.get_source_expressions()[1:])
return clone.as_sql(compiler, connection)
return clone.as_sql(compiler, connection, **extra_context)


class AsSVG(GeoFunc):
Expand All @@ -205,10 +209,10 @@ class BoundingCircle(OracleToleranceMixin, GeoFunc):
def __init__(self, expression, num_seg=48, **extra):
super().__init__(expression, num_seg, **extra)

def as_oracle(self, compiler, connection):
def as_oracle(self, compiler, connection, **extra_context):
clone = self.copy()
clone.set_source_expressions([self.get_source_expressions()[0]])
return super(BoundingCircle, clone).as_oracle(compiler, connection)
return super(BoundingCircle, clone).as_oracle(compiler, connection, **extra_context)


class Centroid(OracleToleranceMixin, GeomOutputGeoFunc):
Expand Down Expand Up @@ -239,7 +243,7 @@ def __init__(self, expr1, expr2, spheroid=None, **extra):
self.spheroid = self._handle_param(spheroid, 'spheroid', bool)
super().__init__(*expressions, **extra)

def as_postgresql(self, compiler, connection):
def as_postgresql(self, compiler, connection, **extra_context):
clone = self.copy()
function = None
expr2 = clone.source_expressions[1]
Expand All @@ -262,7 +266,7 @@ def as_postgresql(self, compiler, connection):
clone.source_expressions.append(Value(self.geo_field.spheroid(connection)))
else:
function = connection.ops.spatial_function_name('DistanceSphere')
return super(Distance, clone).as_sql(compiler, connection, function=function)
return super(Distance, clone).as_sql(compiler, connection, function=function, **extra_context)

def as_sqlite(self, compiler, connection, **extra_context):
if self.geo_field.geodetic(connection):
Expand Down Expand Up @@ -300,12 +304,12 @@ def __init__(self, expression, precision=None, **extra):
expressions.append(self._handle_param(precision, 'precision', int))
super().__init__(*expressions, **extra)

def as_mysql(self, compiler, connection):
def as_mysql(self, compiler, connection, **extra_context):
clone = self.copy()
# If no precision is provided, set it to the maximum.
if len(clone.source_expressions) < 2:
clone.source_expressions.append(Value(100))
return clone.as_sql(compiler, connection)
return clone.as_sql(compiler, connection, **extra_context)


class Intersection(OracleToleranceMixin, GeomOutputGeoFunc):
Expand Down Expand Up @@ -333,7 +337,7 @@ def as_sql(self, compiler, connection, **extra_context):
raise NotSupportedError("This backend doesn't support Length on geodetic fields")
return super().as_sql(compiler, connection, **extra_context)

def as_postgresql(self, compiler, connection):
def as_postgresql(self, compiler, connection, **extra_context):
clone = self.copy()
function = None
if self.source_is_geography():
Expand All @@ -346,13 +350,13 @@ def as_postgresql(self, compiler, connection):
dim = min(f.dim for f in self.get_source_fields() if f)
if dim > 2:
function = connection.ops.length3d
return super(Length, clone).as_sql(compiler, connection, function=function)
return super(Length, clone).as_sql(compiler, connection, function=function, **extra_context)

def as_sqlite(self, compiler, connection):
def as_sqlite(self, compiler, connection, **extra_context):
function = None
if self.geo_field.geodetic(connection):
function = 'GeodesicLength' if self.spheroid else 'GreatCircleLength'
return super().as_sql(compiler, connection, function=function)
return super().as_sql(compiler, connection, function=function, **extra_context)


class LineLocatePoint(GeoFunc):
Expand Down Expand Up @@ -383,19 +387,19 @@ class NumPoints(GeoFunc):
class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
arity = 1

def as_postgresql(self, compiler, connection):
def as_postgresql(self, compiler, connection, **extra_context):
function = None
if self.geo_field.geodetic(connection) and not self.source_is_geography():
raise NotSupportedError("ST_Perimeter cannot use a non-projected non-geography field.")
dim = min(f.dim for f in self.get_source_fields())
if dim > 2:
function = connection.ops.perimeter3d
return super().as_sql(compiler, connection, function=function)
return super().as_sql(compiler, connection, function=function, **extra_context)

def as_sqlite(self, compiler, connection):
def as_sqlite(self, compiler, connection, **extra_context):
if self.geo_field.geodetic(connection):
raise NotSupportedError("Perimeter cannot use a non-projected field.")
return super().as_sql(compiler, connection)
return super().as_sql(compiler, connection, **extra_context)


class PointOnSurface(OracleToleranceMixin, GeomOutputGeoFunc):
Expand Down Expand Up @@ -454,12 +458,12 @@ def __init__(self, expression, srid, **extra):


class Translate(Scale):
def as_sqlite(self, compiler, connection):
def as_sqlite(self, compiler, connection, **extra_context):
clone = self.copy()
if len(self.source_expressions) < 4:
# Always provide the z parameter for ST_Translate
clone.source_expressions.append(Value(0))
return super(Translate, clone).as_sqlite(compiler, connection)
return super(Translate, clone).as_sqlite(compiler, connection, **extra_context)


class Union(OracleToleranceMixin, GeomOutputGeoFunc):
Expand Down
21 changes: 12 additions & 9 deletions django/db/models/aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def as_sql(self, compiler, connection, **extra_context):
if connection.features.supports_aggregate_filter_clause:
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
template = self.filter_template % extra_context.get('template', self.template)
sql, params = super().as_sql(compiler, connection, template=template, filter=filter_sql)
sql, params = super().as_sql(
compiler, connection, template=template, filter=filter_sql,
**extra_context
)
return sql, params + filter_params
else:
copy = self.copy()
Expand Down Expand Up @@ -92,20 +95,20 @@ def _resolve_output_field(self):
return FloatField()
return super()._resolve_output_field()

def as_mysql(self, compiler, connection):
sql, params = super().as_sql(compiler, connection)
def as_mysql(self, compiler, connection, **extra_context):
sql, params = super().as_sql(compiler, connection, **extra_context)
if self.output_field.get_internal_type() == 'DurationField':
sql = 'CAST(%s as SIGNED)' % sql
return sql, params

def as_oracle(self, compiler, connection):
def as_oracle(self, compiler, connection, **extra_context):
if self.output_field.get_internal_type() == 'DurationField':
expression = self.get_source_expressions()[0]
from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
return compiler.compile(
SecondsToInterval(Avg(IntervalToSeconds(expression), filter=self.filter))
)
return super().as_sql(compiler, connection)
return super().as_sql(compiler, connection, **extra_context)


class Count(Aggregate):
Expand Down Expand Up @@ -157,20 +160,20 @@ class Sum(Aggregate):
function = 'SUM'
name = 'Sum'

def as_mysql(self, compiler, connection):
sql, params = super().as_sql(compiler, connection)
def as_mysql(self, compiler, connection, **extra_context):
sql, params = super().as_sql(compiler, connection, **extra_context)
if self.output_field.get_internal_type() == 'DurationField':
sql = 'CAST(%s as SIGNED)' % sql
return sql, params

def as_oracle(self, compiler, connection):
def as_oracle(self, compiler, connection, **extra_context):
if self.output_field.get_internal_type() == 'DurationField':
expression = self.get_source_expressions()[0]
from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
return compiler.compile(
SecondsToInterval(Sum(IntervalToSeconds(expression)))
)
return super().as_sql(compiler, connection)
return super().as_sql(compiler, connection, **extra_context)


class Variance(Aggregate):
Expand Down
22 changes: 11 additions & 11 deletions django/db/models/functions/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ def as_sql(self, compiler, connection, **extra_context):
extra_context['db_type'] = self.output_field.cast_db_type(connection)
return super().as_sql(compiler, connection, **extra_context)

def as_mysql(self, compiler, connection):
def as_mysql(self, compiler, connection, **extra_context):
# MySQL doesn't support explicit cast to float.
template = '(%(expressions)s + 0.0)' if self.output_field.get_internal_type() == 'FloatField' else None
return self.as_sql(compiler, connection, template=template)
return self.as_sql(compiler, connection, template=template, **extra_context)

def as_postgresql(self, compiler, connection):
def as_postgresql(self, compiler, connection, **extra_context):
# CAST would be valid too, but the :: shortcut syntax is more readable.
# 'expressions' is wrapped in parentheses in case it's a complex
# expression.
return self.as_sql(compiler, connection, template='(%(expressions)s)::%(db_type)s')
return self.as_sql(compiler, connection, template='(%(expressions)s)::%(db_type)s', **extra_context)


class Coalesce(Func):
Expand All @@ -35,7 +35,7 @@ def __init__(self, *expressions, **extra):
raise ValueError('Coalesce must take at least two expressions')
super().__init__(*expressions, **extra)

def as_oracle(self, compiler, connection):
def as_oracle(self, compiler, connection, **extra_context):
# Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2),
# so convert all fields to NCLOB when that type is expected.
if self.output_field.get_internal_type() == 'TextField':
Expand All @@ -47,8 +47,8 @@ class ToNCLOB(Func):
]
clone = self.copy()
clone.set_source_expressions(expressions)
return super(Coalesce, clone).as_sql(compiler, connection)
return self.as_sql(compiler, connection)
return super(Coalesce, clone).as_sql(compiler, connection, **extra_context)
return self.as_sql(compiler, connection, **extra_context)


class Greatest(Func):
Expand All @@ -66,9 +66,9 @@ def __init__(self, *expressions, **extra):
raise ValueError('Greatest must take at least two expressions')
super().__init__(*expressions, **extra)

def as_sqlite(self, compiler, connection):
def as_sqlite(self, compiler, connection, **extra_context):
"""Use the MAX function on SQLite."""
return super().as_sqlite(compiler, connection, function='MAX')
return super().as_sqlite(compiler, connection, function='MAX', **extra_context)


class Least(Func):
Expand All @@ -86,6 +86,6 @@ def __init__(self, *expressions, **extra):
raise ValueError('Least must take at least two expressions')
super().__init__(*expressions, **extra)

def as_sqlite(self, compiler, connection):
def as_sqlite(self, compiler, connection, **extra_context):
"""Use the MIN function on SQLite."""
return super().as_sqlite(compiler, connection, function='MIN')
return super().as_sqlite(compiler, connection, function='MIN', **extra_context)
4 changes: 2 additions & 2 deletions django/db/models/functions/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ class Now(Func):
template = 'CURRENT_TIMESTAMP'
output_field = fields.DateTimeField()

def as_postgresql(self, compiler, connection):
def as_postgresql(self, compiler, connection, **extra_context):
# PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
# transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
# other databases.
return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()')
return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()', **extra_context)


class TruncBase(TimezoneMixin, Transform):
Expand Down
Loading

0 comments on commit 83b04d4

Please sign in to comment.