diff --git a/README.md b/README.md index 17037d87..9fddaecd 100644 --- a/README.md +++ b/README.md @@ -1,101 +1,89 @@ +# django-pg-extra-extended (Fork for Django 5+) +

django-postgres-extra

- + +This is a fork of `django-postgres-extra`, updated to support Django 5+ while maintaining PostgreSQL enhancements for the Django ORM. + | | | | |--------------------|---------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| :white_check_mark: | **Tests** | [![CircleCI](https://circleci.com/gh/SectorLabs/django-postgres-extra/tree/master.svg?style=svg)](https://circleci.com/gh/SectorLabs/django-postgres-extra/tree/master) | | :memo: | **License** | [![License](https://img.shields.io/:license-mit-blue.svg)](http://doge.mit-license.org) | -| :package: | **PyPi** | [![PyPi](https://badge.fury.io/py/django-postgres-extra.svg)](https://pypi.python.org/pypi/django-postgres-extra) | -| :four_leaf_clover: | **Code coverage** | [![Coverage Status](https://coveralls.io/repos/github/SectorLabs/django-postgres-extra/badge.svg?branch=coveralls)](https://coveralls.io/github/SectorLabs/django-postgres-extra?branch=master) | -| | **Django Versions** | 2.0, 2.1, 2.2, 3.0, 3.1, 3.2, 4.0, 4.1, 4.2, 5.0 | -| | **Python Versions** | 3.6, 3.7, 3.8, 3.9, 3.10, 3.11 | +| :package: | **PyPi** | (Coming soon) | +| | **Django Versions** | 5.0+ | +| | **Python Versions** | 3.8, 3.9, 3.10, 3.11, 3.12 | | | **Psycopg Versions** | 2, 3 | -| :book: | **Documentation** | [Read The Docs](https://django-postgres-extra.readthedocs.io/en/master/) | -| :warning: | **Upgrade** | [Upgrade from v1.x](https://django-postgres-extra.readthedocs.io/en/master/major_releases.html#new-features) -| :checkered_flag: | **Installation** | [Installation Guide](https://django-postgres-extra.readthedocs.io/en/master/installation.html) | -| :fire: | **Features** | [Features & Documentation](https://django-postgres-extra.readthedocs.io/en/master/index.html#features) | -| :droplet: | **Future enhancements** | [Potential features](https://github.com/SectorLabs/django-postgres-extra/issues?q=is%3Aopen+is%3Aissue+label%3Aenhancement) | - -`django-postgres-extra` aims to make all of PostgreSQL's awesome features available through the Django ORM. We do this by taking care of all the hassle. As opposed to the many small packages that are available to try to bring a single feature to Django with minimal effort. ``django-postgres-extra`` goes the extra mile, with well tested implementations, seamless migrations and much more. - -With seamless we mean that any features we add will work truly seamlessly. You should not have to manually modify your migrations to work with fields and objects provided by this package. - ---- - -:warning: **This README is for v2. See the `v1` branch for v1.x.** - ---- - -## Major features +| :book: | **Documentation** | (Coming soon) | +| :fire: | **Features** | [Features & Documentation](https://github.com/MONSTER-HARSH/django-pg-extra-extended/) | +| :droplet: | **Future enhancements** | [Potential features](https://github.com/MONSTER-HARSH/django-pg-extra-extended/issues?q=is%3Aopen+is%3Aissue+label%3Aenhancement) | -[See the full list](http://django-postgres-extra.readthedocs.io/#features) +## About -* **Native upserts** +This fork of `django-postgres-extra` extends support for Django 5+, keeping all the powerful PostgreSQL features, including: - * Single query - * Concurrency safe - * With bulk support (single query) +- **Native upserts** with bulk support +- **Extended support for HStoreField** (unique constraints, null constraints, etc.) +- **Declarative table partitioning** for PostgreSQL 11+ +- **Faster deletes** using table truncation +- **Advanced indexing options** (conditional and case-sensitive unique indexes) -* **Extended support for HStoreField** +## Installation - * Unique constraints - * Null constraints - * Select individual keys using ``.values()`` or ``.values_list()`` +Coming soon to PyPI. -* **PostgreSQL 11.x declarative table partitioning** +For now, install directly from GitHub: - * Supports both range and list partitioning +```sh +pip install git+https://github.com/MONSTER-HARSH/django-pg-extra-extended.git +``` -* **Faster deletes** +## Getting Started - * Truncate tables (with cascade) - -* **Indexes** - - * Conditional unique index. - * Case sensitive unique index. - -## Working with the code -### Prerequisites +1. Clone the repository: -* PostgreSQL 10 or newer. -* Django 2.0 or newer (including 3.x, 4.x). -* Python 3.6 or newer. + ```sh + git clone https://github.com/MONSTER-HARSH/django-pg-extra-extended.git + ``` -### Getting started +2. Create a virtual environment: -1. Clone the repository: + ```sh + cd django-pg-extra-extended + python -m venv env + source env/bin/activate + ``` - λ git clone https://github.com/SectorLabs/django-postgres-extra.git +3. Create a PostgreSQL user for testing: -2. Create a virtual environment: + ```sh + createuser --superuser psqlextra --pwprompt + export DATABASE_URL=postgres://psqlextra:@localhost/psqlextra + ``` - λ cd django-postgres-extra - λ virtualenv env - λ source env/bin/activate +4. Install dependencies: -3. Create a postgres user for use in tests (skip if your default user is a postgres superuser): + ```sh + pip install .[test] .[analysis] + ``` - λ createuser --superuser psqlextra --pwprompt - λ export DATABASE_URL=postgres://psqlextra:@localhost/psqlextra +5. Run tests: - Hint: if you're using virtualenvwrapper, you might find it beneficial to put - the ``export`` line in ``$VIRTUAL_ENV/bin/postactivate`` so that it's always - available when using this virtualenv. + ```sh + tox + ``` -4. Install the development/test dependencies: +## Migration from Original django-postgres-extra - λ pip install .[test] .[analysis] +If you're upgrading from `django-postgres-extra` (SectorLabs version), ensure that: -5. Run the tests: +- You update the package source to this fork. +- You check for any API changes due to Django 5+ compatibility adjustments. - λ tox +## Contributing -6. Run the benchmarks: +Contributions are welcome! Please open an issue or submit a pull request. - λ py.test -c pytest-benchmark.ini +## License -7. Auto-format code, sort imports and auto-fix linting errors: +MIT License. See [LICENSE](LICENSE) for details. - λ python setup.py fix diff --git a/psqlextra/__init__.py b/psqlextra/__init__.py index 474f803b..e69de29b 100644 --- a/psqlextra/__init__.py +++ b/psqlextra/__init__.py @@ -1,15 +0,0 @@ -import django - -from ._version import __version__ - -if django.VERSION < (3, 2): # pragma: no cover - default_app_config = "psqlextra.apps.PostgresExtraAppConfig" - - __all__ = [ - "default_app_config", - "__version__", - ] -else: - __all__ = [ - "__version__", - ] diff --git a/psqlextra/backend/base.py b/psqlextra/backend/base.py index c8ae73c5..5c788a05 100644 --- a/psqlextra/backend/base.py +++ b/psqlextra/backend/base.py @@ -3,10 +3,6 @@ from typing import TYPE_CHECKING from django.conf import settings -from django.contrib.postgres.signals import ( - get_hstore_oids, - register_type_handlers, -) from django.db import ProgrammingError from . import base_impl @@ -98,22 +94,3 @@ def prepare_database(self): "or add the extension manually.", exc_info=True, ) - return - - # Clear old (non-existent), stale oids. - get_hstore_oids.cache_clear() - - # Verify that we (and Django) can find the OIDs - # for hstore. - oids, _ = get_hstore_oids(self.alias) - if not oids: - logger.warning( - '"hstore" extension was created, but we cannot find the oids' - "in the database. Something went wrong.", - ) - return - - # We must trigger Django into registering the type handlers now - # so that any subsequent code can properly use the newly - # registered types. - register_type_handlers(self) diff --git a/psqlextra/backend/schema.py b/psqlextra/backend/schema.py index 31a23414..28e9211a 100644 --- a/psqlextra/backend/schema.py +++ b/psqlextra/backend/schema.py @@ -1045,7 +1045,7 @@ def _partitioning_properties_for_model(model: Type[Model]): % (model.__name__, meta.method) ) - if not isinstance(meta.key, (list, tuple)): + if not isinstance(meta.key, list): raise ImproperlyConfigured( ( "Model '%s' is not properly configured to be partitioned." diff --git a/psqlextra/manager/manager.py b/psqlextra/manager/manager.py index ee1eb58b..4b96e34f 100644 --- a/psqlextra/manager/manager.py +++ b/psqlextra/manager/manager.py @@ -8,7 +8,7 @@ from psqlextra.query import PostgresQuerySet -class PostgresManager(Manager.from_queryset(PostgresQuerySet)): # type: ignore[misc] +class PostgresManager(Manager.from_queryset(PostgresQuerySet)): """Adds support for PostgreSQL specifics.""" use_in_migrations = True @@ -37,10 +37,7 @@ def __init__(self, *args, **kwargs): ) def truncate( - self, - cascade: bool = False, - restart_identity: bool = False, - using: Optional[str] = None, + self, cascade: bool = False, using: Optional[str] = None ) -> None: """Truncates this model/table using the TRUNCATE statement. @@ -54,19 +51,14 @@ def truncate( False, an error will be raised if there are rows in other tables referencing the rows you're trying to delete. - restart_identity: - Automatically restart sequences owned by - columns of the truncated table(s). """ connection = connections[using or "default"] table_name = connection.ops.quote_name(self.model._meta.db_table) with connection.cursor() as cursor: - sql = f"TRUNCATE TABLE {table_name}" + sql = "TRUNCATE TABLE %s" % table_name if cascade: sql += " CASCADE" - if restart_identity: - sql += " RESTART IDENTITY" cursor.execute(sql) diff --git a/psqlextra/models/base.py b/psqlextra/models/base.py index d240237a..21caad36 100644 --- a/psqlextra/models/base.py +++ b/psqlextra/models/base.py @@ -1,7 +1,4 @@ -from typing import Any - from django.db import models -from django.db.models import Manager from psqlextra.manager import PostgresManager @@ -13,4 +10,4 @@ class Meta: abstract = True base_manager_name = "objects" - objects: "Manager[Any]" = PostgresManager() + objects = PostgresManager() diff --git a/psqlextra/models/partitioned.py b/psqlextra/models/partitioned.py index f0115367..69036040 100644 --- a/psqlextra/models/partitioned.py +++ b/psqlextra/models/partitioned.py @@ -1,5 +1,3 @@ -from typing import Iterable - from django.db.models.base import ModelBase from psqlextra.types import PostgresPartitioningMethod @@ -17,7 +15,7 @@ class PostgresPartitionedModelMeta(ModelBase): """ default_method = PostgresPartitioningMethod.RANGE - default_key: Iterable[str] = [] + default_key = [] def __new__(cls, name, bases, attrs, **kwargs): new_class = super().__new__(cls, name, bases, attrs, **kwargs) @@ -34,14 +32,10 @@ def __new__(cls, name, bases, attrs, **kwargs): return new_class -class PostgresPartitionedModel( - PostgresModel, metaclass=PostgresPartitionedModelMeta -): +class PostgresPartitionedModel(PostgresModel, metaclass=PostgresPartitionedModelMeta): """Base class for taking advantage of PostgreSQL's 11.x native support for table partitioning.""" - _partitioning_meta: PostgresPartitionedModelOptions - class Meta: abstract = True base_manager_name = "objects" diff --git a/psqlextra/query.py b/psqlextra/query.py index 6a86f18e..2756fd8c 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -1,55 +1,19 @@ from collections import OrderedDict from itertools import chain -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Generic, - Iterable, - List, - Optional, - Tuple, - TypeVar, - Union, -) +from typing import Dict, Iterable, List, Optional, Tuple, Union from django.core.exceptions import SuspiciousOperation -from django.db import models, router -from django.db.backends.utils import CursorWrapper -from django.db.models import Expression, Q, QuerySet +from django.db import connections, models, router +from django.db.models import Expression, Q from django.db.models.fields import NOT_PROVIDED -from .expressions import ExcludedCol -from .introspect import model_from_cursor, models_from_cursor from .sql import PostgresInsertQuery, PostgresQuery from .types import ConflictAction -if TYPE_CHECKING: - from django.db.models.constraints import BaseConstraint - from django.db.models.indexes import Index +ConflictTarget = List[Union[str, Tuple[str]]] -ConflictTarget = Union[List[Union[str, Tuple[str]]], "BaseConstraint", "Index"] - -TModel = TypeVar("TModel", bound=models.Model, covariant=True) - -if TYPE_CHECKING: - from typing_extensions import Self - - QuerySetBase = QuerySet[TModel] -else: - QuerySetBase = QuerySet - - -def peek_iterator(iterable): - try: - first = next(iterable) - except StopIteration: - return None - return list(chain([first], iterable)) - - -class PostgresQuerySet(QuerySetBase, Generic[TModel]): +class PostgresQuerySet(models.QuerySet): """Adds support for PostgreSQL specifics.""" def __init__(self, model=None, query=None, using=None, hints=None): @@ -63,9 +27,8 @@ def __init__(self, model=None, query=None, using=None, hints=None): self.conflict_action = None self.conflict_update_condition = None self.index_predicate = None - self.update_values = None - def annotate(self, **annotations) -> "Self": # type: ignore[valid-type, override] + def annotate(self, **annotations): """Custom version of the standard annotate function that allows using field names as annotated fields. @@ -121,7 +84,6 @@ def on_conflict( action: ConflictAction, index_predicate: Optional[Union[Expression, Q, str]] = None, update_condition: Optional[Union[Expression, Q, str]] = None, - update_values: Optional[Dict[str, Union[Any, Expression]]] = None, ): """Sets the action to take when conflicts arise when attempting to insert/create a new row. @@ -139,24 +101,18 @@ def on_conflict( update_condition: Only update if this SQL expression evaluates to true. - - update_values: - Optionally, values/expressions to use when rows - conflict. If not specified, all columns specified - in the rows are updated with the values you specified. """ self.conflict_target = fields self.conflict_action = action self.conflict_update_condition = update_condition self.index_predicate = index_predicate - self.update_values = update_values return self def bulk_insert( self, - rows: Iterable[Dict[str, Any]], + rows: List[dict], return_model: bool = False, using: Optional[str] = None, ): @@ -175,19 +131,18 @@ def bulk_insert( just dicts. using: - Optional name of the database connection to use for + Name of the database connection to use for this query. Returns: A list of either the dicts of the rows inserted, including the pk or the models of the rows inserted with defaults for any fields not specified """ - if rows is None: - return [] - rows = peek_iterator(iter(rows)) + def is_empty(r): + return all([False for _ in r]) - if not rows: + if not rows or is_empty(rows): return [] if not self.conflict_target and not self.conflict_action: @@ -216,17 +171,14 @@ def bulk_insert( deduped_rows.append(row) compiler = self._build_insert_compiler(deduped_rows, using=using) + objs = compiler.execute_sql(return_id=not return_model) + if return_model: + return [ + self._create_model_instance(dict(row, **obj), compiler.using) + for row, obj in zip(deduped_rows, objs) + ] - with compiler.connection.cursor() as cursor: - for sql, params in compiler.as_sql(return_id=not return_model): - cursor.execute(sql, params) - - if return_model: - return list(models_from_cursor(self.model, cursor)) - - return self._consume_cursor_as_dicts( - cursor, original_rows=deduped_rows - ) + return [dict(row, **obj) for row, obj in zip(deduped_rows, objs)] def insert(self, using: Optional[str] = None, **fields): """Creates a new record in the database. @@ -247,20 +199,14 @@ def insert(self, using: Optional[str] = None, **fields): """ if self.conflict_target or self.conflict_action: - if not self.model or not self.model.pk: - return None - compiler = self._build_insert_compiler([fields], using=using) + rows = compiler.execute_sql(return_id=True) - with compiler.connection.cursor() as cursor: - for sql, params in compiler.as_sql(return_id=True): - cursor.execute(sql, params) - - row = cursor.fetchone() - if not row: - return None + _, pk_db_column = self.model._meta.pk.get_attname_column() + if not rows or len(rows) == 0: + return None - return row[0] + return rows[0][pk_db_column] # no special action required, use the standard Django create(..) return super().create(**fields).pk @@ -288,12 +234,30 @@ def insert_and_get(self, using: Optional[str] = None, **fields): return super().create(**fields) compiler = self._build_insert_compiler([fields], using=using) + rows = compiler.execute_sql(return_id=False) + + if not rows: + return None + + columns = rows[0] + + # get a list of columns that are officially part of the model and + # preserve the fact that the attribute name + # might be different than the database column name + model_columns = {} + for field in self.model._meta.local_concrete_fields: + model_columns[field.column] = field.attname - with compiler.connection.cursor() as cursor: - for sql, params in compiler.as_sql(return_id=False): - cursor.execute(sql, params) + # strip out any columns/fields returned by the db that + # are not present in the model + model_init_fields = {} + for column_name, column_value in columns.items(): + try: + model_init_fields[model_columns[column_name]] = column_value + except KeyError: + pass - return model_from_cursor(self.model, cursor) + return self._create_model_instance(model_init_fields, compiler.using) def upsert( self, @@ -302,7 +266,6 @@ def upsert( index_predicate: Optional[Union[Expression, Q, str]] = None, using: Optional[str] = None, update_condition: Optional[Union[Expression, Q, str]] = None, - update_values: Optional[Dict[str, Union[Any, Expression]]] = None, ) -> int: """Creates a new record or updates the existing one with the specified data. @@ -325,27 +288,17 @@ def upsert( update_condition: Only update if this SQL expression evaluates to true. - update_values: - Optionally, values/expressions to use when rows - conflict. If not specified, all columns specified - in the rows are updated with the values you specified. - Returns: The primary key of the row that was created/updated. """ self.on_conflict( conflict_target, - ConflictAction.UPDATE - if (update_condition or update_condition is None) - else ConflictAction.NOTHING, + ConflictAction.UPDATE, index_predicate=index_predicate, update_condition=update_condition, - update_values=update_values, ) - - kwargs = {**fields, "using": using} - return self.insert(**kwargs) + return self.insert(**fields, using=using) def upsert_and_get( self, @@ -354,7 +307,6 @@ def upsert_and_get( index_predicate: Optional[Union[Expression, Q, str]] = None, using: Optional[str] = None, update_condition: Optional[Union[Expression, Q, str]] = None, - update_values: Optional[Dict[str, Union[Any, Expression]]] = None, ): """Creates a new record or updates the existing one with the specified data and then gets the row. @@ -377,11 +329,6 @@ def upsert_and_get( update_condition: Only update if this SQL expression evaluates to true. - update_values: - Optionally, values/expressions to use when rows - conflict. If not specified, all columns specified - in the rows are updated with the values you specified. - Returns: The model instance representing the row that was created/updated. @@ -392,11 +339,8 @@ def upsert_and_get( ConflictAction.UPDATE, index_predicate=index_predicate, update_condition=update_condition, - update_values=update_values, ) - - kwargs = {**fields, "using": using} - return self.insert_and_get(**kwargs) + return self.insert_and_get(**fields, using=using) def bulk_upsert( self, @@ -406,7 +350,6 @@ def bulk_upsert( return_model: bool = False, using: Optional[str] = None, update_condition: Optional[Union[Expression, Q, str]] = None, - update_values: Optional[Dict[str, Union[Any, Expression]]] = None, ): """Creates a set of new records or updates the existing ones with the specified data. @@ -433,11 +376,6 @@ def bulk_upsert( update_condition: Only update if this SQL expression evaluates to true. - update_values: - Optionally, values/expressions to use when rows - conflict. If not specified, all columns specified - in the rows are updated with the values you specified. - Returns: A list of either the dicts of the rows upserted, including the pk or the models of the rows upserted @@ -448,28 +386,46 @@ def bulk_upsert( ConflictAction.UPDATE, index_predicate=index_predicate, update_condition=update_condition, - update_values=update_values, ) - return self.bulk_insert(rows, return_model, using=using) - @staticmethod - def _consume_cursor_as_dicts( - cursor: CursorWrapper, *, original_rows: Iterable[Dict[str, Any]] - ) -> List[dict]: - cursor_description = cursor.description - - return [ - { - **original_row, - **{ - column.name: row[column_index] - for column_index, column in enumerate(cursor_description) - if row - }, - } - for original_row, row in zip(original_rows, cursor) - ] + def _create_model_instance( + self, field_values: dict, using: str, apply_converters: bool = True + ): + """Creates a new instance of the model with the specified field. + + Use this after the row was inserted into the database. The new + instance will marked as "saved". + """ + + converted_field_values = field_values.copy() + + if apply_converters: + connection = connections[using] + + for field in self.model._meta.local_concrete_fields: + if field.attname not in converted_field_values: + continue + + # converters can be defined on the field, or by + # the database back-end we're using + field_column = field.get_col(self.model._meta.db_table) + converters = field.get_db_converters( + connection + ) + connection.ops.get_db_converters(field_column) + + for converter in converters: + converted_field_values[field.attname] = converter( + converted_field_values[field.attname], + field_column, + connection, + ) + + instance = self.model(**converted_field_values) + instance._state.db = using + instance._state.adding = False + + return instance def _build_insert_compiler( self, rows: Iterable[Dict], using: Optional[str] = None @@ -491,7 +447,7 @@ def _build_insert_compiler( # ask the db router which connection to use using = ( - using or self._db or router.db_for_write(self.model, **self._hints) # type: ignore[attr-defined] + using or self._db or router.db_for_write(self.model, **self._hints) ) # create model objects, we also have to detect cases @@ -513,17 +469,12 @@ def _build_insert_compiler( ).format(index) ) - obj = self.model(**row.copy()) - obj._state.db = using - obj._state.adding = False - objs.append(obj) + objs.append( + self._create_model_instance(row, using, apply_converters=False) + ) # get the fields to be used during update/insert - insert_fields, update_values = self._get_upsert_fields(first_row) - - # allow the user to override what should happen on update - if self.update_values is not None: - update_values = self.update_values + insert_fields, update_fields = self._get_upsert_fields(first_row) # build a normal insert query query = PostgresInsertQuery(self.model) @@ -531,7 +482,7 @@ def _build_insert_compiler( query.conflict_target = self.conflict_target query.conflict_update_condition = self.conflict_update_condition query.index_predicate = self.index_predicate - query.insert_on_conflict_values(objs, insert_fields, update_values) + query.values(objs, insert_fields, update_fields) compiler = query.get_compiler(using) return compiler @@ -596,13 +547,13 @@ def _get_upsert_fields(self, kwargs): model_instance = self.model(**kwargs) insert_fields = [] - update_values = {} + update_fields = [] for field in model_instance._meta.local_concrete_fields: has_default = field.default != NOT_PROVIDED if field.name in kwargs or field.column in kwargs: insert_fields.append(field) - update_values[field.name] = ExcludedCol(field) + update_fields.append(field) continue elif has_default: insert_fields.append(field) @@ -613,13 +564,13 @@ def _get_upsert_fields(self, kwargs): # instead of a concrete field, we have to handle that if field.primary_key is True and "pk" in kwargs: insert_fields.append(field) - update_values[field.name] = ExcludedCol(field) + update_fields.append(field) continue if self._is_magical_field(model_instance, field, is_insert=True): insert_fields.append(field) if self._is_magical_field(model_instance, field, is_insert=False): - update_values[field.name] = ExcludedCol(field) + update_fields.append(field) - return insert_fields, update_values + return insert_fields, update_fields diff --git a/psqlextra/sql.py b/psqlextra/sql.py index cf12d8c1..25c8314e 100644 --- a/psqlextra/sql.py +++ b/psqlextra/sql.py @@ -1,14 +1,12 @@ from collections import OrderedDict -from collections.abc import Iterable -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple import django from django.core.exceptions import SuspiciousOperation from django.db import connections, models -from django.db.models import Expression, sql +from django.db.models import sql from django.db.models.constants import LOOKUP_SEP -from django.db.models.expressions import Ref from .compiler import PostgresInsertOnConflictCompiler from .compiler import SQLUpdateCompiler as PostgresUpdateCompiler @@ -18,8 +16,6 @@ class PostgresQuery(sql.Query): - select: Tuple[Expression, ...] - def chain(self, klass=None): """Chains this query to another. @@ -66,28 +62,13 @@ def rename_annotations(self, annotations) -> None: new_annotations[new_name or old_name] = annotation if new_name and self.annotation_select_mask: - # It's a set in all versions prior to Django 5.x - # and a list in Django 5.x and newer. - # https://github.com/django/django/commit/d6b6e5d0fd4e6b6d0183b4cf6e4bd4f9afc7bf67 - if isinstance(self.annotation_select_mask, set): - self.annotation_select_mask.discard(old_name) - self.annotation_select_mask.add(new_name) - elif isinstance(self.annotation_select_mask, list): - self.annotation_select_mask.remove(old_name) - self.annotation_select_mask.append(new_name) - - if isinstance(self.group_by, Iterable): - for statement in self.group_by: - if not isinstance(statement, Ref): - continue - - if statement.refs in annotations: # type: ignore[attr-defined] - statement.refs = annotations[statement.refs] # type: ignore[attr-defined] + self.annotation_select_mask.discard(old_name) + self.annotation_select_mask.add(new_name) self.annotations.clear() self.annotations.update(new_annotations) - def add_fields(self, field_names, *args, **kwargs) -> None: + def add_fields(self, field_names: List[str], *args, **kwargs) -> None: """Adds the given (model) fields to the select set. The field names are added in the order specified. This overrides @@ -119,11 +100,10 @@ def add_fields(self, field_names, *args, **kwargs) -> None: if len(parts) > 1: column_name, hstore_key = parts[:2] is_hstore, field = self._is_hstore_field(column_name) - if self.model and is_hstore: + if is_hstore: select.append( HStoreColumn( - self.model._meta.db_table - or self.model.__class__.__name__, + self.model._meta.db_table or self.model.name, field, hstore_key, ) @@ -135,7 +115,7 @@ def add_fields(self, field_names, *args, **kwargs) -> None: super().add_fields(field_names_without_hstore, *args, **kwargs) if len(select) > 0: - self.set_select(list(self.select + tuple(select))) + self.set_select(self.select + tuple(select)) def _is_hstore_field( self, field_name: str @@ -147,11 +127,8 @@ def _is_hstore_field( instance. """ - if not self.model: - return (False, None) - field_instance = None - for field in self.model._meta.local_concrete_fields: # type: ignore[attr-defined] + for field in self.model._meta.local_concrete_fields: if field.name == field_name or field.column == field_name: field_instance = field break @@ -171,14 +148,10 @@ def __init__(self, *args, **kwargs): self.conflict_action = ConflictAction.UPDATE self.conflict_update_condition = None self.index_predicate = None - self.update_values = {} - - def insert_on_conflict_values( - self, - objs: List, - insert_fields: List, - update_values: Dict[str, Union[Any, Expression]] = {}, - ): + + self.update_fields = [] + + def values(self, objs: List, insert_fields: List, update_fields: List = []): """Sets the values to be used in this query. Insert fields are fields that are definitely @@ -197,13 +170,12 @@ def insert_on_conflict_values( insert_fields: The fields to use in the INSERT statement - update_values: - Expressions/values to use when a conflict - occurs and an UPDATE is performed. + update_fields: + The fields to only use in the UPDATE statement. """ self.insert_values(insert_fields, objs, raw=False) - self.update_values = update_values + self.update_fields = update_fields def get_compiler(self, using=None, connection=None): if using: diff --git a/psqlextra/types.py b/psqlextra/types.py index f1118075..a325fd9e 100644 --- a/psqlextra/types.py +++ b/psqlextra/types.py @@ -28,9 +28,6 @@ class ConflictAction(Enum): def all(cls) -> List["ConflictAction"]: return [choice for choice in cls] - def __str__(self) -> str: - return self.value - class PostgresPartitioningMethod(StrEnum): """Methods of partitioning supported by PostgreSQL 11.x native support for