diff --git a/.circleci/config.yml b/.circleci/config.yml index 23792cd1..92d9093b 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -6,8 +6,8 @@ executors: version: type: string docker: - - image: python:<< parameters.version >>-alpine - - image: postgres:11.0 + - image: python:<< parameters.version >>-buster + - image: postgres:13.0 environment: POSTGRES_DB: 'psqlextra' POSTGRES_USER: 'psqlextra' @@ -22,27 +22,24 @@ commands: steps: - run: name: Install packages - command: apk add postgresql-libs gcc musl-dev postgresql-dev git + command: apt-get update && apt-get install -y --no-install-recommends postgresql-client-11 libpq-dev build-essential git - run: name: Install Python packages - command: pip install --progress-bar off .[<< parameters.extra >>] + command: pip install --progress-bar off '.[<< parameters.extra >>]' run-tests: parameters: pyversion: type: integer - djversions: - type: string steps: - run: name: Run tests - command: tox -e 'py<< parameters.pyversion >>-dj{<< parameters.djversions >>}' + command: tox --listenvs | grep ^py<< parameters.pyversion >> | circleci tests split | xargs -n 1 tox -e environment: DATABASE_URL: 'postgres://psqlextra:psqlextra@localhost:5432/psqlextra' - jobs: test-python36: executor: @@ -54,7 +51,6 @@ jobs: extra: test - run-tests: pyversion: 36 - djversions: 20,21,22,30,31,32 test-python37: executor: @@ -66,7 +62,6 @@ jobs: extra: test - run-tests: pyversion: 37 - djversions: 20,21,22,30,31,32 test-python38: executor: @@ -78,7 +73,6 @@ jobs: extra: test - run-tests: pyversion: 38 - djversions: 20,21,22,30,31,32 test-python39: executor: @@ -90,7 +84,6 @@ jobs: extra: test - run-tests: pyversion: 39 - djversions: 21,22,30,31,32 test-python310: executor: @@ -102,7 +95,17 @@ jobs: extra: test - run-tests: pyversion: 310 - djversions: 21,22,30,31,32 + + test-python311: + executor: + name: python + version: "3.11" + steps: + - checkout + - install-dependencies: + extra: test + - run-tests: + pyversion: 311 - store_test_results: path: reports - run: @@ -116,19 +119,84 @@ jobs: steps: - checkout - install-dependencies: - extra: analysis + extra: analysis, test - run: name: Verify command: python setup.py verify + publish: + executor: + name: python + version: "3.9" + steps: + - checkout + - install-dependencies: + extra: publish + - run: + name: Set version number + command: echo "__version__ = \"${CIRCLE_TAG:1}\"" > psqlextra/_version.py + - run: + name: Build package + command: python -m build + - run: + name: Publish package + command: > + python -m twine upload + --username "__token__" + --password "${PYPI_API_TOKEN}" + --verbose + --non-interactive + --disable-progress-bar + dist/* workflows: - version: 2 build: jobs: - - test-python36 - - test-python37 - - test-python38 - - test-python39 - - test-python310 - - analysis + - test-python36: + filters: + tags: + only: /.*/ + branches: + only: /.*/ + - test-python37: + filters: + tags: + only: /.*/ + branches: + only: /.*/ + - test-python38: + filters: + tags: + only: /.*/ + branches: + only: /.*/ + - test-python39: + filters: + tags: + only: /.*/ + branches: + only: /.*/ + - test-python310: + filters: + tags: + only: /.*/ + branches: + only: /.*/ + - test-python311: + filters: + tags: + only: /.*/ + branches: + only: /.*/ + - analysis: + filters: + tags: + only: /.*/ + branches: + only: /.*/ + - publish: + filters: + tags: + only: /^v.*/ + branches: + ignore: /.*/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f51f66b2..cd0836a9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -16,11 +16,4 @@ If you're unsure whether your change would be a good fit for `django-postgres-ex * PyLint passes. * PEP8 passes. * Features that allow creating custom indexes or fields must also implement the associated migrations. `django-postgres-extra` prides itself on the fact that it integrates smoothly with Django migrations. We'd like to keep it that way for all features. -* Sufficiently complicated changes must be accomponied by tests. - -## Our promise -* We'll promise to reply to each pull request within 24 hours of submission. -* We'll let you know whether we welcome the change or not within that timeframe. - * This avoids you wasting time on a feature that we feel is not a good fit. - -We feel that these promises are fair to whomever decides its worth spending their free time to contribute to `django-postgres-extra`. Please do let us know if you feel we are not living up to these promises. +* Sufficiently complicated changes must be accompanied by tests. diff --git a/README.md b/README.md index d98b4407..17037d87 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,9 @@ | :memo: | **License** | [![License](https://img.shields.io/:license-mit-blue.svg)](http://doge.mit-license.org) | | :package: | **PyPi** | [![PyPi](https://badge.fury.io/py/django-postgres-extra.svg)](https://pypi.python.org/pypi/django-postgres-extra) | | :four_leaf_clover: | **Code coverage** | [![Coverage Status](https://coveralls.io/repos/github/SectorLabs/django-postgres-extra/badge.svg?branch=coveralls)](https://coveralls.io/github/SectorLabs/django-postgres-extra?branch=master) | -| | **Django Versions** | 2.0, 2.1, 2.2, 3.0, 3.1, 3.2 | -| | **Python Versions** | 3.6, 3.7, 3.8, 3.9, 3.10 | +| | **Django Versions** | 2.0, 2.1, 2.2, 3.0, 3.1, 3.2, 4.0, 4.1, 4.2, 5.0 | +| | **Python Versions** | 3.6, 3.7, 3.8, 3.9, 3.10, 3.11 | +| | **Psycopg Versions** | 2, 3 | | :book: | **Documentation** | [Read The Docs](https://django-postgres-extra.readthedocs.io/en/master/) | | :warning: | **Upgrade** | [Upgrade from v1.x](https://django-postgres-extra.readthedocs.io/en/master/major_releases.html#new-features) | :checkered_flag: | **Installation** | [Installation Guide](https://django-postgres-extra.readthedocs.io/en/master/installation.html) | @@ -59,7 +60,7 @@ With seamless we mean that any features we add will work truly seamlessly. You s ### Prerequisites * PostgreSQL 10 or newer. -* Django 2.0 or newer (including 3.x). +* Django 2.0 or newer (including 3.x, 4.x). * Python 3.6 or newer. ### Getting started diff --git a/docs/source/annotations.rst b/docs/source/annotations.rst index d9431510..1f0c2847 100644 --- a/docs/source/annotations.rst +++ b/docs/source/annotations.rst @@ -10,7 +10,7 @@ Annotations Renaming annotations -------------------- -Django does allow you to create an annotation that conflicts with a field on the model. :meth:`psqlextra.query.QuerySet.rename_annotation` makes it possible to do just that. +Django does not allow you to create an annotation that conflicts with a field on the model. :meth:`psqlextra.query.QuerySet.rename_annotation` makes it possible to do just that. .. code-block:: python diff --git a/docs/source/api_reference.rst b/docs/source/api_reference.rst index 70e50ab6..7f175fe9 100644 --- a/docs/source/api_reference.rst +++ b/docs/source/api_reference.rst @@ -34,9 +34,23 @@ API Reference .. automodule:: psqlextra.indexes .. autoclass:: UniqueIndex + .. autoclass:: ConditionalUniqueIndex + .. autoclass:: CaseInsensitiveUniqueIndex +.. automodule:: psqlextra.locking + :members: + +.. automodule:: psqlextra.schema + :members: + +.. automodule:: psqlextra.partitioning + :members: + +.. automodule:: psqlextra.backend.migrations.operations + :members: + .. automodule:: psqlextra.types :members: :undoc-members: diff --git a/docs/source/conflict_handling.rst b/docs/source/conflict_handling.rst index f5108742..cb9423a9 100644 --- a/docs/source/conflict_handling.rst +++ b/docs/source/conflict_handling.rst @@ -87,6 +87,41 @@ Specifying multiple columns is necessary in case of a constraint that spans mult ) +Specific constraint +******************* + +Alternatively, instead of specifying the columns the constraint you're targetting applies to, you can also specify the exact constraint to use: + +.. code-block:: python + + from django.db import models + from psqlextra.models import PostgresModel + + class MyModel(PostgresModel) + class Meta: + constraints = [ + models.UniqueConstraint( + name="myconstraint", + fields=["first_name", "last_name"] + ), + ] + + first_name = models.CharField(max_length=255) + last_name = models.CharField(max_length=255) + + constraint = next( + constraint + for constraint in MyModel._meta.constraints + if constraint.name == "myconstraint" + ), None) + + obj = ( + MyModel.objects + .on_conflict(constraint, ConflictAction.UPDATE) + .insert_and_get(first_name='Henk', last_name='Jansen') + ) + + HStore keys *********** Catching conflicts in columns with a ``UNIQUE`` constraint on a :class:`~psqlextra.fields.HStoreField` key is also supported: @@ -197,6 +232,42 @@ Alternatively, with Django 3.1 or newer, :class:`~django:django.db.models.Q` obj Q(name__gt=ExcludedCol('priority')) +Update values +""""""""""""" + +Optionally, the fields to update can be overriden. The default is to update the same fields that were specified in the rows to insert. + +Refer to the insert values using the :class:`psqlextra.expressions.ExcludedCol` expression which translates to PostgreSQL's ``EXCLUDED.`` expression. All expressions and features that can be used with Django's :meth:`~django:django.db.models.query.QuerySet.update` can be used here. + +.. warning:: + + Specifying an empty ``update_values`` (``{}``) will transform the query into :attr:`~psqlextra.types.ConflictAction.NOTHING`. Only ``None`` makes the default behaviour kick in of updating all fields that were specified. + +.. code-block:: python + + from django.db.models import F + + from psqlextra.expressions import ExcludedCol + + ( + MyModel + .objects + .on_conflict( + ['name'], + ConflictAction.UPDATE, + update_values=dict( + name=ExcludedCol('name'), + count=F('count') + 1, + ), + ) + .insert( + name='henk', + count=0, + ) + ) + + + ConflictAction.NOTHING ********************** @@ -219,7 +290,7 @@ This is preferable when the data you're about to insert is the same as the one t # obj2 is none! object alreaddy exists obj2 = MyModel.objects.on_conflict(['name'], ConflictAction.NOTHING).insert(name="me") -This applies to both :meth:`~psqlextra.query.PostgresQuerySet.insert` and :meth:`~psqlextra.query.PostgresQuerySet.bulk_insert` + This applies all methods: :meth:`~psqlextra.query.PostgresQuerySet.insert`, :meth:`~psqlextra.query.PostgresQuerySet.insert_and_get`, :meth:`~psqlextra.query.PostgresQuerySet.bulk_insert` Bulk diff --git a/docs/source/deletion.rst b/docs/source/deletion.rst index c27cdcb6..9308594c 100644 --- a/docs/source/deletion.rst +++ b/docs/source/deletion.rst @@ -48,3 +48,28 @@ By default, Postgres will raise an error if any other table is referencing one o MyModel.objects.truncate(cascade=True) print(MyModel1.objects.count()) # zero records left print(MyModel2.objects.count()) # zero records left + + +Restart identity +**************** + +If specified, any sequences on the table will be restarted. + +.. code-block:: python + + from django.db import models + from psqlextra.models import PostgresModel + + class MyModel(PostgresModel): + pass + + mymodel = MyModel.objects.create() + assert mymodel.id == 1 + + MyModel.objects.truncate(restart_identity=True) # table is empty after this + print(MyModel.objects.count()) # zero records left + + # Create a new row, it should get ID 1 again because + # the sequence got restarted. + mymodel = MyModel.objects.create() + assert mymodel.id == 1 diff --git a/docs/source/index.rst b/docs/source/index.rst index 28b61560..1959016e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -35,6 +35,15 @@ Explore the documentation to learn about all features: Support for ``TRUNCATE TABLE`` statements (including cascading). +* :ref:`Locking models & tables ` + + Support for explicit table-level locks. + + +* :ref:`Creating/dropping schemas ` + + Support for managing Postgres schemas. + .. toctree:: :maxdepth: 2 @@ -49,6 +58,8 @@ Explore the documentation to learn about all features: table_partitioning expressions annotations + locking + schemas settings api_reference major_releases diff --git a/docs/source/locking.rst b/docs/source/locking.rst new file mode 100644 index 00000000..8cf8cf8e --- /dev/null +++ b/docs/source/locking.rst @@ -0,0 +1,56 @@ +.. include:: ./snippets/postgres_doc_links.rst + +.. _locking_page: + +Locking +======= + +`Explicit table-level locks`_ are supported through the :meth:`psqlextra.locking.postgres_lock_model` and :meth:`psqlextra.locking.postgres_lock_table` methods. All table-level lock methods are supported. + +Locks are always bound to the current transaction and are released when the transaction is committed or rolled back. There is no support (in PostgreSQL) for explicitly releasing a lock. + +.. warning:: + + Locks are only released when the *outer* transaction commits or when a nested transaction is rolled back. You can ensure that the transaction you created is the outermost one by passing the ``durable=True`` argument to ``transaction.atomic``. + +.. note:: + + Use `django-pglocks `_ if you need a advisory lock. + +Locking a model +--------------- + +Use :class:`psqlextra.locking.PostgresTableLockMode` to indicate the type of lock to acquire. + +.. code-block:: python + + from django.db import transaction + + from psqlextra.locking import PostgresTableLockMode, postgres_lock_table + + with transaction.atomic(durable=True): + postgres_lock_model(MyModel, PostgresTableLockMode.EXCLUSIVE) + + # locks are released here, when the transaction committed + + +Locking a table +--------------- + +Use :meth:`psqlextra.locking.postgres_lock_table` to lock arbitrary tables in arbitrary schemas. + +.. code-block:: python + + from django.db import transaction + + from psqlextra.locking import PostgresTableLockMode, postgres_lock_table + + with transaction.atomic(durable=True): + postgres_lock_table("mytable", PostgresTableLockMode.EXCLUSIVE) + postgres_lock_table( + "tableinotherschema", + PostgresTableLockMode.EXCLUSIVE, + schema_name="myschema" + ) + + # locks are released here, when the transaction committed diff --git a/docs/source/schemas.rst b/docs/source/schemas.rst new file mode 100644 index 00000000..01fdd345 --- /dev/null +++ b/docs/source/schemas.rst @@ -0,0 +1,155 @@ +.. include:: ./snippets/postgres_doc_links.rst + +.. _schemas_page: + +Schema +====== + +The :meth:`~psqlextra.schema.PostgresSchema` class provides basic schema management functionality. + +Django does **NOT** support custom schemas. This module does not attempt to solve that problem. + +This module merely allows you to create/drop schemas and allow you to execute raw SQL in a schema. It is not attempt at bringing multi-schema support to Django. + + +Reference an existing schema +---------------------------- + +.. code-block:: python + + for psqlextra.schema import PostgresSchema + + schema = PostgresSchema("myschema") + + with schema.connection.cursor() as cursor: + cursor.execute("SELECT * FROM tablethatexistsinmyschema") + + +Checking if a schema exists +--------------------------- + +.. code-block:: python + + for psqlextra.schema import PostgresSchema + + schema = PostgresSchema("myschema") + if PostgresSchema.exists("myschema"): + print("exists!") + else: + print('does not exist!") + + +Creating a new schema +--------------------- + +With a custom name +****************** + +.. code-block:: python + + for psqlextra.schema import PostgresSchema + + # will raise an error if the schema already exists + schema = PostgresSchema.create("myschema") + + +Re-create if necessary with a custom name +***************************************** + +.. warning:: + + If the schema already exists and it is non-empty or something is referencing it, it will **NOT** be dropped. Specify ``cascade=True`` to drop all of the schema's contents and **anything referencing it**. + +.. code-block:: python + + for psqlextra.schema import PostgresSchema + + # will drop existing schema named `myschema` if it + # exists and re-create it + schema = PostgresSchema.drop_and_create("myschema") + + # will drop the schema and cascade it to its contents + # and anything referencing the schema + schema = PostgresSchema.drop_and_create("otherschema", cascade=True) + + +With a time-based name +********************** + +.. warning:: + + The time-based suffix is precise up to the second. If two threads or processes both try to create a time-based schema name with the same suffix in the same second, they will have conflicts. + +.. code-block:: python + + for psqlextra.schema import PostgresSchema + + # schema name will be "myprefix_" + schema = PostgresSchema.create_time_based("myprefix") + print(schema.name) + + +With a random name +****************** + +A 8 character suffix is appended. Entropy is dependent on your system. See :meth:`~os.urandom` for more information. + +.. code-block:: python + + for psqlextra.schema import PostgresSchema + + # schema name will be "myprefix_<8 random characters>" + schema = PostgresSchema.create_random("myprefix") + print(schema.name) + + +Temporary schema with random name +********************************* + +Use the :meth:`~psqlextra.schema.postgres_temporary_schema` context manager to create a schema with a random name. The schema will only exist within the context manager. + +By default, the schema is not dropped if an exception occurs in the context manager. This prevents unexpected data loss. Specify ``drop_on_throw=True`` to drop the schema if an exception occurs. + +Without an outer transaction, the temporary schema might not be dropped when your program is exits unexpectedly (for example; if it is killed with SIGKILL). Wrap the creation of the schema in a transaction to make sure the schema is cleaned up when an error occurs or your program exits suddenly. + +.. warning:: + + By default, the drop will fail if the schema is not empty or there is anything referencing the schema. Specify ``cascade=True`` to drop all of the schema's contents and **anything referencing it**. + +.. code-block:: python + + for psqlextra.schema import postgres_temporary_schema + + with postgres_temporary_schema("myprefix") as schema: + pass + + with postgres_temporary_schema("otherprefix", drop_on_throw=True) as schema: + raise ValueError("drop it like it's hot") + + with postgres_temporary_schema("greatprefix", cascade=True) as schema: + with schema.connection.cursor() as cursor: + cursor.execute(f"CREATE TABLE {schema.name} AS SELECT 'hello'") + + with postgres_temporary_schema("amazingprefix", drop_on_throw=True, cascade=True) as schema: + with schema.connection.cursor() as cursor: + cursor.execute(f"CREATE TABLE {schema.name} AS SELECT 'hello'") + + raise ValueError("oops") + +Deleting a schema +----------------- + +Any schema can be dropped, including ones not created by :class:`~psqlextra.schema.PostgresSchema`. + +The ``public`` schema cannot be dropped. This is a Postgres built-in and it is almost always a mistake to drop it. A :class:`~django.core.exceptions.SuspiciousOperation` erorr will be raised if you attempt to drop the ``public`` schema. + +.. warning:: + + By default, the drop will fail if the schema is not empty or there is anything referencing the schema. Specify ``cascade=True`` to drop all of the schema's contents and **anything referencing it**. + +.. code-block:: python + + for psqlextra.schema import PostgresSchema + + schema = PostgresSchema.drop("myprefix") + schema = PostgresSchema.drop("myprefix", cascade=True) diff --git a/docs/source/settings.rst b/docs/source/settings.rst index 5ab02fe2..662f9376 100644 --- a/docs/source/settings.rst +++ b/docs/source/settings.rst @@ -9,7 +9,7 @@ Settings ``DATABASES[db_name]['ENGINE']`` must be set to ``"psqlextra.backend"``. If you're already using a custom back-end, set ``POSTGRES_EXTRA_DB_BACKEND_BASE`` to your custom back-end. This will instruct ``django-postgres-extra`` to wrap the back-end you specified. - A good example of where this might be need is if you are using the PostGIS back-end: ``django.contrib.db.backends.postgis``. + A good example of where this might be need is if you are using the PostGIS back-end: ``django.contrib.gis.db.backends.postgis``. **Default value**: ``django.db.backends.postgresql`` @@ -28,3 +28,13 @@ Settings .. note:: If set to ``False``, you must ensure that the ``hstore`` extension is enabled on your database manually. If not enabled, any ``hstore`` related functionality will not work. + +.. _POSTGRES_EXTRA_ANNOTATE_SQL_: + +* ``POSTGRES_EXTRA_ANNOTATE_SQL`` + + If set to ``True``, will append a comment to all SQL queries with the path and line number that the query was made from. + + Format: ``/* */`` + + This can be useful when debugging queries found in PostgreSQL's ``pg_stat_activity`` or in its query log. diff --git a/docs/source/snippets/postgres_doc_links.rst b/docs/source/snippets/postgres_doc_links.rst index 90ebb51c..fe0f4d76 100644 --- a/docs/source/snippets/postgres_doc_links.rst +++ b/docs/source/snippets/postgres_doc_links.rst @@ -2,3 +2,4 @@ .. _TRUNCATE TABLE: https://www.postgresql.org/docs/9.1/sql-truncate.html .. _hstore: https://www.postgresql.org/docs/11/hstore.html .. _PostgreSQL Declarative Table Partitioning: https://www.postgresql.org/docs/current/ddl-partitioning.html#DDL-PARTITIONING-DECLARATIVE +.. _Explicit table-level locks: https://www.postgresql.org/docs/current/explicit-locking.html#LOCKING-TABLES diff --git a/docs/source/table_partitioning.rst b/docs/source/table_partitioning.rst index 5a5f572b..1bb5ba6f 100644 --- a/docs/source/table_partitioning.rst +++ b/docs/source/table_partitioning.rst @@ -80,18 +80,131 @@ This will generate a migration that creates the partitioned table with a default Do not use the standard ``python manage.py makemigrations`` command for partitioned models. Django will issue a standard :class:`~django:django.db.migrations.operations.CreateModel` operation. Doing this will not create a partitioned table and all subsequent operations will fail. -Adding/removing partitions manually ------------------------------------ +Automatically managing partitions +--------------------------------- -Postgres does not have support for automatically creating new partitions as needed. Therefore, one must manually add new partitions. Depending on the partitioning method you have chosen, the partition has to be created differently. +The ``python manage.py pgpartition`` command can help you automatically create new partitions ahead of time and delete old ones for time-based partitioning. -Partitions are tables. Each partition must be given a unique name. :class:`~psqlextra.models.PostgresPartitionedModel` does not require you to create a model for each partition because you are not supposed to query partitions directly. +You can run this command manually as needed, schedule to run it periodically or run it every time you release a new version of your app. +.. warning:: + + We DO NOT recommend that you set up this command to automatically delete partitions without manual review. + + Specify ``--skip-delete`` to not delete partitions automatically. Run the command manually periodically without the ``--yes`` flag to review partitions to be deleted. + + +Command-line options +******************** + + ==================== ============= ================ ==================================================================================================== === === === === === === + Long flag Short flag Default Description + ==================== ============= ================ ==================================================================================================== === === === === === === + ``--yes`` ``-y`` ``False`` Specifies yes to all questions. You will NOT be asked for confirmation before partition deletion. + ``--using`` ``-u`` ``'default'`` Optional name of the database connection to use. + ``--skip-create`` ``False`` Whether to skip creating partitions. + ``--skip-delete`` ``False`` Whether to skip deleting partitions. + + ==================== ============= ================ ==================================================================================================== === === === === === === + + +Configuration +************* + +In order to use the command, you have to declare an instance of :class:`psqlextra.partitioning.PostgresPartitioningManager` and set ``PSQLEXTRA_PARTITIONING_MANAGER`` to a string with the import path to your instance of :class:`psqlextra.partitioning.PostgresPartitioningManager`. + +For example: + +.. code-block:: python + + # myapp/partitioning.py + from psqlextra.partitioning import PostgresPartitioningManager + + manager = PostgresPartitioningManager(...) + + # myapp/settings.py + PSQLEXTRA_PARTITIONING_MANAGER = 'myapp.partitioning.manager' + + +Time-based partitioning +~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + from dateutil.relativedelta import relativedelta + + from psqlextra.partitioning import ( + PostgresPartitioningManager, + PostgresCurrentTimePartitioningStrategy, + PostgresTimePartitionSize, + partition_by_current_time, + ) + from psqlextra.partitioning.config import PostgresPartitioningConfig + + manager = PostgresPartitioningManager([ + # 3 partitions ahead, each partition is one month + # delete partitions older than 6 months + # partitions will be named `[table_name]_[year]_[3-letter month name]`. + PostgresPartitioningConfig( + model=MyPartitionedModel, + strategy=PostgresCurrentTimePartitioningStrategy( + size=PostgresTimePartitionSize(months=1), + count=3, + max_age=relativedelta(months=6), + ), + ), + # 6 partitions ahead, each partition is two weeks + # delete partitions older than 8 months + # partitions will be named `[table_name]_[year]_week_[week number]`. + PostgresPartitioningConfig( + model=MyPartitionedModel, + strategy=PostgresCurrentTimePartitioningStrategy( + size=PostgresTimePartitionSize(weeks=2), + count=6, + max_age=relativedelta(months=8), + ), + ), + # 12 partitions ahead, each partition is 5 days + # old partitions are never deleted, `max_age` is not set + # partitions will be named `[table_name]_[year]_[month]_[month day number]`. + PostgresPartitioningConfig( + model=MyPartitionedModel, + strategy=PostgresCurrentTimePartitioningStrategy( + size=PostgresTimePartitionSize(days=5), + count=12, + ), + ), + ]) + + +Changing a time partitioning strategy +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When switching partitioning strategies, you might encounter the problem that partitions for part of a particular range already exist. + +In order to combat this, you can use the :class:`psqlextra.partitioning.PostgresTimePartitioningStrategy` and specify the `start_datetime` parameter. As a result, no partitions will be created before the given date/time. + + +Custom strategy +~~~~~~~~~~~~~~~ + +You can create a custom partitioning strategy by implementing the :class:`psqlextra.partitioning.PostgresPartitioningStrategy` interface. + +You can look at :class:`psqlextra.partitioning.PostgresCurrentTimePartitioningStrategy` as an example. + + +Manually managing partitions +---------------------------- + +If you are using list or hash partitioning, you most likely have a fixed amount of partitions that can be created up front using migrations or using the schema editor. + +Using migration operations +************************** Adding a range partition ~~~~~~~~~~~~~~~~~~~~~~~~ -Use the :class:`~psqlextra.backend.migrations.operations.PostgresAddRangePartition` operation to add a new range partition. Only use this operation when your partitioned model uses the :attr:`psqlextra.types.PostgresPartitioningMethod.RANGE`. +Use the :class:`~psqlextra.backend.migrations.operations.PostgresAddRangePartition` operation to add a new range partition. Only use this operation when your partitioned model uses :attr:`psqlextra.types.PostgresPartitioningMethod.RANGE`. .. code-block:: python @@ -113,7 +226,7 @@ Use the :class:`~psqlextra.backend.migrations.operations.PostgresAddRangePartiti Adding a list partition ~~~~~~~~~~~~~~~~~~~~~~~ -Use the :class:`~psqlextra.backend.migrations.operations.PostgresAddListPartition` operation to add a new list partition. Only use this operation when your partitioned model uses the :attr:`psqlextra.types.PostgresPartitioningMethod.LIST`. +Use the :class:`~psqlextra.backend.migrations.operations.PostgresAddListPartition` operation to add a new list partition. Only use this operation when your partitioned model uses :attr:`psqlextra.types.PostgresPartitioningMethod.LIST`. .. code-block:: python @@ -131,12 +244,36 @@ Use the :class:`~psqlextra.backend.migrations.operations.PostgresAddListPartitio ] +Adding a hash partition +~~~~~~~~~~~~~~~~~~~~~~~ + +Use the :class:`~psqlextra.backend.migrations.operations.PostgresAddHashPartition` operation to add a new list partition. Only use this operation when your partitioned model uses :attr:`psqlextra.types.PostgresPartitioningMethod.HASH`. + +.. code-block:: python + + from django.db import migrations, models + + from psqlextra.backend.migrations.operations import PostgresAddHashPartition + + class Migration(migrations.Migration): + operations = [ + PostgresAddHashPartition( + model_name="mypartitionedmodel", + name="pt1", + modulus=3, + remainder=1, + ), + ] + + Adding a default partition ~~~~~~~~~~~~~~~~~~~~~~~~~~ -Use the :class:`~psqlextra.backend.migrations.operations.PostgresAddDefaultPartition` operation to add a new default partition. A default partition is the partition where records get saved that couldn't fit in any other partition. +Use the :class:`~psqlextra.backend.migrations.operations.PostgresAddDefaultPartition` operation to add a new list partition. + +Note that you can only have one default partition per partitioned table/model. An error will be thrown if you try to create a second default partition. -Note that you can only have one default partition per partitioned table/model. +If you used ``python manage.py pgmakemigrations`` to generate a migration for your newly created partitioned model, you do not need this operation. This operation is added automatically when you create a new partitioned model. .. code-block:: python @@ -158,6 +295,12 @@ Deleting a default partition Use the :class:`~psqlextra.backend.migrations.operations.PostgresDeleteDefaultPartition` operation to delete an existing default partition. + +.. warning:: + + Deleting the default partition and leaving your model without a default partition can be dangerous. Rows that do not fit in any other partition will fail to be inserted. + + .. code-block:: python from django.db import migrations, models @@ -176,7 +319,7 @@ Use the :class:`~psqlextra.backend.migrations.operations.PostgresDeleteDefaultPa Deleting a range partition ~~~~~~~~~~~~~~~~~~~~~~~~~~ -Use the :class:`psqlextra.backend.migrations.operations.PostgresDeleteRangePartition` operation to delete an existing range partition. +Use the :class:`psqlextra.backend.migrations.operations.PostgresDeleteRangePartition` operation to delete an existing range partition. Only use this operation when your partitioned model uses :attr:`psqlextra.types.PostgresPartitioningMethod.RANGE`. .. code-block:: python @@ -196,7 +339,7 @@ Use the :class:`psqlextra.backend.migrations.operations.PostgresDeleteRangeParti Deleting a list partition ~~~~~~~~~~~~~~~~~~~~~~~~~ -Use the :class:`~psqlextra.backend.migrations.operations.PostgresDeleteListPartition` operation to delete an existing list partition. +Use the :class:`psqlextra.backend.migrations.operations.PostgresDeleteListPartition` operation to delete an existing range partition. Only use this operation when your partitioned model uses :attr:`psqlextra.types.PostgresPartitioningMethod.LIST`. .. code-block:: python @@ -213,6 +356,26 @@ Use the :class:`~psqlextra.backend.migrations.operations.PostgresDeleteListParti ] +Deleting a hash partition +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Use the :class:`psqlextra.backend.migrations.operations.PostgresDeleteHashPartition` operation to delete an existing range partition. Only use this operation when your partitioned model uses :attr:`psqlextra.types.PostgresPartitioningMethod.HASH`. + +.. code-block:: python + + from django.db import migrations, models + + from psqlextra.backend.migrations.operations import PostgresDeleteHashPartition + + class Migration(migrations.Migration): + operations = [ + PostgresDeleteHashPartition( + model_name="mypartitionedmodel", + name="pt1", + ), + ] + + Using the schema editor *********************** @@ -248,120 +411,42 @@ Adding a list partition ) -Adding a default partition -~~~~~~~~~~~~~~~~~~~~~~~~~~ +Adding a hash partition +~~~~~~~~~~~~~~~~~~~~~~~ .. code-block:: python from django.db import connection - connection.schema_editor().add_default_partition( + connection.schema_editor().add_hash_partition( model=MyPartitionedModel, - name="default", + name="pt1", + modulus=3, + remainder=1, ) -Deleting a partition -~~~~~~~~~~~~~~~~~~~~ +Adding a default partition +~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code-block:: python from django.db import connection - connection.schema_editor().delete_partition( + connection.schema_editor().add_default_partition( model=MyPartitionedModel, name="default", ) -Adding/removing partitions automatically ----------------------------------------- - -:class:`psqlextra.partitioning.PostgresPartitioningManager` an experimental helper class that can be called periodically to automatically create new partitions if you're using range partitioning. - -.. note:: - - There is currently no scheduler or command to automatically create new partitions. You'll have to run this function in your own cron jobs. - -The auto partitioner supports automatically creating yearly, monthly, weekly or daily partitions. Use the ``count`` parameter to configure how many partitions it should create ahead. - - -Partitioning strategies -*********************** - - -Time-based partitioning -~~~~~~~~~~~~~~~~~~~~~~~ +Deleting a partition +~~~~~~~~~~~~~~~~~~~~ .. code-block:: python - from dateutil.relativedelta import relativedelta - - from psqlextra.partitioning import ( - PostgresPartitioningManager, - PostgresCurrentTimePartitioningStrategy, - PostgresTimePartitionSize, - partition_by_current_time, - ) - - manager = PostgresPartitioningManager([ - # 3 partitions ahead, each partition is one month - # delete partitions older than 6 months - # partitions will be named `[table_name]_[year]_[3-letter month name]`. - PostgresPartitioningConfig( - model=MyPartitionedModel, - strategy=PostgresCurrentTimePartitioningStrategy( - size=PostgresTimePartitionSize(months=1), - count=3, - max_age=relativedelta(months=6), - ), - ), - # 6 partitions ahead, each partition is two weeks - # delete partitions older than 8 months - # partitions will be named `[table_name]_[year]_week_[week number]`. - PostgresPartitioningConfig( - model=MyPartitionedModel, - strategy=PostgresCurrentTimePartitioningStrategy( - size=PostgresTimePartitionSize(weeks=2), - count=6, - max_age=relativedelta(months=8), - ), - ), - # 12 partitions ahead, each partition is 5 days - # old partitions are never deleted, `max_age` is not set - # partitions will be named `[table_name]_[year]_[month]_[month day number]`. - PostgresPartitioningConfig( - model=MyPartitionedModel, - strategy=PostgresCurrentTimePartitioningStrategy( - size=PostgresTimePartitionSize(wdyas=5), - count=12, - ), - ), - ]) + from django.db import connection - # these are the default arguments - partioning_plan = manager.plan( - skip_create=False, - skip_delete=False, - using='default' + connection.schema_editor().delete_partition( + model=MyPartitionedModel, + name="default", ) - - # prints a list of partitions to be created/deleted - partitioning_plan.print() - - # apply the plan - partitioning_plan.apply(using='default'); - - -Custom strategy -~~~~~~~~~~~~~~~ - -You can create a custom partitioning strategy by implementing the :class:`psqlextra.partitioning.PostgresPartitioningStrategy` interface. - -You can look at :class:`psqlextra.partitioning.PostgresCurrentTimePartitioningStrategy` as an example. - - -Switching partitioning strategies -********************************* - -When switching partitioning strategies, you might encounter the problem that partitions for part of a particular range already exist. In order to combat this, you can use the :class:`psqlextra.partitioning.PostgresTimePartitioningStrategy` and specify the `start_datetime` parameter. As a result, no partitions will be created before the given date/time. diff --git a/manage.py b/manage.py old mode 100644 new mode 100755 diff --git a/psqlextra/__init__.py b/psqlextra/__init__.py index 5b7b21a1..474f803b 100644 --- a/psqlextra/__init__.py +++ b/psqlextra/__init__.py @@ -1,4 +1,15 @@ import django +from ._version import __version__ + if django.VERSION < (3, 2): # pragma: no cover default_app_config = "psqlextra.apps.PostgresExtraAppConfig" + + __all__ = [ + "default_app_config", + "__version__", + ] +else: + __all__ = [ + "__version__", + ] diff --git a/psqlextra/_version.py b/psqlextra/_version.py new file mode 100644 index 00000000..e8733fa0 --- /dev/null +++ b/psqlextra/_version.py @@ -0,0 +1 @@ +__version__ = "2.0.9rc4" diff --git a/psqlextra/apps.py b/psqlextra/apps.py index 61ba29c4..d8965f7d 100644 --- a/psqlextra/apps.py +++ b/psqlextra/apps.py @@ -4,3 +4,6 @@ class PostgresExtraAppConfig(AppConfig): name = "psqlextra" verbose_name = "PostgreSQL Extra" + + def ready(self) -> None: + from .lookups import InValuesLookup # noqa diff --git a/psqlextra/backend/base.py b/psqlextra/backend/base.py index 0d19d2bf..c8ae73c5 100644 --- a/psqlextra/backend/base.py +++ b/psqlextra/backend/base.py @@ -1,6 +1,12 @@ import logging +from typing import TYPE_CHECKING + from django.conf import settings +from django.contrib.postgres.signals import ( + get_hstore_oids, + register_type_handlers, +) from django.db import ProgrammingError from . import base_impl @@ -8,20 +14,63 @@ from .operations import PostgresOperations from .schema import PostgresSchemaEditor +from django.db.backends.postgresql.base import ( # isort:skip + DatabaseWrapper as PostgresDatabaseWrapper, +) + + logger = logging.getLogger(__name__) -class DatabaseWrapper(base_impl.backend()): +if TYPE_CHECKING: + + class Wrapper(PostgresDatabaseWrapper): + pass + +else: + Wrapper = base_impl.backend() + + +class DatabaseWrapper(Wrapper): """Wraps the standard PostgreSQL database back-end. Overrides the schema editor with our custom schema editor and makes sure the `hstore` extension is enabled. """ - SchemaEditorClass = PostgresSchemaEditor + SchemaEditorClass = PostgresSchemaEditor # type: ignore[assignment] introspection_class = PostgresIntrospection ops_class = PostgresOperations + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Some base back-ends such as the PostGIS back-end don't properly + # set `ops_class` and `introspection_class` and initialize these + # classes themselves. + # + # This can lead to broken functionality. We fix this automatically. + + if not isinstance(self.introspection, self.introspection_class): + self.introspection = self.introspection_class(self) + + if not isinstance(self.ops, self.ops_class): + self.ops = self.ops_class(self) + + for expected_compiler_class in self.ops.compiler_classes: + compiler_class = self.ops.compiler(expected_compiler_class.__name__) + + if not issubclass(compiler_class, expected_compiler_class): + logger.warning( + "Compiler '%s.%s' is not properly deriving from '%s.%s'." + % ( + compiler_class.__module__, + compiler_class.__name__, + expected_compiler_class.__module__, + expected_compiler_class.__name__, + ) + ) + def prepare_database(self): """Ran to prepare the configured database. @@ -49,3 +98,22 @@ def prepare_database(self): "or add the extension manually.", exc_info=True, ) + return + + # Clear old (non-existent), stale oids. + get_hstore_oids.cache_clear() + + # Verify that we (and Django) can find the OIDs + # for hstore. + oids, _ = get_hstore_oids(self.alias) + if not oids: + logger.warning( + '"hstore" extension was created, but we cannot find the oids' + "in the database. Something went wrong.", + ) + return + + # We must trigger Django into registering the type handlers now + # so that any subsequent code can properly use the newly + # registered types. + register_type_handlers(self) diff --git a/psqlextra/backend/base_impl.py b/psqlextra/backend/base_impl.py index d13e78c9..88bf9278 100644 --- a/psqlextra/backend/base_impl.py +++ b/psqlextra/backend/base_impl.py @@ -2,14 +2,23 @@ from django.conf import settings from django.core.exceptions import ImproperlyConfigured +from django.db import DEFAULT_DB_ALIAS, connections +from django.db.backends.postgresql.base import DatabaseWrapper +from django.db.backends.postgresql.introspection import ( # type: ignore[import] + DatabaseIntrospection, +) +from django.db.backends.postgresql.operations import DatabaseOperations +from django.db.backends.postgresql.schema import ( # type: ignore[import] + DatabaseSchemaEditor, +) from django.db.backends.postgresql.base import ( # isort:skip DatabaseWrapper as Psycopg2DatabaseWrapper, ) -def backend(): - """Gets the base class for the custom database back-end. +def base_backend_instance(): + """Gets an instance of the base class for the custom database back-end. This should be the Django PostgreSQL back-end. However, some people are already using a custom back-end from @@ -19,6 +28,10 @@ def backend(): As long as the specified base eventually also has the PostgreSQL back-end as a base, then everything should work as intended. + + We create an instance to inspect what classes to subclass + because not all back-ends set properties such as `ops_class` + properly. The PostGIS back-end is a good example. """ base_class_name = getattr( settings, @@ -49,34 +62,51 @@ def backend(): % base_class_name ) - return base_class + base_instance = base_class(connections.databases[DEFAULT_DB_ALIAS]) + if base_instance.connection: + raise ImproperlyConfigured( + ( + "'%s' establishes a connection during initialization." + " This is not expected and can lead to more connections" + " being established than neccesarry." + ) + % base_class_name + ) + + return base_instance + + +def backend() -> DatabaseWrapper: + """Gets the base class for the database back-end.""" + + return base_backend_instance().__class__ -def schema_editor(): +def schema_editor() -> DatabaseSchemaEditor: """Gets the base class for the schema editor. We have to use the configured base back-end's schema editor for this. """ - return backend().SchemaEditorClass + return base_backend_instance().SchemaEditorClass -def introspection(): +def introspection() -> DatabaseIntrospection: """Gets the base class for the introspection class. We have to use the configured base back-end's introspection class for this. """ - return backend().introspection_class + return base_backend_instance().introspection.__class__ -def operations(): +def operations() -> DatabaseOperations: """Gets the base class for the operations class. We have to use the configured base back-end's operations class for this. """ - return backend().ops_class + return base_backend_instance().ops.__class__ diff --git a/psqlextra/backend/introspection.py b/psqlextra/backend/introspection.py index a85f27cd..bd775779 100644 --- a/psqlextra/backend/introspection.py +++ b/psqlextra/backend/introspection.py @@ -1,5 +1,9 @@ from dataclasses import dataclass -from typing import List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple + +from django.db.backends.postgresql.introspection import ( # type: ignore[import] + DatabaseIntrospection, +) from psqlextra.types import PostgresPartitioningMethod @@ -45,15 +49,41 @@ def partition_by_name( ) -class PostgresIntrospection(base_impl.introspection()): +if TYPE_CHECKING: + + class Introspection(DatabaseIntrospection): + pass + +else: + Introspection = base_impl.introspection() + + +class PostgresIntrospection(Introspection): """Adds introspection features specific to PostgreSQL.""" + # TODO: This class is a mess, both here and in the + # the base. + # + # Some methods return untyped dicts, some named tuples, + # some flat lists of strings. It's horribly inconsistent. + # + # Most methods are poorly named. For example; `get_table_description` + # does not return a complete table description. It merely returns + # the columns. + # + # We do our best in this class to stay consistent with + # the base in Django by respecting its naming scheme + # and commonly used return types. Creating an API that + # matches the look&feel from the Django base class + # is more important than fixing those issues. + def get_partitioned_tables( self, cursor - ) -> PostgresIntrospectedPartitonedTable: + ) -> List[PostgresIntrospectedPartitonedTable]: """Gets a list of partitioned tables.""" - sql = """ + cursor.execute( + """ SELECT pg_class.relname, pg_partitioned_table.partstrat @@ -64,8 +94,7 @@ def get_partitioned_tables( ON pg_class.oid = pg_partitioned_table.partrelid """ - - cursor.execute(sql) + ) return [ PostgresIntrospectedPartitonedTable( @@ -172,6 +201,24 @@ def get_partition_key(self, cursor, table_name: str) -> List[str]: cursor.execute(sql, (table_name,)) return [row[0] for row in cursor.fetchall()] + def get_columns(self, cursor, table_name: str): + return self.get_table_description(cursor, table_name) + + def get_schema_list(self, cursor) -> List[str]: + """A flat list of available schemas.""" + + cursor.execute( + """ + SELECT + schema_name + FROM + information_schema.schemata + """, + tuple(), + ) + + return [name for name, in cursor.fetchall()] + def get_constraints(self, cursor, table_name: str): """Retrieve any constraints or keys (unique, pk, fk, check, index) across one or more columns. @@ -187,8 +234,83 @@ def get_constraints(self, cursor, table_name: str): "SELECT indexname, indexdef FROM pg_indexes WHERE tablename = %s", (table_name,), ) - for index, definition in cursor.fetchall(): - if constraints[index].get("definition") is None: - constraints[index]["definition"] = definition + for index_name, definition in cursor.fetchall(): + # PostgreSQL 13 or older won't give a definition if the + # index is actually a primary key. + constraint = constraints.get(index_name) + if not constraint: + continue + + if constraint.get("definition") is None: + constraint["definition"] = definition return constraints + + def get_table_locks(self, cursor) -> List[Tuple[str, str, str]]: + cursor.execute( + """ + SELECT + n.nspname, + t.relname, + l.mode + FROM pg_locks l + INNER JOIN pg_class t ON t.oid = l.relation + INNER JOIN pg_namespace n ON n.oid = t.relnamespace + WHERE t.relnamespace >= 2200 + ORDER BY n.nspname, t.relname, l.mode + """ + ) + + return cursor.fetchall() + + def get_storage_settings(self, cursor, table_name: str) -> Dict[str, str]: + sql = """ + SELECT + unnest(c.reloptions || array(select 'toast.' || x from pg_catalog.unnest(tc.reloptions) x)) + FROM + pg_catalog.pg_class c + LEFT JOIN + pg_catalog.pg_class tc ON (c.reltoastrelid = tc.oid) + LEFT JOIN + pg_catalog.pg_am am ON (c.relam = am.oid) + WHERE + c.relname::text = %s + AND pg_catalog.pg_table_is_visible(c.oid) + """ + + cursor.execute(sql, (table_name,)) + + storage_settings = {} + for row in cursor.fetchall(): + # It's hard to believe, but storage settings are really + # represented as `key=value` strings in Postgres. + # See: https://www.postgresql.org/docs/current/catalog-pg-class.html + name, value = row[0].split("=") + storage_settings[name] = value + + return storage_settings + + def get_relations(self, cursor, table_name: str): + """Gets a dictionary {field_name: (field_name_other_table, + other_table)} representing all relations in the specified table. + + This is overriden because the query in Django does not handle + relations between tables in different schemas properly. + """ + + cursor.execute( + """ + SELECT a1.attname, c2.relname, a2.attname + FROM pg_constraint con + LEFT JOIN pg_class c1 ON con.conrelid = c1.oid + LEFT JOIN pg_class c2 ON con.confrelid = c2.oid + LEFT JOIN pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1] + LEFT JOIN pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1] + WHERE + con.conrelid = %s::regclass AND + con.contype = 'f' AND + pg_catalog.pg_table_is_visible(c1.oid) + """, + [table_name], + ) + return {row[0]: (row[2], row[1]) for row in cursor.fetchall()} diff --git a/psqlextra/backend/migrations/patched_autodetector.py b/psqlextra/backend/migrations/patched_autodetector.py index 66da6734..e5ba8938 100644 --- a/psqlextra/backend/migrations/patched_autodetector.py +++ b/psqlextra/backend/migrations/patched_autodetector.py @@ -1,6 +1,8 @@ from contextlib import contextmanager from unittest import mock +import django + from django.db.migrations import ( AddField, AlterField, @@ -10,8 +12,7 @@ RenameField, ) from django.db.migrations.autodetector import MigrationAutodetector -from django.db.migrations.operations.base import Operation -from django.db.models import Model +from django.db.migrations.operations.fields import FieldOperation from psqlextra.models import ( PostgresMaterializedViewModel, @@ -21,6 +22,11 @@ from psqlextra.types import PostgresPartitioningMethod from . import operations +from .state import ( + PostgresMaterializedViewModelState, + PostgresPartitionedModelState, + PostgresViewModelState, +) # original `MigrationAutodetector.add_operation` # function, saved here so the patched version can @@ -77,7 +83,7 @@ def rename_field(self, operation: RenameField): return self._transform_view_field_operations(operation) - def _transform_view_field_operations(self, operation: Operation): + def _transform_view_field_operations(self, operation: FieldOperation): """Transforms operations on fields on a (materialized) view into state only operations. @@ -88,12 +94,26 @@ def _transform_view_field_operations(self, operation: Operation): actually applying it. """ - model = self.autodetector.new_apps.get_model( - self.app_label, operation.model_name - ) + if django.VERSION >= (4, 0): + model_identifier = (self.app_label, operation.model_name.lower()) + model_state = ( + self.autodetector.to_state.models.get(model_identifier) + or self.autodetector.from_state.models[model_identifier] + ) + + if isinstance(model_state, PostgresViewModelState): + return self.add( + operations.ApplyState(state_operation=operation) + ) + else: + model = self.autodetector.new_apps.get_model( + self.app_label, operation.model_name + ) - if issubclass(model, PostgresViewModel): - return self.add(operations.ApplyState(state_operation=operation)) + if issubclass(model, PostgresViewModel): + return self.add( + operations.ApplyState(state_operation=operation) + ) return self.add(operation) @@ -101,16 +121,28 @@ def add_create_model(self, operation: CreateModel): """Adds the specified :see:CreateModel operation to the list of operations to execute in the migration.""" - model = self.autodetector.new_apps.get_model( - self.app_label, operation.name - ) + if django.VERSION >= (4, 0): + model_state = self.autodetector.to_state.models[ + self.app_label, operation.name.lower() + ] + + if isinstance(model_state, PostgresPartitionedModelState): + return self.add_create_partitioned_model(operation) + elif isinstance(model_state, PostgresMaterializedViewModelState): + return self.add_create_materialized_view_model(operation) + elif isinstance(model_state, PostgresViewModelState): + return self.add_create_view_model(operation) + else: + model = self.autodetector.new_apps.get_model( + self.app_label, operation.name + ) - if issubclass(model, PostgresPartitionedModel): - return self.add_create_partitioned_model(model, operation) - elif issubclass(model, PostgresMaterializedViewModel): - return self.add_create_materialized_view_model(model, operation) - elif issubclass(model, PostgresViewModel): - return self.add_create_view_model(model, operation) + if issubclass(model, PostgresPartitionedModel): + return self.add_create_partitioned_model(operation) + elif issubclass(model, PostgresMaterializedViewModel): + return self.add_create_materialized_view_model(operation) + elif issubclass(model, PostgresViewModel): + return self.add_create_view_model(operation) return self.add(operation) @@ -118,44 +150,68 @@ def add_delete_model(self, operation: DeleteModel): """Adds the specified :see:Deletemodel operation to the list of operations to execute in the migration.""" - model = self.autodetector.old_apps.get_model( - self.app_label, operation.name - ) + if django.VERSION >= (4, 0): + model_state = self.autodetector.from_state.models[ + self.app_label, operation.name.lower() + ] + + if isinstance(model_state, PostgresPartitionedModelState): + return self.add_delete_partitioned_model(operation) + elif isinstance(model_state, PostgresMaterializedViewModelState): + return self.add_delete_materialized_view_model(operation) + elif isinstance(model_state, PostgresViewModelState): + return self.add_delete_view_model(operation) + else: + model = self.autodetector.old_apps.get_model( + self.app_label, operation.name + ) - if issubclass(model, PostgresPartitionedModel): - return self.add_delete_partitioned_model(model, operation) - elif issubclass(model, PostgresMaterializedViewModel): - return self.add_delete_materialized_view_model(model, operation) - elif issubclass(model, PostgresViewModel): - return self.add_delete_view_model(model, operation) + if issubclass(model, PostgresPartitionedModel): + return self.add_delete_partitioned_model(operation) + elif issubclass(model, PostgresMaterializedViewModel): + return self.add_delete_materialized_view_model(operation) + elif issubclass(model, PostgresViewModel): + return self.add_delete_view_model(operation) return self.add(operation) - def add_create_partitioned_model( - self, model: Model, operation: CreateModel - ): + def add_create_partitioned_model(self, operation: CreateModel): """Adds a :see:PostgresCreatePartitionedModel operation to the list of operations to execute in the migration.""" - partitioning_options = model._partitioning_meta.original_attrs + if django.VERSION >= (4, 0): + model_state = self.autodetector.to_state.models[ + self.app_label, operation.name.lower() + ] + partitioning_options = model_state.partitioning_options + else: + model = self.autodetector.new_apps.get_model( + self.app_label, operation.name + ) + partitioning_options = model._partitioning_meta.original_attrs + _, args, kwargs = operation.deconstruct() if partitioning_options["method"] != PostgresPartitioningMethod.HASH: self.add( operations.PostgresAddDefaultPartition( - model_name=model.__name__, name="default" + model_name=operation.name, name="default" ) ) + partitioned_kwargs = { + **kwargs, + "partitioning_options": partitioning_options, + } + self.add( operations.PostgresCreatePartitionedModel( - *args, **kwargs, partitioning_options=partitioning_options + *args, + **partitioned_kwargs, ) ) - def add_delete_partitioned_model( - self, model: Model, operation: DeleteModel - ): + def add_delete_partitioned_model(self, operation: DeleteModel): """Adds a :see:PostgresDeletePartitionedModel operation to the list of operations to execute in the migration.""" @@ -164,44 +220,61 @@ def add_delete_partitioned_model( operations.PostgresDeletePartitionedModel(*args, **kwargs) ) - def add_create_view_model(self, model: Model, operation: CreateModel): + def add_create_view_model(self, operation: CreateModel): """Adds a :see:PostgresCreateViewModel operation to the list of operations to execute in the migration.""" - view_options = model._view_meta.original_attrs + if django.VERSION >= (4, 0): + model_state = self.autodetector.to_state.models[ + self.app_label, operation.name.lower() + ] + view_options = model_state.view_options + else: + model = self.autodetector.new_apps.get_model( + self.app_label, operation.name + ) + view_options = model._view_meta.original_attrs + _, args, kwargs = operation.deconstruct() - self.add( - operations.PostgresCreateViewModel( - *args, **kwargs, view_options=view_options - ) - ) + view_kwargs = {**kwargs, "view_options": view_options} + + self.add(operations.PostgresCreateViewModel(*args, **view_kwargs)) - def add_delete_view_model(self, model: Model, operation: DeleteModel): + def add_delete_view_model(self, operation: DeleteModel): """Adds a :see:PostgresDeleteViewModel operation to the list of operations to execute in the migration.""" _, args, kwargs = operation.deconstruct() return self.add(operations.PostgresDeleteViewModel(*args, **kwargs)) - def add_create_materialized_view_model( - self, model: Model, operation: CreateModel - ): + def add_create_materialized_view_model(self, operation: CreateModel): """Adds a :see:PostgresCreateMaterializedViewModel operation to the list of operations to execute in the migration.""" - view_options = model._view_meta.original_attrs + if django.VERSION >= (4, 0): + model_state = self.autodetector.to_state.models[ + self.app_label, operation.name.lower() + ] + view_options = model_state.view_options + else: + model = self.autodetector.new_apps.get_model( + self.app_label, operation.name + ) + view_options = model._view_meta.original_attrs + _, args, kwargs = operation.deconstruct() + view_kwargs = {**kwargs, "view_options": view_options} + self.add( operations.PostgresCreateMaterializedViewModel( - *args, **kwargs, view_options=view_options + *args, + **view_kwargs, ) ) - def add_delete_materialized_view_model( - self, model: Model, operation: DeleteModel - ): + def add_delete_materialized_view_model(self, operation: DeleteModel): """Adds a :see:PostgresDeleteMaterializedViewModel operation to the list of operations to execute in the migration.""" diff --git a/psqlextra/backend/migrations/state/model.py b/psqlextra/backend/migrations/state/model.py index 465b6152..797147f4 100644 --- a/psqlextra/backend/migrations/state/model.py +++ b/psqlextra/backend/migrations/state/model.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Type +from typing import Tuple, Type, cast from django.db.migrations.state import ModelState from django.db.models import Model @@ -17,8 +17,8 @@ class PostgresModelState(ModelState): """ @classmethod - def from_model( - cls, model: PostgresModel, *args, **kwargs + def from_model( # type: ignore[override] + cls, model: Type[PostgresModel], *args, **kwargs ) -> "PostgresModelState": """Creates a new :see:PostgresModelState object from the specified model. @@ -29,28 +29,32 @@ def from_model( We also need to patch up the base class for the model. """ - model_state = super().from_model(model, *args, **kwargs) - model_state = cls._pre_new(model, model_state) + model_state = super().from_model( + cast(Type[Model], model), *args, **kwargs + ) + model_state = cls._pre_new( + model, cast("PostgresModelState", model_state) + ) # django does not add abstract bases as a base in migrations # because it assumes the base does not add anything important # in a migration.. but it does, so we replace the Model # base with the actual base - bases = tuple() + bases: Tuple[Type[Model], ...] = tuple() for base in model_state.bases: if issubclass(base, Model): bases += (cls._get_base_model_class(),) else: bases += (base,) - model_state.bases = bases + model_state.bases = cast(Tuple[Type[Model]], bases) return model_state def clone(self) -> "PostgresModelState": """Gets an exact copy of this :see:PostgresModelState.""" model_state = super().clone() - return self._pre_clone(model_state) + return self._pre_clone(cast(PostgresModelState, model_state)) def render(self, apps): """Renders this state into an actual model.""" @@ -95,7 +99,9 @@ def render(self, apps): @classmethod def _pre_new( - cls, model: PostgresModel, model_state: "PostgresModelState" + cls, + model: Type[PostgresModel], + model_state: "PostgresModelState", ) -> "PostgresModelState": """Called when a new model state is created from the specified model.""" diff --git a/psqlextra/backend/migrations/state/partitioning.py b/psqlextra/backend/migrations/state/partitioning.py index aef7a5e3..e8b9a5eb 100644 --- a/psqlextra/backend/migrations/state/partitioning.py +++ b/psqlextra/backend/migrations/state/partitioning.py @@ -94,7 +94,7 @@ def delete_partition(self, name: str): del self.partitions[name] @classmethod - def _pre_new( + def _pre_new( # type: ignore[override] cls, model: PostgresPartitionedModel, model_state: "PostgresPartitionedModelState", @@ -108,7 +108,7 @@ def _pre_new( ) return model_state - def _pre_clone( + def _pre_clone( # type: ignore[override] self, model_state: "PostgresPartitionedModelState" ) -> "PostgresPartitionedModelState": """Called when this model state is cloned.""" diff --git a/psqlextra/backend/migrations/state/view.py b/psqlextra/backend/migrations/state/view.py index d59b3120..0f5b52eb 100644 --- a/psqlextra/backend/migrations/state/view.py +++ b/psqlextra/backend/migrations/state/view.py @@ -22,8 +22,10 @@ def __init__(self, *args, view_options={}, **kwargs): self.view_options = dict(view_options) @classmethod - def _pre_new( - cls, model: PostgresViewModel, model_state: "PostgresViewModelState" + def _pre_new( # type: ignore[override] + cls, + model: Type[PostgresViewModel], + model_state: "PostgresViewModelState", ) -> "PostgresViewModelState": """Called when a new model state is created from the specified model.""" @@ -31,7 +33,7 @@ def _pre_new( model_state.view_options = dict(model._view_meta.original_attrs) return model_state - def _pre_clone( + def _pre_clone( # type: ignore[override] self, model_state: "PostgresViewModelState" ) -> "PostgresViewModelState": """Called when this model state is cloned.""" diff --git a/psqlextra/backend/operations.py b/psqlextra/backend/operations.py index cab204a2..3bcf1897 100644 --- a/psqlextra/backend/operations.py +++ b/psqlextra/backend/operations.py @@ -1,27 +1,23 @@ -from importlib import import_module +from psqlextra.compiler import ( + SQLAggregateCompiler, + SQLCompiler, + SQLDeleteCompiler, + SQLInsertCompiler, + SQLUpdateCompiler, +) from . import base_impl -class PostgresOperations(base_impl.operations()): +class PostgresOperations(base_impl.operations()): # type: ignore[misc] """Simple operations specific to PostgreSQL.""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + compiler_module = "psqlextra.compiler" - self._compiler_cache = None - - def compiler(self, compiler_name: str): - """Gets the SQL compiler with the specified name.""" - - # first let django try to find the compiler - try: - return super().compiler(compiler_name) - except AttributeError: - pass - - # django can't find it, look in our own module - if self._compiler_cache is None: - self._compiler_cache = import_module("psqlextra.compiler") - - return getattr(self._compiler_cache, compiler_name) + compiler_classes = [ + SQLCompiler, + SQLDeleteCompiler, + SQLAggregateCompiler, + SQLUpdateCompiler, + SQLInsertCompiler, + ] diff --git a/psqlextra/backend/schema.py b/psqlextra/backend/schema.py index cc55dffa..31a23414 100644 --- a/psqlextra/backend/schema.py +++ b/psqlextra/backend/schema.py @@ -1,14 +1,24 @@ -from typing import Any, List, Optional +from typing import TYPE_CHECKING, Any, List, Optional, Type, cast from unittest import mock +import django + from django.core.exceptions import ( FieldDoesNotExist, ImproperlyConfigured, SuspiciousOperation, ) from django.db import transaction +from django.db.backends.ddl_references import Statement +from django.db.backends.postgresql.schema import ( # type: ignore[import] + DatabaseSchemaEditor, +) from django.db.models import Field, Model +from psqlextra.settings import ( + postgres_prepend_local_search_path, + postgres_reset_local_search_path, +) from psqlextra.type_assertions import is_sql_with_params from psqlextra.types import PostgresPartitioningMethod @@ -19,12 +29,35 @@ HStoreUniqueSchemaEditorSideEffect, ) +if TYPE_CHECKING: + + class SchemaEditor(DatabaseSchemaEditor): + pass + +else: + SchemaEditor = base_impl.schema_editor() + -class PostgresSchemaEditor(base_impl.schema_editor()): +class PostgresSchemaEditor(SchemaEditor): """Schema editor that adds extra methods for PostgreSQL specific features and hooks into existing implementations to add side effects specific to PostgreSQL.""" + sql_add_pk = "ALTER TABLE %s ADD PRIMARY KEY (%s)" + + sql_create_fk_not_valid = f"{SchemaEditor.sql_create_fk} NOT VALID" + sql_validate_fk = "ALTER TABLE %s VALIDATE CONSTRAINT %s" + + sql_create_sequence_with_owner = "CREATE SEQUENCE %s OWNED BY %s.%s" + + sql_alter_table_storage_setting = "ALTER TABLE %s SET (%s = %s)" + sql_reset_table_storage_setting = "ALTER TABLE %s RESET (%s)" + + sql_alter_table_schema = "ALTER TABLE %s SET SCHEMA %s" + sql_create_schema = "CREATE SCHEMA %s" + sql_delete_schema = "DROP SCHEMA %s" + sql_delete_schema_cascade = "DROP SCHEMA %s CASCADE" + sql_create_view = "CREATE VIEW %s AS (%s)" sql_replace_view = "CREATE OR REPLACE VIEW %s AS (%s)" sql_drop_view = "DROP VIEW IF EXISTS %s" @@ -48,9 +81,9 @@ class PostgresSchemaEditor(base_impl.schema_editor()): sql_delete_partition = "DROP TABLE %s" sql_table_comment = "COMMENT ON TABLE %s IS %s" - side_effects = [ - HStoreUniqueSchemaEditorSideEffect(), - HStoreRequiredSchemaEditorSideEffect(), + side_effects: List[DatabaseSchemaEditor] = [ + cast(DatabaseSchemaEditor, HStoreUniqueSchemaEditorSideEffect()), + cast(DatabaseSchemaEditor, HStoreRequiredSchemaEditorSideEffect()), ] def __init__(self, connection, collect_sql=False, atomic=True): @@ -63,7 +96,22 @@ def __init__(self, connection, collect_sql=False, atomic=True): self.deferred_sql = [] self.introspection = PostgresIntrospection(self.connection) - def create_model(self, model: Model) -> None: + def create_schema(self, name: str) -> None: + """Creates a Postgres schema.""" + + self.execute(self.sql_create_schema % self.quote_name(name)) + + def delete_schema(self, name: str, cascade: bool) -> None: + """Drops a Postgres schema.""" + + sql = ( + self.sql_delete_schema + if not cascade + else self.sql_delete_schema_cascade + ) + self.execute(sql % self.quote_name(name)) + + def create_model(self, model: Type[Model]) -> None: """Creates a new model.""" super().create_model(model) @@ -71,7 +119,7 @@ def create_model(self, model: Model) -> None: for side_effect in self.side_effects: side_effect.create_model(model) - def delete_model(self, model: Model) -> None: + def delete_model(self, model: Type[Model]) -> None: """Drops/deletes an existing model.""" for side_effect in self.side_effects: @@ -79,8 +127,395 @@ def delete_model(self, model: Model) -> None: super().delete_model(model) + def clone_model_structure_to_schema( + self, model: Type[Model], *, schema_name: str + ) -> None: + """Creates a clone of the columns for the specified model in a separate + schema. + + The table will have exactly the same name as the model table + in the default schema. It will have none of the constraints, + foreign keys and indexes. + + Use this to create a temporary clone of a model table to + replace the original model table later on. The lack of + indices and constraints allows for greater write speeds. + + The original model table will be unaffected. + + Arguments: + model: + Model to clone the table of into the + specified schema. + + schema_name: + Name of the schema to create the cloned + table in. + """ + + table_name = model._meta.db_table + quoted_table_name = self.quote_name(model._meta.db_table) + quoted_schema_name = self.quote_name(schema_name) + + quoted_table_fqn = f"{quoted_schema_name}.{quoted_table_name}" + + self.execute( + self.sql_create_table + % { + "table": quoted_table_fqn, + "definition": f"LIKE {quoted_table_name} INCLUDING ALL EXCLUDING CONSTRAINTS EXCLUDING INDEXES", + } + ) + + # Copy sequences + # + # Django 4.0 and older do not use IDENTITY so Postgres does + # not copy the sequences into the new table. We do it manually. + if django.VERSION < (4, 1): + with self.connection.cursor() as cursor: + sequences = self.introspection.get_sequences(cursor, table_name) + + for sequence in sequences: + if sequence["table"] != table_name: + continue + + quoted_sequence_name = self.quote_name(sequence["name"]) + quoted_sequence_fqn = ( + f"{quoted_schema_name}.{quoted_sequence_name}" + ) + quoted_column_name = self.quote_name(sequence["column"]) + + self.execute( + self.sql_create_sequence_with_owner + % ( + quoted_sequence_fqn, + quoted_table_fqn, + quoted_column_name, + ) + ) + + self.execute( + self.sql_alter_column + % { + "table": quoted_table_fqn, + "changes": self.sql_alter_column_default + % { + "column": quoted_column_name, + "default": "nextval('%s')" % quoted_sequence_fqn, + }, + } + ) + + # Copy storage settings + # + # Postgres only copies column-level storage options, not + # the table-level storage options. + with self.connection.cursor() as cursor: + storage_settings = self.introspection.get_storage_settings( + cursor, model._meta.db_table + ) + + for setting_name, setting_value in storage_settings.items(): + self.alter_table_storage_setting( + quoted_table_fqn, setting_name, setting_value + ) + + def clone_model_constraints_and_indexes_to_schema( + self, model: Type[Model], *, schema_name: str + ) -> None: + """Adds the constraints, foreign keys and indexes to a model table that + was cloned into a separate table without them by + `clone_model_structure_to_schema`. + + Arguments: + model: + Model for which the cloned table was created. + + schema_name: + Name of the schema in which the cloned table + resides. + """ + + with postgres_prepend_local_search_path( + [schema_name], using=self.connection.alias + ): + for constraint in model._meta.constraints: + self.add_constraint(model, constraint) # type: ignore[attr-defined] + + for index in model._meta.indexes: + self.add_index(model, index) + + if model._meta.unique_together: + self.alter_unique_together( + model, tuple(), model._meta.unique_together + ) + + if model._meta.index_together: + self.alter_index_together( + model, tuple(), model._meta.index_together + ) + + for field in model._meta.local_concrete_fields: # type: ignore[attr-defined] + # Django creates primary keys later added to the model with + # a custom name. We want the name as it was created originally. + if field.primary_key: + with postgres_reset_local_search_path( + using=self.connection.alias + ): + [primary_key_name] = self._constraint_names( # type: ignore[attr-defined] + model, primary_key=True + ) + + self.execute( + self.sql_create_pk + % { + "table": self.quote_name(model._meta.db_table), + "name": self.quote_name(primary_key_name), + "columns": self.quote_name( + field.db_column or field.attname + ), + } + ) + continue + + # Django creates foreign keys in a single statement which acquires + # a AccessExclusiveLock on the referenced table. We want to avoid + # that and created the FK as NOT VALID. We can run VALIDATE in + # a separate transaction later to validate the entries without + # acquiring a AccessExclusiveLock. + if field.remote_field: + with postgres_reset_local_search_path( + using=self.connection.alias + ): + [fk_name] = self._constraint_names( # type: ignore[attr-defined] + model, [field.column], foreign_key=True + ) + + sql = Statement( + self.sql_create_fk_not_valid, + table=self.quote_name(model._meta.db_table), + name=self.quote_name(fk_name), + column=self.quote_name(field.column), + to_table=self.quote_name( + field.target_field.model._meta.db_table + ), + to_column=self.quote_name(field.target_field.column), + deferrable=self.connection.ops.deferrable_sql(), + ) + + self.execute(sql) + + # It's hard to alter a field's check because it is defined + # by the field class, not the field instance. Handle this + # manually. + field_check = field.db_parameters(self.connection).get("check") + if field_check: + with postgres_reset_local_search_path( + using=self.connection.alias + ): + [field_check_name] = self._constraint_names( # type: ignore[attr-defined] + model, + [field.column], + check=True, + exclude={ + constraint.name + for constraint in model._meta.constraints + }, + ) + + self.execute( + self._create_check_sql( # type: ignore[attr-defined] + model, field_check_name, field_check + ) + ) + + # Clone the field and alter its state to math our current + # table definition. This will cause Django see the missing + # indices and create them. + if field.remote_field: + # We add the foreign key constraint ourselves with NOT VALID, + # hence, we specify `db_constraint=False` on both old/new. + # Django won't touch the foreign key constraint. + old_field = self._clone_model_field( + field, db_index=False, unique=False, db_constraint=False + ) + new_field = self._clone_model_field( + field, db_constraint=False + ) + self.alter_field(model, old_field, new_field) + else: + old_field = self._clone_model_field( + field, db_index=False, unique=False + ) + new_field = self._clone_model_field(field) + self.alter_field(model, old_field, new_field) + + def clone_model_foreign_keys_to_schema( + self, model: Type[Model], schema_name: str + ) -> None: + """Validates the foreign keys in the cloned model table created by + `clone_model_structure_to_schema` and + `clone_model_constraints_and_indexes_to_schema`. + + Do NOT run this in the same transaction as the + foreign keys were added to the table. It WILL + acquire a long-lived AccessExclusiveLock. + + Arguments: + model: + Model for which the cloned table was created. + + schema_name: + Name of the schema in which the cloned table + resides. + """ + + constraint_names = self._constraint_names(model, foreign_key=True) # type: ignore[attr-defined] + + with postgres_prepend_local_search_path( + [schema_name], using=self.connection.alias + ): + for fk_name in constraint_names: + self.execute( + self.sql_validate_fk + % ( + self.quote_name(model._meta.db_table), + self.quote_name(fk_name), + ) + ) + + def alter_table_storage_setting( + self, table_name: str, name: str, value: str + ) -> None: + """Alters a storage setting for a table. + + See: https://www.postgresql.org/docs/current/sql-createtable.html#SQL-CREATETABLE-STORAGE-PARAMETERS + + Arguments: + table_name: + Name of the table to alter the setting for. + + name: + Name of the setting to alter. + + value: + Value to alter the setting to. + + Note that this is always a string, even if it looks + like a number or a boolean. That's how Postgres + stores storage settings internally. + """ + + self.execute( + self.sql_alter_table_storage_setting + % (self.quote_name(table_name), name, value) + ) + + def alter_model_storage_setting( + self, model: Type[Model], name: str, value: str + ) -> None: + """Alters a storage setting for the model's table. + + See: https://www.postgresql.org/docs/current/sql-createtable.html#SQL-CREATETABLE-STORAGE-PARAMETERS + + Arguments: + model: + Model of which to alter the table + setting. + + name: + Name of the setting to alter. + + value: + Value to alter the setting to. + + Note that this is always a string, even if it looks + like a number or a boolean. That's how Postgres + stores storage settings internally. + """ + + self.alter_table_storage_setting(model._meta.db_table, name, value) + + def reset_table_storage_setting(self, table_name: str, name: str) -> None: + """Resets a table's storage setting to the database or server default. + + See: https://www.postgresql.org/docs/current/sql-createtable.html#SQL-CREATETABLE-STORAGE-PARAMETERS + + Arguments: + table_name: + Name of the table to reset the setting for. + + name: + Name of the setting to reset. + """ + + self.execute( + self.sql_reset_table_storage_setting + % (self.quote_name(table_name), name) + ) + + def reset_model_storage_setting( + self, model: Type[Model], name: str + ) -> None: + """Resets a model's table storage setting to the database or server + default. + + See: https://www.postgresql.org/docs/current/sql-createtable.html#SQL-CREATETABLE-STORAGE-PARAMETERS + + Arguments: + table_name: + model: + Model for which to reset the table setting for. + + name: + Name of the setting to reset. + """ + + self.reset_table_storage_setting(model._meta.db_table, name) + + def alter_table_schema(self, table_name: str, schema_name: str) -> None: + """Moves the specified table into the specified schema. + + WARNING: Moving models into a different schema than the default + will break querying the model. + + Arguments: + table_name: + Name of the table to move into the specified schema. + + schema_name: + Name of the schema to move the table to. + """ + + self.execute( + self.sql_alter_table_schema + % (self.quote_name(table_name), self.quote_name(schema_name)) + ) + + def alter_model_schema(self, model: Type[Model], schema_name: str) -> None: + """Moves the specified model's table into the specified schema. + + WARNING: Moving models into a different schema than the default + will break querying the model. + + Arguments: + model: + Model of which to move the table. + + schema_name: + Name of the schema to move the model's table to. + """ + + self.execute( + self.sql_alter_table_schema + % ( + self.quote_name(model._meta.db_table), + self.quote_name(schema_name), + ) + ) + def refresh_materialized_view_model( - self, model: Model, concurrently: bool = False + self, model: Type[Model], concurrently: bool = False ) -> None: """Refreshes a materialized view.""" @@ -93,12 +528,12 @@ def refresh_materialized_view_model( sql = sql_template % self.quote_name(model._meta.db_table) self.execute(sql) - def create_view_model(self, model: Model) -> None: + def create_view_model(self, model: Type[Model]) -> None: """Creates a new view model.""" self._create_view_model(self.sql_create_view, model) - def replace_view_model(self, model: Model) -> None: + def replace_view_model(self, model: Type[Model]) -> None: """Replaces a view model with a newer version. This is used to alter the backing query of a view. @@ -106,18 +541,18 @@ def replace_view_model(self, model: Model) -> None: self._create_view_model(self.sql_replace_view, model) - def delete_view_model(self, model: Model) -> None: + def delete_view_model(self, model: Type[Model]) -> None: """Deletes a view model.""" sql = self.sql_drop_view % self.quote_name(model._meta.db_table) self.execute(sql) - def create_materialized_view_model(self, model: Model) -> None: + def create_materialized_view_model(self, model: Type[Model]) -> None: """Creates a new materialized view model.""" self._create_view_model(self.sql_create_materialized_view, model) - def replace_materialized_view_model(self, model: Model) -> None: + def replace_materialized_view_model(self, model: Type[Model]) -> None: """Replaces a materialized view with a newer version. This is used to alter the backing query of a materialized view. @@ -143,12 +578,12 @@ def replace_materialized_view_model(self, model: Model) -> None: if not constraint_options["definition"]: raise SuspiciousOperation( "Table %s has a constraint '%s' that no definition could be generated for", - (model._meta.db_tabel, constraint_name), + (model._meta.db_table, constraint_name), ) self.execute(constraint_options["definition"]) - def delete_materialized_view_model(self, model: Model) -> None: + def delete_materialized_view_model(self, model: Type[Model]) -> None: """Deletes a materialized view model.""" sql = self.sql_drop_materialized_view % self.quote_name( @@ -156,7 +591,7 @@ def delete_materialized_view_model(self, model: Model) -> None: ) self.execute(sql) - def create_partitioned_model(self, model: Model) -> None: + def create_partitioned_model(self, model: Type[Model]) -> None: """Creates a new partitioned model.""" meta = self._partitioning_properties_for_model(model) @@ -171,10 +606,13 @@ def create_partitioned_model(self, model: Model) -> None: # create a composite key that includes the partitioning key sql = sql.replace(" PRIMARY KEY", "") - sql = sql[:-1] + ", PRIMARY KEY (%s, %s))" % ( - self.quote_name(model._meta.pk.name), - partitioning_key_sql, - ) + if model._meta.pk and model._meta.pk.name not in meta.key: + sql = sql[:-1] + ", PRIMARY KEY (%s, %s))" % ( + self.quote_name(model._meta.pk.name), + partitioning_key_sql, + ) + else: + sql = sql[:-1] + ", PRIMARY KEY (%s))" % (partitioning_key_sql,) # extend the standard CREATE TABLE statement with # 'PARTITION BY ...' @@ -185,14 +623,14 @@ def create_partitioned_model(self, model: Model) -> None: self.execute(sql, params) - def delete_partitioned_model(self, model: Model) -> None: + def delete_partitioned_model(self, model: Type[Model]) -> None: """Drops the specified partitioned model.""" return self.delete_model(model) def add_range_partition( self, - model: Model, + model: Type[Model], name: str, from_values: Any, to_values: Any, @@ -243,7 +681,7 @@ def add_range_partition( def add_list_partition( self, - model: Model, + model: Type[Model], name: str, values: List[Any], comment: Optional[str] = None, @@ -286,7 +724,7 @@ def add_list_partition( def add_hash_partition( self, - model: Model, + model: Type[Model], name: str, modulus: int, remainder: int, @@ -331,7 +769,7 @@ def add_hash_partition( self.set_comment_on_table(table_name, comment) def add_default_partition( - self, model: Model, name: str, comment: Optional[str] = None + self, model: Type[Model], name: str, comment: Optional[str] = None ) -> None: """Creates a new default partition for the specified partitioned model. @@ -367,7 +805,7 @@ def add_default_partition( if comment: self.set_comment_on_table(table_name, comment) - def delete_partition(self, model: Model, name: str) -> None: + def delete_partition(self, model: Type[Model], name: str) -> None: """Deletes the partition with the specified name.""" sql = self.sql_delete_partition % self.quote_name( @@ -376,7 +814,7 @@ def delete_partition(self, model: Model, name: str) -> None: self.execute(sql) def alter_db_table( - self, model: Model, old_db_table: str, new_db_table: str + self, model: Type[Model], old_db_table: str, new_db_table: str ) -> None: """Alters a table/model.""" @@ -385,7 +823,7 @@ def alter_db_table( for side_effect in self.side_effects: side_effect.alter_db_table(model, old_db_table, new_db_table) - def add_field(self, model: Model, field: Field) -> None: + def add_field(self, model: Type[Model], field: Field) -> None: """Adds a new field to an exisiting model.""" super().add_field(model, field) @@ -393,7 +831,7 @@ def add_field(self, model: Model, field: Field) -> None: for side_effect in self.side_effects: side_effect.add_field(model, field) - def remove_field(self, model: Model, field: Field) -> None: + def remove_field(self, model: Type[Model], field: Field) -> None: """Removes a field from an existing model.""" for side_effect in self.side_effects: @@ -403,7 +841,7 @@ def remove_field(self, model: Model, field: Field) -> None: def alter_field( self, - model: Model, + model: Type[Model], old_field: Field, new_field: Field, strict: bool = False, @@ -415,19 +853,110 @@ def alter_field( for side_effect in self.side_effects: side_effect.alter_field(model, old_field, new_field, strict) + def vacuum_table( + self, + table_name: str, + columns: List[str] = [], + *, + full: bool = False, + freeze: bool = False, + verbose: bool = False, + analyze: bool = False, + disable_page_skipping: bool = False, + skip_locked: bool = False, + index_cleanup: bool = False, + truncate: bool = False, + parallel: Optional[int] = None, + ) -> None: + """Runs the VACUUM statement on the specified table with the specified + options. + + Arguments: + table_name: + Name of the table to run VACUUM on. + + columns: + Optionally, a list of columns to vacuum. If not + specified, all columns are vacuumed. + """ + + if self.connection.in_atomic_block: + raise SuspiciousOperation("Vacuum cannot be done in a transaction") + + options = [] + if full: + options.append("FULL") + if freeze: + options.append("FREEZE") + if verbose: + options.append("VERBOSE") + if analyze: + options.append("ANALYZE") + if disable_page_skipping: + options.append("DISABLE_PAGE_SKIPPING") + if skip_locked: + options.append("SKIP_LOCKED") + if index_cleanup: + options.append("INDEX_CLEANUP") + if truncate: + options.append("TRUNCATE") + if parallel is not None: + options.append(f"PARALLEL {parallel}") + + sql = "VACUUM" + + if options: + options_sql = ", ".join(options) + sql += f" ({options_sql})" + + sql += f" {self.quote_name(table_name)}" + + if columns: + columns_sql = ", ".join( + [self.quote_name(column) for column in columns] + ) + sql += f" ({columns_sql})" + + self.execute(sql) + + def vacuum_model( + self, model: Type[Model], fields: List[Field] = [], **kwargs + ) -> None: + """Runs the VACUUM statement on the table of the specified model with + the specified options. + + Arguments: + table_name: + model: + Model of which to run VACUUM the table. + + fields: + Optionally, a list of fields to vacuum. If not + specified, all fields are vacuumed. + """ + + columns = [ + field.column + for field in fields + if getattr(field, "concrete", False) and field.column + ] + self.vacuum_table(model._meta.db_table, columns, **kwargs) + def set_comment_on_table(self, table_name: str, comment: str) -> None: """Sets the comment on the specified table.""" sql = self.sql_table_comment % (self.quote_name(table_name), "%s") self.execute(sql, (comment,)) - def _create_view_model(self, sql: str, model: Model) -> None: + def _create_view_model(self, sql: str, model: Type[Model]) -> None: """Creates a new view model using the specified SQL query.""" meta = self._view_properties_for_model(model) with self.connection.cursor() as cursor: - view_sql = cursor.mogrify(*meta.query).decode("utf-8") + view_sql = cursor.mogrify(*meta.query) + if isinstance(view_sql, bytes): + view_sql = view_sql.decode("utf-8") self.execute(sql % (self.quote_name(model._meta.db_table), view_sql)) @@ -446,7 +975,7 @@ def _extract_sql(self, method, *args): return tuple(execute.mock_calls[0])[1] @staticmethod - def _view_properties_for_model(model: Model): + def _view_properties_for_model(model: Type[Model]): """Gets the view options for the specified model. Raises: @@ -478,7 +1007,7 @@ def _view_properties_for_model(model: Model): return meta @staticmethod - def _partitioning_properties_for_model(model: Model): + def _partitioning_properties_for_model(model: Type[Model]): """Gets the partitioning options for the specified model. Raises: @@ -516,7 +1045,7 @@ def _partitioning_properties_for_model(model: Model): % (model.__name__, meta.method) ) - if not isinstance(meta.key, list): + if not isinstance(meta.key, (list, tuple)): raise ImproperlyConfigured( ( "Model '%s' is not properly configured to be partitioned." @@ -541,5 +1070,29 @@ def _partitioning_properties_for_model(model: Model): return meta - def create_partition_table_name(self, model: Model, name: str) -> str: + def create_partition_table_name(self, model: Type[Model], name: str) -> str: return "%s_%s" % (model._meta.db_table.lower(), name.lower()) + + def _clone_model_field(self, field: Field, **overrides) -> Field: + """Clones the specified model field and overrides its kwargs with the + specified overrides. + + The cloned field will not be contributed to the model. + """ + + _, _, field_args, field_kwargs = field.deconstruct() + + cloned_field_args = field_args[:] + cloned_field_kwargs = {**field_kwargs, **overrides} + + cloned_field = field.__class__( + *cloned_field_args, **cloned_field_kwargs + ) + cloned_field.model = field.model + cloned_field.set_attributes_from_name(field.name) + + if cloned_field.remote_field and field.remote_field: + cloned_field.remote_field.model = field.remote_field.model + cloned_field.set_attributes_from_rel() # type: ignore[attr-defined] + + return cloned_field diff --git a/psqlextra/compiler.py b/psqlextra/compiler.py index 23866b5c..36aad204 100644 --- a/psqlextra/compiler.py +++ b/psqlextra/compiler.py @@ -1,19 +1,97 @@ +import inspect +import os +import sys + from collections.abc import Iterable -from typing import Tuple, Union +from typing import TYPE_CHECKING, Tuple, Union, cast import django +from django.conf import settings from django.core.exceptions import SuspiciousOperation from django.db.models import Expression, Model, Q from django.db.models.fields.related import RelatedField -from django.db.models.sql.compiler import SQLInsertCompiler, SQLUpdateCompiler -from django.db.utils import ProgrammingError +from django.db.models.sql import compiler as django_compiler from .expressions import HStoreValue from .types import ConflictAction +if TYPE_CHECKING: + from .sql import PostgresInsertQuery + + +def append_caller_to_sql(sql): + """Append the caller to SQL queries. + + Adds the calling file and function as an SQL comment to each query. + Examples: + INSERT INTO "tests_47ee19d1" ("id", "title") + VALUES (1, 'Test') + RETURNING "tests_47ee19d1"."id" + /* 998020 test_append_caller_to_sql_crud .../django-postgres-extra/tests/test_append_caller_to_sql.py 55 */ + + SELECT "tests_47ee19d1"."id", "tests_47ee19d1"."title" + FROM "tests_47ee19d1" + WHERE "tests_47ee19d1"."id" = 1 + LIMIT 1 + /* 998020 test_append_caller_to_sql_crud .../django-postgres-extra/tests/test_append_caller_to_sql.py 69 */ + + UPDATE "tests_47ee19d1" + SET "title" = 'success' + WHERE "tests_47ee19d1"."id" = 1 + /* 998020 test_append_caller_to_sql_crud .../django-postgres-extra/tests/test_append_caller_to_sql.py 64 */ + + DELETE FROM "tests_47ee19d1" + WHERE "tests_47ee19d1"."id" IN (1) + /* 998020 test_append_caller_to_sql_crud .../django-postgres-extra/tests/test_append_caller_to_sql.py 74 */ + + Slow and blocking queries could be easily tracked down to their originator + within the source code using the "pg_stat_activity" table. + + Enable "POSTGRES_EXTRA_ANNOTATE_SQL" within the database settings to enable this feature. + """ + + if not getattr(settings, "POSTGRES_EXTRA_ANNOTATE_SQL", None): + return sql + + try: + # Search for the first non-Django caller + stack = inspect.stack() + for stack_frame in stack[1:]: + frame_filename = stack_frame[1] + frame_line = stack_frame[2] + frame_function = stack_frame[3] + if "/django/" in frame_filename or "/psqlextra/" in frame_filename: + continue + + return f"{sql} /* {os.getpid()} {frame_function} {frame_filename} {frame_line} */" + + # Django internal commands (like migrations) end up here + return f"{sql} /* {os.getpid()} {sys.argv[0]} */" + except Exception: + # Don't break anything because this convinence function runs into an unexpected situation + return sql -class PostgresUpdateCompiler(SQLUpdateCompiler): + +class SQLCompiler(django_compiler.SQLCompiler): # type: ignore [attr-defined] + def as_sql(self, *args, **kwargs): + sql, params = super().as_sql(*args, **kwargs) + return append_caller_to_sql(sql), params + + +class SQLDeleteCompiler(django_compiler.SQLDeleteCompiler): # type: ignore [name-defined] + def as_sql(self, *args, **kwargs): + sql, params = super().as_sql(*args, **kwargs) + return append_caller_to_sql(sql), params + + +class SQLAggregateCompiler(django_compiler.SQLAggregateCompiler): # type: ignore [name-defined] + def as_sql(self, *args, **kwargs): + sql, params = super().as_sql(*args, **kwargs) + return append_caller_to_sql(sql), params + + +class SQLUpdateCompiler(django_compiler.SQLUpdateCompiler): # type: ignore [name-defined] """Compiler for SQL UPDATE statements that allows us to use expressions inside HStore values. @@ -22,13 +100,13 @@ class PostgresUpdateCompiler(SQLUpdateCompiler): .update(name=dict(en=F('test'))) """ - def as_sql(self): + def as_sql(self, *args, **kwargs): self._prepare_query_values() - return super().as_sql() + sql, params = super().as_sql(*args, **kwargs) + return append_caller_to_sql(sql), params def _prepare_query_values(self): - """Extra prep on query values by converting dictionaries into. - + """Extra prep on query values by converting dictionaries into :see:HStoreValue expressions. This allows putting expressions in a dictionary. The @@ -69,46 +147,40 @@ def _does_dict_contain_expression(data: dict) -> bool: return False -class PostgresInsertCompiler(SQLInsertCompiler): +class SQLInsertCompiler(django_compiler.SQLInsertCompiler): # type: ignore [name-defined] """Compiler for SQL INSERT statements.""" - def __init__(self, *args, **kwargs): - """Initializes a new instance of :see:PostgresInsertCompiler.""" + def as_sql(self, *args, **kwargs): + """Builds the SQL INSERT statement.""" + queries = [ + (append_caller_to_sql(sql), params) + for sql, params in super().as_sql(*args, **kwargs) + ] + + return queries + + +class PostgresInsertOnConflictCompiler(django_compiler.SQLInsertCompiler): # type: ignore [name-defined] + """Compiler for SQL INSERT statements.""" + query: "PostgresInsertQuery" + + def __init__(self, *args, **kwargs): + """Initializes a new instance of + :see:PostgresInsertOnConflictCompiler.""" super().__init__(*args, **kwargs) self.qn = self.connection.ops.quote_name - def as_sql(self, return_id=False): + def as_sql(self, return_id=False, *args, **kwargs): """Builds the SQL INSERT statement.""" queries = [ self._rewrite_insert(sql, params, return_id) - for sql, params in super().as_sql() + for sql, params in super().as_sql(*args, **kwargs) ] return queries - def execute_sql(self, return_id=False): - # execute all the generate queries - with self.connection.cursor() as cursor: - rows = [] - for sql, params in self.as_sql(return_id): - cursor.execute(sql, params) - try: - rows.extend(cursor.fetchall()) - except ProgrammingError: - pass - - # create a mapping between column names and column value - return [ - { - column.name: row[column_index] - for column_index, column in enumerate(cursor.description) - if row - } - for row in rows - ] - def _rewrite_insert(self, sql, params, return_id=False): """Rewrites a formed SQL INSERT query to include the ON CONFLICT clause. @@ -120,9 +192,9 @@ def _rewrite_insert(self, sql, params, return_id=False): params: The parameters passed to the query. - returning: - What to put in the `RETURNING` clause - of the resulting query. + return_id: + Whether to only return the ID or all + columns. Returns: A tuple of the rewritten SQL query and new params. @@ -132,40 +204,46 @@ def _rewrite_insert(self, sql, params, return_id=False): self.qn(self.query.model._meta.pk.attname) if return_id else "*" ) - return self._rewrite_insert_on_conflict( + (sql, params) = self._rewrite_insert_on_conflict( sql, params, self.query.conflict_action.value, returning ) + return append_caller_to_sql(sql), params + def _rewrite_insert_on_conflict( self, sql, params, conflict_action: ConflictAction, returning ): """Rewrites a normal SQL INSERT query to add the 'ON CONFLICT' clause.""" - update_columns = ", ".join( - [ - "{0} = EXCLUDED.{0}".format(self.qn(field.column)) - for field in self.query.update_fields - ] - ) - # build the conflict target, the columns to watch # for conflicts - conflict_target = self._build_conflict_target() - index_predicate = self.query.index_predicate - update_condition = self.query.conflict_update_condition + on_conflict_clause = self._build_on_conflict_clause() + index_predicate = self.query.index_predicate # type: ignore[attr-defined] + update_condition = self.query.conflict_update_condition # type: ignore[attr-defined] - rewritten_sql = f"{sql} ON CONFLICT {conflict_target}" + rewritten_sql = f"{sql} {on_conflict_clause}" if index_predicate: expr_sql, expr_params = self._compile_expression(index_predicate) rewritten_sql += f" WHERE {expr_sql}" params += tuple(expr_params) + # Fallback in case the user didn't specify any update values. We can still + # make the query work if we switch to ConflictAction.NOTHING + if ( + conflict_action == ConflictAction.UPDATE.value + and not self.query.update_values + ): + conflict_action = ConflictAction.NOTHING + rewritten_sql += f" DO {conflict_action}" - if conflict_action == "UPDATE": - rewritten_sql += f" SET {update_columns}" + if conflict_action == ConflictAction.UPDATE.value: + set_sql, sql_params = self._build_set_statement() + + rewritten_sql += f" SET {set_sql}" + params += sql_params if update_condition: expr_sql, expr_params = self._compile_expression( @@ -178,6 +256,38 @@ def _rewrite_insert_on_conflict( return (rewritten_sql, params) + def _build_set_statement(self) -> Tuple[str, tuple]: + """Builds the SET statement for the ON CONFLICT DO UPDATE clause. + + This uses the update compiler to provide full compatibility with + the standard Django's `update(...)`. + """ + + # Local import to work around the circular dependency between + # the compiler and the queries. + from .sql import PostgresUpdateQuery + + query = cast(PostgresUpdateQuery, self.query.chain(PostgresUpdateQuery)) + query.add_update_values(self.query.update_values) + + sql, params = query.get_compiler(self.connection.alias).as_sql() + return sql.split("SET")[1].split(" WHERE")[0], tuple(params) + + def _build_on_conflict_clause(self): + if django.VERSION >= (2, 2): + from django.db.models.constraints import BaseConstraint + from django.db.models.indexes import Index + + if isinstance( + self.query.conflict_target, BaseConstraint + ) or isinstance(self.query.conflict_target, Index): + return "ON CONFLICT ON CONSTRAINT %s" % self.qn( + self.query.conflict_target.name + ) + + conflict_target = self._build_conflict_target() + return f"ON CONFLICT {conflict_target}" + def _build_conflict_target(self): """Builds the `conflict_target` for the ON CONFLICT clause.""" @@ -263,12 +373,15 @@ def _get_model_field(self, name: str): field_name = self._normalize_field_name(name) + if not self.query.model: + return None + # 'pk' has special meaning and always refers to the primary # key of a model, we have to respect this de-facto standard behaviour if field_name == "pk" and self.query.model._meta.pk: return self.query.model._meta.pk - for field in self.query.model._meta.local_concrete_fields: + for field in self.query.model._meta.local_concrete_fields: # type: ignore[attr-defined] if field.name == field_name or field.column == field_name: return field @@ -310,7 +423,7 @@ def _format_field_value(self, field_name) -> str: if isinstance(field, RelatedField) and isinstance(value, Model): value = value.pk - return SQLInsertCompiler.prepare_value( + return django_compiler.SQLInsertCompiler.prepare_value( # type: ignore[attr-defined] self, field, # Note: this deliberately doesn't use `pre_save_val` as we don't diff --git a/psqlextra/error.py b/psqlextra/error.py new file mode 100644 index 00000000..b3a5cf83 --- /dev/null +++ b/psqlextra/error.py @@ -0,0 +1,62 @@ +from typing import TYPE_CHECKING, Optional, Type, Union + +from django import db + +if TYPE_CHECKING: + from psycopg2 import Error as _Psycopg2Error + + Psycopg2Error: Optional[Type[_Psycopg2Error]] + + from psycopg import Error as _Psycopg3Error + + Psycopg3Error: Optional[Type[_Psycopg3Error]] + +try: + from psycopg2 import Error as Psycopg2Error # type: ignore[no-redef] +except ImportError: + Psycopg2Error = None # type: ignore[misc] + +try: + from psycopg import Error as Psycopg3Error # type: ignore[no-redef] +except ImportError: + Psycopg3Error = None # type: ignore[misc] + + +def extract_postgres_error( + error: db.Error, +) -> Optional[Union["_Psycopg2Error", "_Psycopg3Error"]]: + """Extracts the underlying :see:psycopg2.Error from the specified Django + database error. + + As per PEP-249, Django wraps all database errors in its own + exception. We can extract the underlying database error by examaning + the cause of the error. + """ + + if (Psycopg2Error and not isinstance(error.__cause__, Psycopg2Error)) and ( + Psycopg3Error and not isinstance(error.__cause__, Psycopg3Error) + ): + return None + + return error.__cause__ + + +def extract_postgres_error_code(error: db.Error) -> Optional[str]: + """Extracts the underlying Postgres error code. + + As per PEP-249, Django wraps all database errors in its own + exception. We can extract the underlying database error by examaning + the cause of the error. + """ + + cause = error.__cause__ + if not cause: + return None + + if Psycopg2Error and isinstance(cause, Psycopg2Error): + return cause.pgcode + + if Psycopg3Error and isinstance(cause, Psycopg3Error): + return cause.sqlstate + + return None diff --git a/psqlextra/expressions.py b/psqlextra/expressions.py index 1840283c..20486dfa 100644 --- a/psqlextra/expressions.py +++ b/psqlextra/expressions.py @@ -1,4 +1,6 @@ -from django.db.models import CharField, expressions +from typing import Union + +from django.db.models import CharField, Field, expressions class HStoreValue(expressions.Expression): @@ -140,7 +142,7 @@ def __init__(self, name: str, key: str): def resolve_expression(self, *args, **kwargs): """Resolves the expression into a :see:HStoreColumn expression.""" - original_expression: expressions.Col = super().resolve_expression( + original_expression: expressions.Col = super().resolve_expression( # type: ignore[annotation-unchecked] *args, **kwargs ) expression = HStoreColumn( @@ -212,11 +214,21 @@ class ExcludedCol(expressions.Expression): """References a column in PostgreSQL's special EXCLUDED column, which is used in upserts to refer to the data about to be inserted/updated. - See: https://www.postgresql.org/docs/9.5/sql-insert.html#SQL-ON-CONFLICT + See: https://www.postgresql.org/docs/current/sql-insert.html#SQL-ON-CONFLICT """ - def __init__(self, name: str): - self.name = name + def __init__(self, field_or_name: Union[Field, str]): + + # We support both field classes or just field names here. We prefer + # fields because when the expression is compiled, it might need + # the field information to figure out the correct placeholder. + # Even though that isn't require for this particular expression. + if isinstance(field_or_name, Field): + super().__init__(field_or_name) + self.name = field_or_name.column + else: + super().__init__(None) + self.name = field_or_name def as_sql(self, compiler, connection): quoted_name = connection.ops.quote_name(self.name) diff --git a/psqlextra/introspect/__init__.py b/psqlextra/introspect/__init__.py new file mode 100644 index 00000000..bd85935f --- /dev/null +++ b/psqlextra/introspect/__init__.py @@ -0,0 +1,8 @@ +from .fields import inspect_model_local_concrete_fields +from .models import model_from_cursor, models_from_cursor + +__all__ = [ + "models_from_cursor", + "model_from_cursor", + "inspect_model_local_concrete_fields", +] diff --git a/psqlextra/introspect/fields.py b/psqlextra/introspect/fields.py new file mode 100644 index 00000000..27ef28f7 --- /dev/null +++ b/psqlextra/introspect/fields.py @@ -0,0 +1,21 @@ +from typing import List, Type + +from django.db.models import Field, Model + + +def inspect_model_local_concrete_fields(model: Type[Model]) -> List[Field]: + """Gets a complete list of local and concrete fields on a model, these are + fields that directly map to a database colmn directly on the table backing + the model. + + This is similar to Django's `Meta.local_concrete_fields`, which is a + private API. This method utilizes only public APIs. + """ + + local_concrete_fields = [] + + for field in model._meta.get_fields(include_parents=False): + if isinstance(field, Field) and field.column and not field.many_to_many: + local_concrete_fields.append(field) + + return local_concrete_fields diff --git a/psqlextra/introspect/models.py b/psqlextra/introspect/models.py new file mode 100644 index 00000000..e160bcaf --- /dev/null +++ b/psqlextra/introspect/models.py @@ -0,0 +1,175 @@ +from typing import ( + Any, + Dict, + Generator, + Iterable, + List, + Optional, + Type, + TypeVar, + Union, + cast, +) + +from django.core.exceptions import FieldDoesNotExist +from django.db import connection, models +from django.db.models import Field, Model +from django.db.models.expressions import Expression + +from .fields import inspect_model_local_concrete_fields + +TModel = TypeVar("TModel", bound=models.Model) + + +def _construct_model( + model: Type[TModel], + columns: Iterable[str], + values: Iterable[Any], + *, + apply_converters: bool = True +) -> TModel: + fields_by_name_and_column = {} + for concrete_field in inspect_model_local_concrete_fields(model): + fields_by_name_and_column[concrete_field.attname] = concrete_field + + if concrete_field.db_column: + fields_by_name_and_column[concrete_field.db_column] = concrete_field + + indexable_columns = list(columns) + + row = {} + + for index, value in enumerate(values): + column = indexable_columns[index] + try: + field: Optional[Field] = cast(Field, model._meta.get_field(column)) + except FieldDoesNotExist: + field = fields_by_name_and_column.get(column) + + if not field: + continue + + field_column_expression = field.get_col(model._meta.db_table) + + if apply_converters: + converters = cast(Expression, field).get_db_converters( + connection + ) + connection.ops.get_db_converters(field_column_expression) + + converted_value = value + for converter in converters: + converted_value = converter( + converted_value, + field_column_expression, + connection, + ) + else: + converted_value = value + + row[field.attname] = converted_value + + instance = model(**row) + instance._state.adding = False + instance._state.db = connection.alias + + return instance + + +def models_from_cursor( + model: Type[TModel], cursor, *, related_fields: List[str] = [] +) -> Generator[TModel, None, None]: + """Fetches all rows from a cursor and converts the values into model + instances. + + This is roughly what Django does internally when you do queries. This + goes further than `Model.from_db` as it also applies converters to make + sure that values are converted into their Python equivalent. + + Use this when you've outgrown the ORM and you are writing performant + queries yourself and you need to map the results back into ORM objects. + + Arguments: + model: + Model to construct. + + cursor: + Cursor to read the rows from. + + related_fields: + List of ForeignKey/OneToOneField names that were joined + into the raw query. Use this to achieve the same thing + that Django's `.select_related()` does. + + Field names should be specified in the order that they + are SELECT'd in. + """ + + columns = [col[0] for col in cursor.description] + field_offset = len(inspect_model_local_concrete_fields(model)) + + rows = cursor.fetchmany() + + while rows: + for values in rows: + instance = _construct_model( + model, columns[:field_offset], values[:field_offset] + ) + + for index, related_field_name in enumerate(related_fields): + related_model = cast( + Union[Type[Model], None], + model._meta.get_field(related_field_name).related_model, + ) + if not related_model: + continue + + related_field_count = len( + inspect_model_local_concrete_fields(related_model) + ) + + # autopep8: off + related_columns = columns[ + field_offset : field_offset + related_field_count # noqa + ] + related_values = values[ + field_offset : field_offset + related_field_count # noqa + ] + # autopep8: one + + if ( + not related_columns + or not related_values + or all([value is None for value in related_values]) + ): + continue + + related_instance = _construct_model( + cast(Type[Model], related_model), + related_columns, + related_values, + ) + instance._state.fields_cache[related_field_name] = related_instance # type: ignore + + field_offset += len( + inspect_model_local_concrete_fields(related_model) + ) + + yield instance + + rows = cursor.fetchmany() + + +def model_from_cursor( + model: Type[TModel], cursor, *, related_fields: List[str] = [] +) -> Optional[TModel]: + return next( + models_from_cursor(model, cursor, related_fields=related_fields), None + ) + + +def model_from_dict( + model: Type[TModel], row: Dict[str, Any], *, apply_converters: bool = True +) -> TModel: + return _construct_model( + model, row.keys(), row.values(), apply_converters=apply_converters + ) diff --git a/psqlextra/locking.py b/psqlextra/locking.py new file mode 100644 index 00000000..da8ff567 --- /dev/null +++ b/psqlextra/locking.py @@ -0,0 +1,104 @@ +from enum import Enum +from typing import Optional, Type + +from django.db import DEFAULT_DB_ALIAS, connections, models + + +class PostgresTableLockMode(Enum): + """List of table locking modes. + + See: https://www.postgresql.org/docs/current/explicit-locking.html + """ + + ACCESS_SHARE = "ACCESS SHARE" + ROW_SHARE = "ROW SHARE" + ROW_EXCLUSIVE = "ROW EXCLUSIVE" + SHARE_UPDATE_EXCLUSIVE = "SHARE UPDATE EXCLUSIVE" + SHARE = "SHARE" + SHARE_ROW_EXCLUSIVE = "SHARE ROW EXCLUSIVE" + EXCLUSIVE = "EXCLUSIVE" + ACCESS_EXCLUSIVE = "ACCESS EXCLUSIVE" + + @property + def alias(self) -> str: + return ( + "".join([word.title() for word in self.name.lower().split("_")]) + + "Lock" + ) + + +def postgres_lock_table( + table_name: str, + lock_mode: PostgresTableLockMode, + *, + schema_name: Optional[str] = None, + using: str = DEFAULT_DB_ALIAS, +) -> None: + """Locks the specified table with the specified mode. + + The lock is held until the end of the current transaction. + + Arguments: + table_name: + Unquoted table name to acquire the lock on. + + lock_mode: + Type of lock to acquire. + + schema_name: + Optionally, the unquoted name of the schema + the table to lock is in. If not specified, + the table name is resolved by PostgreSQL + using it's ``search_path``. + + using: + Optional name of the database connection to use. + """ + + connection = connections[using] + + with connection.cursor() as cursor: + quoted_fqn = connection.ops.quote_name(table_name) + if schema_name: + quoted_fqn = ( + connection.ops.quote_name(schema_name) + "." + quoted_fqn + ) + + cursor.execute(f"LOCK TABLE {quoted_fqn} IN {lock_mode.value} MODE") + + +def postgres_lock_model( + model: Type[models.Model], + lock_mode: PostgresTableLockMode, + *, + using: str = DEFAULT_DB_ALIAS, + schema_name: Optional[str] = None, +) -> None: + """Locks the specified model with the specified mode. + + The lock is held until the end of the current transaction. + + Arguments: + model: + The model of which to lock the table. + + lock_mode: + Type of lock to acquire. + + schema_name: + Optionally, the unquoted name of the schema + the table to lock is in. If not specified, + the table name is resolved by PostgreSQL + using it's ``search_path``. + + Django models always reside in the default + ("public") schema. You should not specify + this unless you're doing something special. + + using: + Optional name of the database connection to use. + """ + + postgres_lock_table( + model._meta.db_table, lock_mode, schema_name=schema_name, using=using + ) diff --git a/psqlextra/lookups.py b/psqlextra/lookups.py new file mode 100644 index 00000000..4010310b --- /dev/null +++ b/psqlextra/lookups.py @@ -0,0 +1,34 @@ +from django.db.models import lookups +from django.db.models.fields import Field, related_lookups +from django.db.models.fields.related import ForeignObject + + +class InValuesLookupMixin: + """Performs a `lhs IN VALUES ((a), (b), (c))` lookup. + + This can be significantly faster then a normal `IN (a, b, c)`. The + latter sometimes causes the Postgres query planner do a sequential + scan. + """ + + def as_sql(self, compiler, connection): + + if not self.rhs_is_direct_value(): + return super().as_sql(compiler, connection) + + lhs, lhs_params = self.process_lhs(compiler, connection) + + _, rhs_params = self.process_rhs(compiler, connection) + rhs = ",".join([f"(%s)" for _ in rhs_params]) # noqa: F541 + + return f"{lhs} IN (VALUES {rhs})", lhs_params + list(rhs_params) + + +@Field.register_lookup +class InValuesLookup(InValuesLookupMixin, lookups.In): + lookup_name = "invalues" + + +@ForeignObject.register_lookup +class InValuesRelatedLookup(InValuesLookupMixin, related_lookups.RelatedIn): + lookup_name = "invalues" diff --git a/psqlextra/management/commands/pgmakemigrations.py b/psqlextra/management/commands/pgmakemigrations.py index cdb7131b..7b678855 100644 --- a/psqlextra/management/commands/pgmakemigrations.py +++ b/psqlextra/management/commands/pgmakemigrations.py @@ -1,4 +1,6 @@ -from django.core.management.commands import makemigrations +from django.core.management.commands import ( # type: ignore[attr-defined] + makemigrations, +) from psqlextra.backend.migrations import postgres_patched_migrations diff --git a/psqlextra/management/commands/pgpartition.py b/psqlextra/management/commands/pgpartition.py index d8e5d993..8a6fa636 100644 --- a/psqlextra/management/commands/pgpartition.py +++ b/psqlextra/management/commands/pgpartition.py @@ -2,9 +2,6 @@ from typing import Optional -import colorama - -from ansimarkup import ansiprint, ansistring from django.conf import settings from django.core.management.base import BaseCommand from django.utils.module_loading import import_string @@ -40,7 +37,7 @@ def add_arguments(self, parser): parser.add_argument( "--using", "-u", - help="Name of the database connection to use.", + help="Optional name of the database connection to use.", default="default", ) @@ -60,7 +57,7 @@ def add_arguments(self, parser): default=False, ) - def handle( + def handle( # type: ignore[override] self, dry: bool, yes: bool, @@ -70,10 +67,6 @@ def handle( *args, **kwargs, ): - # disable coloring if no terminal is attached - if not sys.stdout.isatty(): - colorama.init(strip=True) - partitioning_manager = self._partitioning_manager() plan = partitioning_manager.plan( @@ -83,7 +76,7 @@ def handle( creations_count = len(plan.creations) deletions_count = len(plan.deletions) if creations_count == 0 and deletions_count == 0: - ansiprint("Nothing to be done.") + print("Nothing to be done.") return plan.print() @@ -92,18 +85,14 @@ def handle( return if not yes: - sys.stdout.write( - ansistring( - "Do you want to proceed? (y/N) " - ) - ) + sys.stdout.write("Do you want to proceed? (y/N) ") if not self._ask_for_confirmation(): - ansiprint("Operation aborted.") + print("Operation aborted.") return plan.apply(using=using) - ansiprint("Operations applied.") + print("Operations applied.") @staticmethod def _ask_for_confirmation() -> bool: @@ -119,7 +108,7 @@ def _ask_for_confirmation() -> bool: @staticmethod def _partitioning_manager(): partitioning_manager = getattr( - settings, "PSQLEXTRA_PARTITIONING_MANAGER" + settings, "PSQLEXTRA_PARTITIONING_MANAGER", None ) if not partitioning_manager: raise PostgresPartitioningError( diff --git a/psqlextra/manager/manager.py b/psqlextra/manager/manager.py index 4b96e34f..ee1eb58b 100644 --- a/psqlextra/manager/manager.py +++ b/psqlextra/manager/manager.py @@ -8,7 +8,7 @@ from psqlextra.query import PostgresQuerySet -class PostgresManager(Manager.from_queryset(PostgresQuerySet)): +class PostgresManager(Manager.from_queryset(PostgresQuerySet)): # type: ignore[misc] """Adds support for PostgreSQL specifics.""" use_in_migrations = True @@ -37,7 +37,10 @@ def __init__(self, *args, **kwargs): ) def truncate( - self, cascade: bool = False, using: Optional[str] = None + self, + cascade: bool = False, + restart_identity: bool = False, + using: Optional[str] = None, ) -> None: """Truncates this model/table using the TRUNCATE statement. @@ -51,14 +54,19 @@ def truncate( False, an error will be raised if there are rows in other tables referencing the rows you're trying to delete. + restart_identity: + Automatically restart sequences owned by + columns of the truncated table(s). """ connection = connections[using or "default"] table_name = connection.ops.quote_name(self.model._meta.db_table) with connection.cursor() as cursor: - sql = "TRUNCATE TABLE %s" % table_name + sql = f"TRUNCATE TABLE {table_name}" if cascade: sql += " CASCADE" + if restart_identity: + sql += " RESTART IDENTITY" cursor.execute(sql) diff --git a/psqlextra/models/base.py b/psqlextra/models/base.py index 21caad36..d240237a 100644 --- a/psqlextra/models/base.py +++ b/psqlextra/models/base.py @@ -1,4 +1,7 @@ +from typing import Any + from django.db import models +from django.db.models import Manager from psqlextra.manager import PostgresManager @@ -10,4 +13,4 @@ class Meta: abstract = True base_manager_name = "objects" - objects = PostgresManager() + objects: "Manager[Any]" = PostgresManager() diff --git a/psqlextra/models/partitioned.py b/psqlextra/models/partitioned.py index c03f3e93..f0115367 100644 --- a/psqlextra/models/partitioned.py +++ b/psqlextra/models/partitioned.py @@ -1,3 +1,5 @@ +from typing import Iterable + from django.db.models.base import ModelBase from psqlextra.types import PostgresPartitioningMethod @@ -15,7 +17,7 @@ class PostgresPartitionedModelMeta(ModelBase): """ default_method = PostgresPartitioningMethod.RANGE - default_key = [] + default_key: Iterable[str] = [] def __new__(cls, name, bases, attrs, **kwargs): new_class = super().__new__(cls, name, bases, attrs, **kwargs) @@ -38,6 +40,8 @@ class PostgresPartitionedModel( """Base class for taking advantage of PostgreSQL's 11.x native support for table partitioning.""" + _partitioning_meta: PostgresPartitionedModelOptions + class Meta: abstract = True base_manager_name = "objects" diff --git a/psqlextra/models/view.py b/psqlextra/models/view.py index a9497057..b19f88c8 100644 --- a/psqlextra/models/view.py +++ b/psqlextra/models/view.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast from django.core.exceptions import ImproperlyConfigured from django.db import connections @@ -12,6 +12,9 @@ from .base import PostgresModel from .options import PostgresViewOptions +if TYPE_CHECKING: + from psqlextra.backend.schema import PostgresSchemaEditor + ViewQueryValue = Union[QuerySet, SQLWithParams, SQL] ViewQuery = Optional[Union[ViewQueryValue, Callable[[], ViewQueryValue]]] @@ -77,23 +80,26 @@ def _view_query_as_sql_with_params( " to be a valid `django.db.models.query.QuerySet`" " SQL string, or tuple of SQL string and params." ) - % (model.__name__) + % (model.__class__.__name__) ) # querysets can easily be converted into sql, params if is_query_set(view_query): - return view_query.query.sql_with_params() + return cast("QuerySet[Any]", view_query).query.sql_with_params() # query was already specified in the target format if is_sql_with_params(view_query): - return view_query + return cast(SQLWithParams, view_query) - return view_query, tuple() + view_query_sql = cast(str, view_query) + return view_query_sql, tuple() class PostgresViewModel(PostgresModel, metaclass=PostgresViewModelMeta): """Base class for creating a model that is a view.""" + _view_meta: PostgresViewOptions + class Meta: abstract = True base_manager_name = "objects" @@ -127,4 +133,6 @@ def refresh( conn_name = using or "default" with connections[conn_name].schema_editor() as schema_editor: - schema_editor.refresh_materialized_view_model(cls, concurrently) + cast( + "PostgresSchemaEditor", schema_editor + ).refresh_materialized_view_model(cls, concurrently) diff --git a/psqlextra/partitioning/__init__.py b/psqlextra/partitioning/__init__.py index 9e67ddf5..970c3ad7 100644 --- a/psqlextra/partitioning/__init__.py +++ b/psqlextra/partitioning/__init__.py @@ -5,6 +5,7 @@ from .partition import PostgresPartition from .plan import PostgresModelPartitioningPlan, PostgresPartitioningPlan from .range_partition import PostgresRangePartition +from .range_strategy import PostgresRangePartitioningStrategy from .shorthands import partition_by_current_time from .strategy import PostgresPartitioningStrategy from .time_partition import PostgresTimePartition @@ -22,8 +23,8 @@ "PostgresTimePartition", "PostgresPartitioningStrategy", "PostgresTimePartitioningStrategy", - "PostgresCurrentTimePartitioningStrategy", "PostgresRangePartitioningStrategy", + "PostgresCurrentTimePartitioningStrategy", "PostgresPartitioningConfig", "PostgresTimePartitionSize", ] diff --git a/psqlextra/partitioning/config.py b/psqlextra/partitioning/config.py index df21c057..976bf1ae 100644 --- a/psqlextra/partitioning/config.py +++ b/psqlextra/partitioning/config.py @@ -1,3 +1,5 @@ +from typing import Type + from psqlextra.models import PostgresPartitionedModel from .strategy import PostgresPartitioningStrategy @@ -9,7 +11,7 @@ class PostgresPartitioningConfig: def __init__( self, - model: PostgresPartitionedModel, + model: Type[PostgresPartitionedModel], strategy: PostgresPartitioningStrategy, ) -> None: self.model = model diff --git a/psqlextra/partitioning/current_time_strategy.py b/psqlextra/partitioning/current_time_strategy.py index a0268be6..114a1aaf 100644 --- a/psqlextra/partitioning/current_time_strategy.py +++ b/psqlextra/partitioning/current_time_strategy.py @@ -24,6 +24,7 @@ def __init__( size: PostgresTimePartitionSize, count: int, max_age: Optional[relativedelta] = None, + name_format: Optional[str] = None, ) -> None: """Initializes a new instance of :see:PostgresTimePartitioningStrategy. @@ -44,13 +45,16 @@ def __init__( self.size = size self.count = count self.max_age = max_age + self.name_format = name_format def to_create(self) -> Generator[PostgresTimePartition, None, None]: current_datetime = self.size.start(self.get_start_datetime()) for _ in range(self.count): yield PostgresTimePartition( - start_datetime=current_datetime, size=self.size + start_datetime=current_datetime, + size=self.size, + name_format=self.name_format, ) current_datetime += self.size.as_delta() @@ -65,7 +69,9 @@ def to_delete(self) -> Generator[PostgresTimePartition, None, None]: while True: yield PostgresTimePartition( - start_datetime=current_datetime, size=self.size + start_datetime=current_datetime, + size=self.size, + name_format=self.name_format, ) current_datetime -= self.size.as_delta() diff --git a/psqlextra/partitioning/manager.py b/psqlextra/partitioning/manager.py index 28aee91e..074cc1c6 100644 --- a/psqlextra/partitioning/manager.py +++ b/psqlextra/partitioning/manager.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Type from django.db import connections @@ -39,7 +39,7 @@ def plan( for deletion, regardless of the configuration. using: - Name of the database connection to use. + Optional name of the database connection to use. Returns: A plan describing what partitions would be created @@ -111,7 +111,9 @@ def _plan_for_config( return model_plan @staticmethod - def _get_partitioned_table(connection, model: PostgresPartitionedModel): + def _get_partitioned_table( + connection, model: Type[PostgresPartitionedModel] + ): with connection.cursor() as cursor: table = connection.introspection.get_partitioned_table( cursor, model._meta.db_table diff --git a/psqlextra/partitioning/partition.py b/psqlextra/partitioning/partition.py index ca64bbdc..4c13fda0 100644 --- a/psqlextra/partitioning/partition.py +++ b/psqlextra/partitioning/partition.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Optional +from typing import Optional, Type from psqlextra.backend.schema import PostgresSchemaEditor from psqlextra.models import PostgresPartitionedModel @@ -15,7 +15,7 @@ def name(self) -> str: @abstractmethod def create( self, - model: PostgresPartitionedModel, + model: Type[PostgresPartitionedModel], schema_editor: PostgresSchemaEditor, comment: Optional[str] = None, ) -> None: @@ -24,7 +24,7 @@ def create( @abstractmethod def delete( self, - model: PostgresPartitionedModel, + model: Type[PostgresPartitionedModel], schema_editor: PostgresSchemaEditor, ) -> None: """Deletes this partition from the database.""" diff --git a/psqlextra/partitioning/plan.py b/psqlextra/partitioning/plan.py index bdcf04d0..3fcac44d 100644 --- a/psqlextra/partitioning/plan.py +++ b/psqlextra/partitioning/plan.py @@ -1,13 +1,15 @@ from dataclasses import dataclass, field -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional, cast -from ansimarkup import ansiprint from django.db import connections, transaction from .config import PostgresPartitioningConfig from .constants import AUTO_PARTITIONED_COMMENT from .partition import PostgresPartition +if TYPE_CHECKING: + from psqlextra.backend.schema import PostgresSchemaEditor + @dataclass class PostgresModelPartitioningPlan: @@ -29,7 +31,7 @@ def apply(self, using: Optional[str]) -> None: Arguments: using: - Name of the database connection to use. + Optional name of the database connection to use. """ connection = connections[using or "default"] @@ -39,27 +41,30 @@ def apply(self, using: Optional[str]) -> None: for partition in self.creations: partition.create( self.config.model, - schema_editor, + cast("PostgresSchemaEditor", schema_editor), comment=AUTO_PARTITIONED_COMMENT, ) for partition in self.deletions: - partition.delete(self.config.model, schema_editor) + partition.delete( + self.config.model, + cast("PostgresSchemaEditor", schema_editor), + ) def print(self) -> None: """Prints this model plan to the terminal in a readable format.""" - ansiprint(f"{self.config.model.__name__}:") + print(f"{self.config.model.__name__}:") for partition in self.deletions: - ansiprint(" - %s" % partition.name()) + print(" - %s" % partition.name()) for key, value in partition.deconstruct().items(): - ansiprint(f" {key}: {value}") + print(f" {key}: {value}") for partition in self.creations: - ansiprint(" + %s" % partition.name()) + print(" + %s" % partition.name()) for key, value in partition.deconstruct().items(): - ansiprint(f" {key}: {value}") + print(f" {key}: {value}") @dataclass @@ -104,12 +109,8 @@ def print(self) -> None: create_count = len(self.creations) delete_count = len(self.deletions) - ansiprint( - f"{delete_count} partitions will be deleted" - ) - ansiprint( - f"{create_count} partitions will be created" - ) + print(f"{delete_count} partitions will be deleted") + print(f"{create_count} partitions will be created") __all__ = ["PostgresPartitioningPlan", "PostgresModelPartitioningPlan"] diff --git a/psqlextra/partitioning/range_partition.py b/psqlextra/partitioning/range_partition.py index b49fe784..a2f3e82f 100644 --- a/psqlextra/partitioning/range_partition.py +++ b/psqlextra/partitioning/range_partition.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Optional, Type from psqlextra.backend.schema import PostgresSchemaEditor from psqlextra.models import PostgresPartitionedModel @@ -23,7 +23,7 @@ def deconstruct(self) -> dict: def create( self, - model: PostgresPartitionedModel, + model: Type[PostgresPartitionedModel], schema_editor: PostgresSchemaEditor, comment: Optional[str] = None, ) -> None: @@ -37,7 +37,7 @@ def create( def delete( self, - model: PostgresPartitionedModel, + model: Type[PostgresPartitionedModel], schema_editor: PostgresSchemaEditor, ) -> None: schema_editor.delete_partition(model, self.name()) diff --git a/psqlextra/partitioning/shorthands.py b/psqlextra/partitioning/shorthands.py index 05ce4a34..30175273 100644 --- a/psqlextra/partitioning/shorthands.py +++ b/psqlextra/partitioning/shorthands.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Type from dateutil.relativedelta import relativedelta @@ -10,13 +10,14 @@ def partition_by_current_time( - model: PostgresPartitionedModel, + model: Type[PostgresPartitionedModel], count: int, years: Optional[int] = None, months: Optional[int] = None, weeks: Optional[int] = None, days: Optional[int] = None, max_age: Optional[relativedelta] = None, + name_format: Optional[str] = None, ) -> PostgresPartitioningConfig: """Short-hand for generating a partitioning config that partitions the specified model by time. @@ -48,6 +49,10 @@ def partition_by_current_time( Partitions older than this are deleted when running a delete/cleanup run. + + name_format: + The datetime format which is being passed to datetime.strftime + to generate the partition name. """ size = PostgresTimePartitionSize( @@ -57,7 +62,10 @@ def partition_by_current_time( return PostgresPartitioningConfig( model=model, strategy=PostgresCurrentTimePartitioningStrategy( - size=size, count=count, max_age=max_age + size=size, + count=count, + max_age=max_age, + name_format=name_format, ), ) diff --git a/psqlextra/partitioning/time_partition.py b/psqlextra/partitioning/time_partition.py index b6be67a1..3c8a4d87 100644 --- a/psqlextra/partitioning/time_partition.py +++ b/psqlextra/partitioning/time_partition.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Optional from .error import PostgresPartitioningError from .range_partition import PostgresRangePartition @@ -22,7 +23,10 @@ class PostgresTimePartition(PostgresRangePartition): } def __init__( - self, size: PostgresTimePartitionSize, start_datetime: datetime + self, + size: PostgresTimePartitionSize, + start_datetime: datetime, + name_format: Optional[str] = None, ) -> None: end_datetime = start_datetime + size.as_delta() @@ -34,9 +38,12 @@ def __init__( self.size = size self.start_datetime = start_datetime self.end_datetime = end_datetime + self.name_format = name_format def name(self) -> str: - name_format = self._unit_name_format.get(self.size.unit) + name_format = self.name_format or self._unit_name_format.get( + self.size.unit + ) if not name_format: raise PostgresPartitioningError("Unknown size/unit") diff --git a/psqlextra/py.typed b/psqlextra/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/psqlextra/query.py b/psqlextra/query.py index f52f31e0..6a86f18e 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -1,19 +1,55 @@ from collections import OrderedDict from itertools import chain -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + Iterable, + List, + Optional, + Tuple, + TypeVar, + Union, +) from django.core.exceptions import SuspiciousOperation -from django.db import connections, models, router -from django.db.models import Expression, Q +from django.db import models, router +from django.db.backends.utils import CursorWrapper +from django.db.models import Expression, Q, QuerySet from django.db.models.fields import NOT_PROVIDED +from .expressions import ExcludedCol +from .introspect import model_from_cursor, models_from_cursor from .sql import PostgresInsertQuery, PostgresQuery from .types import ConflictAction -ConflictTarget = List[Union[str, Tuple[str]]] +if TYPE_CHECKING: + from django.db.models.constraints import BaseConstraint + from django.db.models.indexes import Index +ConflictTarget = Union[List[Union[str, Tuple[str]]], "BaseConstraint", "Index"] -class PostgresQuerySet(models.QuerySet): + +TModel = TypeVar("TModel", bound=models.Model, covariant=True) + +if TYPE_CHECKING: + from typing_extensions import Self + + QuerySetBase = QuerySet[TModel] +else: + QuerySetBase = QuerySet + + +def peek_iterator(iterable): + try: + first = next(iterable) + except StopIteration: + return None + return list(chain([first], iterable)) + + +class PostgresQuerySet(QuerySetBase, Generic[TModel]): """Adds support for PostgreSQL specifics.""" def __init__(self, model=None, query=None, using=None, hints=None): @@ -27,8 +63,9 @@ def __init__(self, model=None, query=None, using=None, hints=None): self.conflict_action = None self.conflict_update_condition = None self.index_predicate = None + self.update_values = None - def annotate(self, **annotations): + def annotate(self, **annotations) -> "Self": # type: ignore[valid-type, override] """Custom version of the standard annotate function that allows using field names as annotated fields. @@ -84,6 +121,7 @@ def on_conflict( action: ConflictAction, index_predicate: Optional[Union[Expression, Q, str]] = None, update_condition: Optional[Union[Expression, Q, str]] = None, + update_values: Optional[Dict[str, Union[Any, Expression]]] = None, ): """Sets the action to take when conflicts arise when attempting to insert/create a new row. @@ -101,18 +139,24 @@ def on_conflict( update_condition: Only update if this SQL expression evaluates to true. + + update_values: + Optionally, values/expressions to use when rows + conflict. If not specified, all columns specified + in the rows are updated with the values you specified. """ self.conflict_target = fields self.conflict_action = action self.conflict_update_condition = update_condition self.index_predicate = index_predicate + self.update_values = update_values return self def bulk_insert( self, - rows: List[dict], + rows: Iterable[Dict[str, Any]], return_model: bool = False, using: Optional[str] = None, ): @@ -131,13 +175,20 @@ def bulk_insert( just dicts. using: - Name of the database connection to use for + Optional name of the database connection to use for this query. Returns: A list of either the dicts of the rows inserted, including the pk or the models of the rows inserted with defaults for any fields not specified """ + if rows is None: + return [] + + rows = peek_iterator(iter(rows)) + + if not rows: + return [] if not self.conflict_target and not self.conflict_action: # no special action required, use the standard Django bulk_create(..) @@ -165,14 +216,17 @@ def bulk_insert( deduped_rows.append(row) compiler = self._build_insert_compiler(deduped_rows, using=using) - objs = compiler.execute_sql(return_id=not return_model) - if return_model: - return [ - self._create_model_instance(dict(row, **obj), compiler.using) - for row, obj in zip(deduped_rows, objs) - ] - return [dict(row, **obj) for row, obj in zip(deduped_rows, objs)] + with compiler.connection.cursor() as cursor: + for sql, params in compiler.as_sql(return_id=not return_model): + cursor.execute(sql, params) + + if return_model: + return list(models_from_cursor(self.model, cursor)) + + return self._consume_cursor_as_dicts( + cursor, original_rows=deduped_rows + ) def insert(self, using: Optional[str] = None, **fields): """Creates a new record in the database. @@ -193,14 +247,20 @@ def insert(self, using: Optional[str] = None, **fields): """ if self.conflict_target or self.conflict_action: + if not self.model or not self.model.pk: + return None + compiler = self._build_insert_compiler([fields], using=using) - rows = compiler.execute_sql(return_id=True) - pk_field_name = self.model._meta.pk.name - if not rows or len(rows) == 0: - return None + with compiler.connection.cursor() as cursor: + for sql, params in compiler.as_sql(return_id=True): + cursor.execute(sql, params) + + row = cursor.fetchone() + if not row: + return None - return rows[0][pk_field_name] + return row[0] # no special action required, use the standard Django create(..) return super().create(**fields).pk @@ -228,30 +288,12 @@ def insert_and_get(self, using: Optional[str] = None, **fields): return super().create(**fields) compiler = self._build_insert_compiler([fields], using=using) - rows = compiler.execute_sql(return_id=False) - - if not rows: - return None - - columns = rows[0] - - # get a list of columns that are officially part of the model and - # preserve the fact that the attribute name - # might be different than the database column name - model_columns = {} - for field in self.model._meta.local_concrete_fields: - model_columns[field.column] = field.attname - # strip out any columns/fields returned by the db that - # are not present in the model - model_init_fields = {} - for column_name, column_value in columns.items(): - try: - model_init_fields[model_columns[column_name]] = column_value - except KeyError: - pass + with compiler.connection.cursor() as cursor: + for sql, params in compiler.as_sql(return_id=False): + cursor.execute(sql, params) - return self._create_model_instance(model_init_fields, compiler.using) + return model_from_cursor(self.model, cursor) def upsert( self, @@ -260,6 +302,7 @@ def upsert( index_predicate: Optional[Union[Expression, Q, str]] = None, using: Optional[str] = None, update_condition: Optional[Union[Expression, Q, str]] = None, + update_values: Optional[Dict[str, Union[Any, Expression]]] = None, ) -> int: """Creates a new record or updates the existing one with the specified data. @@ -282,17 +325,27 @@ def upsert( update_condition: Only update if this SQL expression evaluates to true. + update_values: + Optionally, values/expressions to use when rows + conflict. If not specified, all columns specified + in the rows are updated with the values you specified. + Returns: The primary key of the row that was created/updated. """ self.on_conflict( conflict_target, - ConflictAction.UPDATE, + ConflictAction.UPDATE + if (update_condition or update_condition is None) + else ConflictAction.NOTHING, index_predicate=index_predicate, update_condition=update_condition, + update_values=update_values, ) - return self.insert(**fields, using=using) + + kwargs = {**fields, "using": using} + return self.insert(**kwargs) def upsert_and_get( self, @@ -301,6 +354,7 @@ def upsert_and_get( index_predicate: Optional[Union[Expression, Q, str]] = None, using: Optional[str] = None, update_condition: Optional[Union[Expression, Q, str]] = None, + update_values: Optional[Dict[str, Union[Any, Expression]]] = None, ): """Creates a new record or updates the existing one with the specified data and then gets the row. @@ -323,6 +377,11 @@ def upsert_and_get( update_condition: Only update if this SQL expression evaluates to true. + update_values: + Optionally, values/expressions to use when rows + conflict. If not specified, all columns specified + in the rows are updated with the values you specified. + Returns: The model instance representing the row that was created/updated. @@ -333,8 +392,11 @@ def upsert_and_get( ConflictAction.UPDATE, index_predicate=index_predicate, update_condition=update_condition, + update_values=update_values, ) - return self.insert_and_get(**fields, using=using) + + kwargs = {**fields, "using": using} + return self.insert_and_get(**kwargs) def bulk_upsert( self, @@ -344,6 +406,7 @@ def bulk_upsert( return_model: bool = False, using: Optional[str] = None, update_condition: Optional[Union[Expression, Q, str]] = None, + update_values: Optional[Dict[str, Union[Any, Expression]]] = None, ): """Creates a set of new records or updates the existing ones with the specified data. @@ -370,62 +433,43 @@ def bulk_upsert( update_condition: Only update if this SQL expression evaluates to true. + update_values: + Optionally, values/expressions to use when rows + conflict. If not specified, all columns specified + in the rows are updated with the values you specified. + Returns: A list of either the dicts of the rows upserted, including the pk or the models of the rows upserted """ - def is_empty(r): - return all([False for _ in r]) - - if not rows or is_empty(rows): - return [] - self.on_conflict( conflict_target, ConflictAction.UPDATE, index_predicate=index_predicate, update_condition=update_condition, + update_values=update_values, ) - return self.bulk_insert(rows, return_model, using=using) - - def _create_model_instance( - self, field_values: dict, using: str, apply_converters: bool = True - ): - """Creates a new instance of the model with the specified field. - Use this after the row was inserted into the database. The new - instance will marked as "saved". - """ - - converted_field_values = field_values.copy() - - if apply_converters: - connection = connections[using] - - for field in self.model._meta.local_concrete_fields: - if field.attname not in converted_field_values: - continue - - # converters can be defined on the field, or by - # the database back-end we're using - field_column = field.get_col(self.model._meta.db_table) - converters = field.get_db_converters( - connection - ) + connection.ops.get_db_converters(field_column) - - for converter in converters: - converted_field_values[field.attname] = converter( - converted_field_values[field.attname], - field_column, - connection, - ) - - instance = self.model(**converted_field_values) - instance._state.db = using - instance._state.adding = False + return self.bulk_insert(rows, return_model, using=using) - return instance + @staticmethod + def _consume_cursor_as_dicts( + cursor: CursorWrapper, *, original_rows: Iterable[Dict[str, Any]] + ) -> List[dict]: + cursor_description = cursor.description + + return [ + { + **original_row, + **{ + column.name: row[column_index] + for column_index, column in enumerate(cursor_description) + if row + }, + } + for original_row, row in zip(original_rows, cursor) + ] def _build_insert_compiler( self, rows: Iterable[Dict], using: Optional[str] = None @@ -447,7 +491,7 @@ def _build_insert_compiler( # ask the db router which connection to use using = ( - using or self._db or router.db_for_write(self.model, **self._hints) + using or self._db or router.db_for_write(self.model, **self._hints) # type: ignore[attr-defined] ) # create model objects, we also have to detect cases @@ -469,12 +513,17 @@ def _build_insert_compiler( ).format(index) ) - objs.append( - self._create_model_instance(row, using, apply_converters=False) - ) + obj = self.model(**row.copy()) + obj._state.db = using + obj._state.adding = False + objs.append(obj) # get the fields to be used during update/insert - insert_fields, update_fields = self._get_upsert_fields(first_row) + insert_fields, update_values = self._get_upsert_fields(first_row) + + # allow the user to override what should happen on update + if self.update_values is not None: + update_values = self.update_values # build a normal insert query query = PostgresInsertQuery(self.model) @@ -482,7 +531,7 @@ def _build_insert_compiler( query.conflict_target = self.conflict_target query.conflict_update_condition = self.conflict_update_condition query.index_predicate = self.index_predicate - query.values(objs, insert_fields, update_fields) + query.insert_on_conflict_values(objs, insert_fields, update_values) compiler = query.get_compiler(using) return compiler @@ -547,13 +596,13 @@ def _get_upsert_fields(self, kwargs): model_instance = self.model(**kwargs) insert_fields = [] - update_fields = [] + update_values = {} for field in model_instance._meta.local_concrete_fields: has_default = field.default != NOT_PROVIDED if field.name in kwargs or field.column in kwargs: insert_fields.append(field) - update_fields.append(field) + update_values[field.name] = ExcludedCol(field) continue elif has_default: insert_fields.append(field) @@ -564,13 +613,13 @@ def _get_upsert_fields(self, kwargs): # instead of a concrete field, we have to handle that if field.primary_key is True and "pk" in kwargs: insert_fields.append(field) - update_fields.append(field) + update_values[field.name] = ExcludedCol(field) continue if self._is_magical_field(model_instance, field, is_insert=True): insert_fields.append(field) if self._is_magical_field(model_instance, field, is_insert=False): - update_fields.append(field) + update_values[field.name] = ExcludedCol(field) - return insert_fields, update_fields + return insert_fields, update_values diff --git a/psqlextra/schema.py b/psqlextra/schema.py new file mode 100644 index 00000000..9edb83bd --- /dev/null +++ b/psqlextra/schema.py @@ -0,0 +1,227 @@ +import os + +from contextlib import contextmanager +from typing import TYPE_CHECKING, Generator, cast + +from django.core.exceptions import SuspiciousOperation, ValidationError +from django.db import DEFAULT_DB_ALIAS, connections, transaction +from django.utils import timezone + +if TYPE_CHECKING: + from psqlextra.backend.introspection import PostgresIntrospection + from psqlextra.backend.schema import PostgresSchemaEditor + + +class PostgresSchema: + """Represents a Postgres schema. + + See: https://www.postgresql.org/docs/current/ddl-schemas.html + """ + + NAME_MAX_LENGTH = 63 + + name: str + + default: "PostgresSchema" + + def __init__(self, name: str) -> None: + self.name = name + + @classmethod + def create( + cls, name: str, *, using: str = DEFAULT_DB_ALIAS + ) -> "PostgresSchema": + """Creates a new schema with the specified name. + + This throws if the schema already exists as that is most likely + a problem that requires careful handling. Pretending everything + is ok might cause the caller to overwrite data, thinking it got + a empty schema. + + Arguments: + name: + The name to give to the new schema (max 63 characters). + + using: + Optional name of the database connection to use. + """ + + if len(name) > cls.NAME_MAX_LENGTH: + raise ValidationError( + f"Schema name '{name}' is longer than Postgres's limit of {cls.NAME_MAX_LENGTH} characters" + ) + + with connections[using].schema_editor() as schema_editor: + cast("PostgresSchemaEditor", schema_editor).create_schema(name) + + return cls(name) + + @classmethod + def create_time_based( + cls, prefix: str, *, using: str = DEFAULT_DB_ALIAS + ) -> "PostgresSchema": + """Creates a new schema with a time-based suffix. + + The time is precise up to the second. Creating + multiple time based schema in the same second + WILL lead to conflicts. + + Arguments: + prefix: + Name to prefix the final name with. The name plus + prefix cannot be longer than 63 characters. + + using: + Name of the database connection to use. + """ + + suffix = timezone.now().strftime("%Y%m%d%H%m%S") + name = cls._create_generated_name(prefix, suffix) + + return cls.create(name, using=using) + + @classmethod + def create_random( + cls, prefix: str, *, using: str = DEFAULT_DB_ALIAS + ) -> "PostgresSchema": + """Creates a new schema with a random suffix. + + Arguments: + prefix: + Name to prefix the final name with. The name plus + prefix cannot be longer than 63 characters. + + using: + Name of the database connection to use. + """ + + suffix = os.urandom(4).hex() + name = cls._create_generated_name(prefix, suffix) + + return cls.create(name, using=using) + + @classmethod + def delete_and_create( + cls, name: str, *, cascade: bool = False, using: str = DEFAULT_DB_ALIAS + ) -> "PostgresSchema": + """Deletes the schema if it exists before re-creating it. + + Arguments: + name: + Name of the schema to delete+create (max 63 characters). + + cascade: + Whether to delete the contents of the schema + and anything that references it if it exists. + + using: + Optional name of the database connection to use. + """ + + with transaction.atomic(using=using): + cls(name).delete(cascade=cascade, using=using) + return cls.create(name, using=using) + + @classmethod + def exists(cls, name: str, *, using: str = DEFAULT_DB_ALIAS) -> bool: + """Gets whether a schema with the specified name exists. + + Arguments: + name: + Name of the schema to check of whether it + exists. + + using: + Optional name of the database connection to use. + """ + + connection = connections[using] + + with connection.cursor() as cursor: + return name in cast( + "PostgresIntrospection", connection.introspection + ).get_schema_list(cursor) + + def delete( + self, *, cascade: bool = False, using: str = DEFAULT_DB_ALIAS + ) -> None: + """Deletes the schema and optionally deletes the contents of the schema + and anything that references it. + + Arguments: + cascade: + Cascade the delete to the contents of the schema + and anything that references it. + + If not set, the schema will refuse to be deleted + unless it is empty and there are not remaining + references. + """ + + if self.name == "public": + raise SuspiciousOperation( + "Pretty sure you are about to make a mistake by trying to drop the 'public' schema. I have stopped you. Thank me later." + ) + + with connections[using].schema_editor() as schema_editor: + cast("PostgresSchemaEditor", schema_editor).delete_schema( + self.name, cascade=cascade + ) + + @classmethod + def _create_generated_name(cls, prefix: str, suffix: str) -> str: + separator = "_" + generated_name = f"{prefix}{separator}{suffix}" + max_prefix_length = cls.NAME_MAX_LENGTH - len(suffix) - len(separator) + + if len(generated_name) > cls.NAME_MAX_LENGTH: + raise ValidationError( + f"Schema prefix '{prefix}' is longer than {max_prefix_length} characters. Together with the separator and generated suffix of {len(suffix)} characters, the name would exceed Postgres's limit of {cls.NAME_MAX_LENGTH} characters." + ) + + return generated_name + + +PostgresSchema.default = PostgresSchema("public") + + +@contextmanager +def postgres_temporary_schema( + prefix: str, + *, + cascade: bool = False, + delete_on_throw: bool = False, + using: str = DEFAULT_DB_ALIAS, +) -> Generator[PostgresSchema, None, None]: + """Creates a temporary schema that only lives in the context of this + context manager. + + Arguments: + prefix: + Name to prefix the final name with. + + cascade: + Whether to cascade the delete when dropping the + schema. If enabled, the contents of the schema + are deleted as well as anything that references + the schema. + + delete_on_throw: + Whether to automatically drop the schema if + any error occurs within the context manager. + + using: + Optional name of the database connection to use. + """ + + schema = PostgresSchema.create_random(prefix, using=using) + + try: + yield schema + except Exception as e: + if delete_on_throw: + schema.delete(cascade=cascade, using=using) + + raise e + + schema.delete(cascade=cascade, using=using) diff --git a/psqlextra/settings.py b/psqlextra/settings.py new file mode 100644 index 00000000..6f75c779 --- /dev/null +++ b/psqlextra/settings.py @@ -0,0 +1,120 @@ +from contextlib import contextmanager +from typing import Generator, List, Optional, Union + +from django.core.exceptions import SuspiciousOperation +from django.db import DEFAULT_DB_ALIAS, connections + + +@contextmanager +def postgres_set_local( + *, + using: str = DEFAULT_DB_ALIAS, + **options: Optional[Union[str, int, float, List[str]]], +) -> Generator[None, None, None]: + """Sets the specified PostgreSQL options using SET LOCAL so that they apply + to the current transacton only. + + The effect is undone when the context manager exits. + + See https://www.postgresql.org/docs/current/runtime-config-client.html + for an overview of all available options. + """ + + connection = connections[using] + qn = connection.ops.quote_name + + if not connection.in_atomic_block: + raise SuspiciousOperation( + "SET LOCAL makes no sense outside a transaction. Start a transaction first." + ) + + sql = [] + params: List[Union[str, int, float, List[str]]] = [] + for name, value in options.items(): + if value is None: + sql.append(f"SET LOCAL {qn(name)} TO DEFAULT") + continue + + # Settings that accept a list of values are actually + # stored as string lists. We cannot just pass a list + # of values. We have to create the comma separated + # string ourselves. + if isinstance(value, list) or isinstance(value, tuple): + placeholder = ", ".join(["%s" for _ in value]) + params.extend(value) + else: + placeholder = "%s" + params.append(value) + + sql.append(f"SET LOCAL {qn(name)} = {placeholder}") + + with connection.cursor() as cursor: + cursor.execute( + "SELECT name, setting FROM pg_settings WHERE name = ANY(%s)", + (list(options.keys()),), + ) + original_values = dict(cursor.fetchall()) + cursor.execute("; ".join(sql), params) + + yield + + # Put everything back to how it was. DEFAULT is + # not good enough as a outer SET LOCAL might + # have set a different value. + with connection.cursor() as cursor: + sql = [] + params = [] + + for name, value in options.items(): + original_value = original_values.get(name) + if original_value: + sql.append(f"SET LOCAL {qn(name)} = {original_value}") + else: + sql.append(f"SET LOCAL {qn(name)} TO DEFAULT") + + cursor.execute("; ".join(sql), params) + + +@contextmanager +def postgres_set_local_search_path( + search_path: List[str], *, using: str = DEFAULT_DB_ALIAS +) -> Generator[None, None, None]: + """Sets the search path to the specified schemas.""" + + with postgres_set_local(search_path=search_path, using=using): + yield + + +@contextmanager +def postgres_prepend_local_search_path( + search_path: List[str], *, using: str = DEFAULT_DB_ALIAS +) -> Generator[None, None, None]: + """Prepends the current local search path with the specified schemas.""" + + connection = connections[using] + + with connection.cursor() as cursor: + cursor.execute("SHOW search_path") + [ + original_search_path, + ] = cursor.fetchone() + + placeholders = ", ".join(["%s" for _ in search_path]) + cursor.execute( + f"SET LOCAL search_path = {placeholders}, {original_search_path}", + tuple(search_path), + ) + + yield + + cursor.execute(f"SET LOCAL search_path = {original_search_path}") + + +@contextmanager +def postgres_reset_local_search_path( + *, using: str = DEFAULT_DB_ALIAS +) -> Generator[None, None, None]: + """Resets the local search path to the default.""" + + with postgres_set_local(search_path=None, using=using): + yield diff --git a/psqlextra/sql.py b/psqlextra/sql.py index 7f624623..cf12d8c1 100644 --- a/psqlextra/sql.py +++ b/psqlextra/sql.py @@ -1,20 +1,25 @@ from collections import OrderedDict -from typing import List, Optional, Tuple +from collections.abc import Iterable +from typing import Any, Dict, List, Optional, Tuple, Union import django from django.core.exceptions import SuspiciousOperation from django.db import connections, models -from django.db.models import sql +from django.db.models import Expression, sql from django.db.models.constants import LOOKUP_SEP +from django.db.models.expressions import Ref -from .compiler import PostgresInsertCompiler, PostgresUpdateCompiler +from .compiler import PostgresInsertOnConflictCompiler +from .compiler import SQLUpdateCompiler as PostgresUpdateCompiler from .expressions import HStoreColumn from .fields import HStoreField from .types import ConflictAction class PostgresQuery(sql.Query): + select: Tuple[Expression, ...] + def chain(self, klass=None): """Chains this query to another. @@ -61,13 +66,28 @@ def rename_annotations(self, annotations) -> None: new_annotations[new_name or old_name] = annotation if new_name and self.annotation_select_mask: - self.annotation_select_mask.discard(old_name) - self.annotation_select_mask.add(new_name) + # It's a set in all versions prior to Django 5.x + # and a list in Django 5.x and newer. + # https://github.com/django/django/commit/d6b6e5d0fd4e6b6d0183b4cf6e4bd4f9afc7bf67 + if isinstance(self.annotation_select_mask, set): + self.annotation_select_mask.discard(old_name) + self.annotation_select_mask.add(new_name) + elif isinstance(self.annotation_select_mask, list): + self.annotation_select_mask.remove(old_name) + self.annotation_select_mask.append(new_name) + + if isinstance(self.group_by, Iterable): + for statement in self.group_by: + if not isinstance(statement, Ref): + continue + + if statement.refs in annotations: # type: ignore[attr-defined] + statement.refs = annotations[statement.refs] # type: ignore[attr-defined] self.annotations.clear() self.annotations.update(new_annotations) - def add_fields(self, field_names: List[str], *args, **kwargs) -> None: + def add_fields(self, field_names, *args, **kwargs) -> None: """Adds the given (model) fields to the select set. The field names are added in the order specified. This overrides @@ -99,10 +119,11 @@ def add_fields(self, field_names: List[str], *args, **kwargs) -> None: if len(parts) > 1: column_name, hstore_key = parts[:2] is_hstore, field = self._is_hstore_field(column_name) - if is_hstore: + if self.model and is_hstore: select.append( HStoreColumn( - self.model._meta.db_table or self.model.name, + self.model._meta.db_table + or self.model.__class__.__name__, field, hstore_key, ) @@ -114,7 +135,7 @@ def add_fields(self, field_names: List[str], *args, **kwargs) -> None: super().add_fields(field_names_without_hstore, *args, **kwargs) if len(select) > 0: - self.set_select(self.select + tuple(select)) + self.set_select(list(self.select + tuple(select))) def _is_hstore_field( self, field_name: str @@ -126,8 +147,11 @@ def _is_hstore_field( instance. """ + if not self.model: + return (False, None) + field_instance = None - for field in self.model._meta.local_concrete_fields: + for field in self.model._meta.local_concrete_fields: # type: ignore[attr-defined] if field.name == field_name or field.column == field_name: field_instance = field break @@ -147,10 +171,14 @@ def __init__(self, *args, **kwargs): self.conflict_action = ConflictAction.UPDATE self.conflict_update_condition = None self.index_predicate = None - - self.update_fields = [] - - def values(self, objs: List, insert_fields: List, update_fields: List = []): + self.update_values = {} + + def insert_on_conflict_values( + self, + objs: List, + insert_fields: List, + update_values: Dict[str, Union[Any, Expression]] = {}, + ): """Sets the values to be used in this query. Insert fields are fields that are definitely @@ -169,17 +197,18 @@ def values(self, objs: List, insert_fields: List, update_fields: List = []): insert_fields: The fields to use in the INSERT statement - update_fields: - The fields to only use in the UPDATE statement. + update_values: + Expressions/values to use when a conflict + occurs and an UPDATE is performed. """ self.insert_values(insert_fields, objs, raw=False) - self.update_fields = update_fields + self.update_values = update_values def get_compiler(self, using=None, connection=None): if using: connection = connections[using] - return PostgresInsertCompiler(self, connection, using) + return PostgresInsertOnConflictCompiler(self, connection, using) class PostgresUpdateQuery(sql.UpdateQuery): diff --git a/psqlextra/type_assertions.py b/psqlextra/type_assertions.py index 0a7e8608..e18d13be 100644 --- a/psqlextra/type_assertions.py +++ b/psqlextra/type_assertions.py @@ -7,7 +7,7 @@ def is_query_set(value: Any) -> bool: """Gets whether the specified value is a :see:QuerySet.""" - return isinstance(value, QuerySet) + return isinstance(value, QuerySet) # type: ignore[misc] def is_sql(value: Any) -> bool: diff --git a/psqlextra/types.py b/psqlextra/types.py index a325fd9e..f1118075 100644 --- a/psqlextra/types.py +++ b/psqlextra/types.py @@ -28,6 +28,9 @@ class ConflictAction(Enum): def all(cls) -> List["ConflictAction"]: return [choice for choice in cls] + def __str__(self) -> str: + return self.value + class PostgresPartitioningMethod(StrEnum): """Methods of partitioning supported by PostgreSQL 11.x native support for diff --git a/psqlextra/util.py b/psqlextra/util.py index edc4e955..d0bca000 100644 --- a/psqlextra/util.py +++ b/psqlextra/util.py @@ -1,10 +1,15 @@ from contextlib import contextmanager +from typing import Generator, Type + +from django.db import models from .manager import PostgresManager @contextmanager -def postgres_manager(model): +def postgres_manager( + model: Type[models.Model], +) -> Generator[PostgresManager, None, None]: """Allows you to use the :see:PostgresManager with the specified model instance on the fly. diff --git a/pyproject.toml b/pyproject.toml index 126ae9a3..fb35b3b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,3 +10,18 @@ exclude = ''' )/ ) ''' + +[tool.mypy] +python_version = "3.8" +plugins = ["mypy_django_plugin.main"] +mypy_path = ["stubs", "."] +exclude = "(env|build|dist|migrations)" + +[[tool.mypy.overrides]] +module = [ + "psycopg.*" +] +ignore_missing_imports = true + +[tool.django-stubs] +django_settings_module = "settings" diff --git a/settings.py b/settings.py index ed0d0f98..7266ccb4 100644 --- a/settings.py +++ b/settings.py @@ -11,7 +11,7 @@ 'default': dj_database_url.config(default='postgres:///psqlextra'), } -DATABASES['default']['ENGINE'] = 'psqlextra.backend' +DATABASES['default']['ENGINE'] = 'tests.psqlextra_test_backend' LANGUAGE_CODE = 'en' LANGUAGES = ( @@ -24,3 +24,6 @@ 'psqlextra', 'tests', ) + +USE_TZ = True +TIME_ZONE = 'UTC' diff --git a/setup.py b/setup.py index 365f5861..918beb87 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,8 @@ from setuptools import find_packages, setup +exec(open("psqlextra/_version.py").read()) + class BaseCommand(distutils.cmd.Command): user_options = [] @@ -36,8 +38,9 @@ def run(self): setup( name="django-postgres-extra", - version="2.0.4rc2", + version=__version__, packages=find_packages(exclude=["tests"]), + package_data={"psqlextra": ["py.typed"]}, include_package_data=True, license="MIT License", description="Bringing all of PostgreSQL's awesomeness to Django.", @@ -58,14 +61,15 @@ def run(self): "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Topic :: Internet :: WWW/HTTP", "Topic :: Internet :: WWW/HTTP :: Dynamic Content", ], python_requires=">=3.6", install_requires=[ - "Django>=2.0", + "Django>=2.0,<6.0", "python-dateutil>=2.8.0,<=3.0.0", - "ansimarkup>=1.4.0,<=2.0.0", ], extras_require={ ':python_version <= "3.6"': ["dataclasses"], @@ -77,23 +81,47 @@ def run(self): "pytest-benchmark==3.4.1", "pytest-django==4.4.0", "pytest-cov==3.0.0", + "pytest-lazy-fixture==0.6.3", + "pytest-freezegun==0.4.2", "tox==3.24.4", "freezegun==1.1.0", "coveralls==3.3.0", "snapshottest==0.6.0", ], "analysis": [ - "black==21.10b0", + "black==22.3.0", "flake8==4.0.1", "autoflake==1.4", "autopep8==1.6.0", "isort==5.10.0", "docformatter==1.4", + "mypy==1.2.0; python_version > '3.6'", + "mypy==0.971; python_version <= '3.6'", + "django-stubs==4.2.7; python_version > '3.6'", + "django-stubs==1.9.0; python_version <= '3.6'", + "typing-extensions==4.5.0; python_version > '3.6'", + "typing-extensions==4.1.0; python_version <= '3.6'", + "types-dj-database-url==1.3.0.0", + "types-psycopg2==2.9.21.9", + "types-python-dateutil==2.8.19.12", + ], + "publish": [ + "build==0.7.0", + "twine==3.7.1", ], }, cmdclass={ "lint": create_command( - "Lints the code", [["flake8", "setup.py", "psqlextra", "tests"]] + "Lints the code", + [ + [ + "flake8", + "--builtin=__version__", + "setup.py", + "psqlextra", + "tests", + ] + ], ), "lint_fix": create_command( "Lints the code", @@ -110,6 +138,18 @@ def run(self): ["autopep8", "-i", "-r", "setup.py", "psqlextra", "tests"], ], ), + "lint_types": create_command( + "Type-checks the code", + [ + [ + "mypy", + "--package", + "psqlextra", + "--pretty", + "--show-error-codes", + ], + ], + ), "format": create_command( "Formats the code", [["black", "setup.py", "psqlextra", "tests"]] ), @@ -148,6 +188,7 @@ def run(self): ["python", "setup.py", "sort_imports"], ["python", "setup.py", "lint_fix"], ["python", "setup.py", "lint"], + ["python", "setup.py", "lint_types"], ], ), "verify": create_command( @@ -157,6 +198,7 @@ def run(self): ["python", "setup.py", "format_docstrings_verify"], ["python", "setup.py", "sort_imports_verify"], ["python", "setup.py", "lint"], + ["python", "setup.py", "lint_types"], ], ), "test": create_command( diff --git a/tests/conftest.py b/tests/conftest.py index f90692af..387edd3b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,7 +29,7 @@ def fake_app(): def postgres_server_version(db) -> int: """Gets the PostgreSQL server version.""" - return connection.cursor().connection.server_version + return connection.cursor().connection.info.server_version @pytest.fixture(autouse=True) diff --git a/tests/db_introspection.py b/tests/db_introspection.py index bdcd4b19..285cd0e4 100644 --- a/tests/db_introspection.py +++ b/tests/db_introspection.py @@ -4,38 +4,100 @@ This makes test code less verbose and easier to read/write. """ +from contextlib import contextmanager +from typing import Optional + from django.db import connection +from psqlextra.settings import postgres_set_local + + +@contextmanager +def introspect(schema_name: Optional[str] = None): + with postgres_set_local(search_path=schema_name or None): + with connection.cursor() as cursor: + yield connection.introspection, cursor -def table_names(include_views: bool = True): + +def table_names( + include_views: bool = True, *, schema_name: Optional[str] = None +): """Gets a flat list of tables in the default database.""" - with connection.cursor() as cursor: - introspection = connection.introspection + with introspect(schema_name) as (introspection, cursor): return introspection.table_names(cursor, include_views) -def get_partitioned_table(table_name: str): +def get_partitioned_table( + table_name: str, + *, + schema_name: Optional[str] = None, +): """Gets the definition of a partitioned table in the default database.""" - with connection.cursor() as cursor: - introspection = connection.introspection + with introspect(schema_name) as (introspection, cursor): return introspection.get_partitioned_table(cursor, table_name) -def get_partitions(table_name: str): +def get_partitions( + table_name: str, + *, + schema_name: Optional[str] = None, +): """Gets a list of partitions for the specified partitioned table in the default database.""" - with connection.cursor() as cursor: - introspection = connection.introspection + with introspect(schema_name) as (introspection, cursor): return introspection.get_partitions(cursor, table_name) -def get_constraints(table_name: str): - """Gets a complete list of constraints and indexes for the specified - table.""" +def get_columns( + table_name: str, + *, + schema_name: Optional[str] = None, +): + """Gets a list of columns for the specified table.""" + + with introspect(schema_name) as (introspection, cursor): + return introspection.get_columns(cursor, table_name) + + +def get_relations( + table_name: str, + *, + schema_name: Optional[str] = None, +): + """Gets a list of relations for the specified table.""" + + with introspect(schema_name) as (introspection, cursor): + return introspection.get_relations(cursor, table_name) - with connection.cursor() as cursor: - introspection = connection.introspection + +def get_constraints( + table_name: str, + *, + schema_name: Optional[str] = None, +): + """Gets a list of constraints and indexes for the specified table.""" + + with introspect(schema_name) as (introspection, cursor): return introspection.get_constraints(cursor, table_name) + + +def get_sequences( + table_name: str, + *, + schema_name: Optional[str] = None, +): + """Gets a list of sequences own by the specified table.""" + + with introspect(schema_name) as (introspection, cursor): + return introspection.get_sequences(cursor, table_name) + + +def get_storage_settings(table_name: str, *, schema_name: Optional[str] = None): + """Gets a list of all storage settings that have been set on the specified + table.""" + + with introspect(schema_name) as (introspection, cursor): + return introspection.get_storage_settings(cursor, table_name) diff --git a/tests/fake_model.py b/tests/fake_model.py index 1254e762..ec626f3a 100644 --- a/tests/fake_model.py +++ b/tests/fake_model.py @@ -3,9 +3,10 @@ import uuid from contextlib import contextmanager +from typing import Type from django.apps import AppConfig, apps -from django.db import connection +from django.db import connection, models from psqlextra.models import ( PostgresMaterializedViewModel, @@ -39,6 +40,17 @@ def define_fake_model( return model +def undefine_fake_model(model: Type[models.Model]) -> None: + """Removes the fake model from the app registry.""" + + app_label = model._meta.app_label or "tests" + app_models = apps.app_configs[app_label].models + + for model_name in [model.__name__, model.__name__.lower()]: + if model_name in app_models: + del app_models[model_name] + + def define_fake_view_model( fields=None, view_options={}, meta_options={}, model_base=PostgresViewModel ): @@ -115,6 +127,15 @@ def get_fake_model(fields=None, model_base=PostgresModel, meta_options={}): return model +def delete_fake_model(model: Type[models.Model]) -> None: + """Deletes a fake model from the database and the internal app registry.""" + + undefine_fake_model(model) + + with connection.schema_editor() as schema_editor: + schema_editor.delete_model(model) + + @contextmanager def define_fake_app(): """Creates and registers a fake Django app.""" diff --git a/tests/psqlextra_test_backend/__init__.py b/tests/psqlextra_test_backend/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/psqlextra_test_backend/base.py b/tests/psqlextra_test_backend/base.py new file mode 100644 index 00000000..0961a2bc --- /dev/null +++ b/tests/psqlextra_test_backend/base.py @@ -0,0 +1,23 @@ +from datetime import timezone + +import django + +from django.conf import settings + +from psqlextra.backend.base import DatabaseWrapper as PSQLExtraDatabaseWrapper + + +class DatabaseWrapper(PSQLExtraDatabaseWrapper): + # Works around the compatibility issue of Django <3.0 and psycopg2.9 + # in combination with USE_TZ + # + # See: https://github.com/psycopg/psycopg2/issues/1293#issuecomment-862835147 + if django.VERSION < (3, 1): + + def create_cursor(self, name=None): + cursor = super().create_cursor(name) + cursor.tzinfo_factory = ( + lambda offset: timezone.utc if settings.USE_TZ else None + ) + + return cursor diff --git a/tests/test_append_caller_to_sql.py b/tests/test_append_caller_to_sql.py new file mode 100644 index 00000000..50ca0b5e --- /dev/null +++ b/tests/test_append_caller_to_sql.py @@ -0,0 +1,81 @@ +import pytest + +from django.db import connection, models +from django.test.utils import CaptureQueriesContext, override_settings + +from psqlextra.compiler import append_caller_to_sql + +from .fake_model import get_fake_model + + +class psqlextraSimulated: + def callMockedClass(self): + return MockedClass().mockedMethod() + + +class MockedClass: + def mockedMethod(self): + return append_caller_to_sql("sql") + + +def mockedFunction(): + return append_caller_to_sql("sql") + + +@override_settings(POSTGRES_EXTRA_ANNOTATE_SQL=False) +def test_disable_append_caller_to_sql(): + commented_sql = mockedFunction() + assert commented_sql == "sql" + + +@pytest.mark.parametrize( + "entry_point", + [ + MockedClass().mockedMethod, + psqlextraSimulated().callMockedClass, + ], +) +@override_settings(POSTGRES_EXTRA_ANNOTATE_SQL=True) +def test_append_caller_to_sql_class(entry_point): + commented_sql = entry_point() + assert commented_sql.startswith("sql /* ") + assert "mockedMethod" in commented_sql + assert __file__ in commented_sql + + +@override_settings(POSTGRES_EXTRA_ANNOTATE_SQL=True) +def test_append_caller_to_sql_function(): + commented_sql = mockedFunction() + assert commented_sql.startswith("sql /* ") + assert "mockedFunction" in commented_sql + assert __file__ in commented_sql + + +@override_settings(POSTGRES_EXTRA_ANNOTATE_SQL=True) +def test_append_caller_to_sql_crud(): + model = get_fake_model( + { + "title": models.CharField(max_length=255, null=True), + } + ) + + obj = None + with CaptureQueriesContext(connection) as queries: + obj = model.objects.create( + id=1, + title="Test", + ) + assert "test_append_caller_to_sql_crud " in queries[0]["sql"] + + obj.title = "success" + with CaptureQueriesContext(connection) as queries: + obj.save() + assert "test_append_caller_to_sql_crud " in queries[0]["sql"] + + with CaptureQueriesContext(connection) as queries: + assert model.objects.filter(id=obj.id)[0].id == obj.id + assert "test_append_caller_to_sql_crud " in queries[0]["sql"] + + with CaptureQueriesContext(connection) as queries: + obj.delete() + assert "test_append_caller_to_sql_crud " in queries[0]["sql"] diff --git a/tests/test_introspect.py b/tests/test_introspect.py new file mode 100644 index 00000000..5e5a9ffc --- /dev/null +++ b/tests/test_introspect.py @@ -0,0 +1,462 @@ +import django +import pytest + +from django.contrib.postgres.fields import ArrayField +from django.db import connection, models +from django.test.utils import CaptureQueriesContext +from django.utils import timezone + +from psqlextra.introspect import model_from_cursor, models_from_cursor + +from .fake_model import get_fake_model + +django_31_skip_reason = "Django < 3.1 does not support JSON fields which are required for these tests" + + +@pytest.fixture +def mocked_model_varying_fields(): + return get_fake_model( + { + "title": models.TextField(null=True), + "updated_at": models.DateTimeField(null=True), + "content": models.JSONField(null=True), + "items": ArrayField(models.TextField(), null=True), + } + ) + + +@pytest.fixture +def mocked_model_single_field(): + return get_fake_model( + { + "name": models.TextField(), + } + ) + + +@pytest.fixture +def mocked_model_foreign_keys( + mocked_model_varying_fields, mocked_model_single_field +): + return get_fake_model( + { + "varying_fields": models.ForeignKey( + mocked_model_varying_fields, null=True, on_delete=models.CASCADE + ), + "single_field": models.ForeignKey( + mocked_model_single_field, null=True, on_delete=models.CASCADE + ), + } + ) + + +@pytest.fixture +def mocked_model_varying_fields_instance(freezer, mocked_model_varying_fields): + return mocked_model_varying_fields.objects.create( + title="hello world", + updated_at=timezone.now(), + content={"a": 1}, + items=["a", "b"], + ) + + +@pytest.fixture +def models_from_cursor_wrapper_multiple(): + def _wrapper(*args, **kwargs): + return list(models_from_cursor(*args, **kwargs))[0] + + return _wrapper + + +@pytest.fixture +def models_from_cursor_wrapper_single(): + return model_from_cursor + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +@pytest.mark.parametrize( + "models_from_cursor_wrapper", + [ + pytest.lazy_fixture("models_from_cursor_wrapper_multiple"), + pytest.lazy_fixture("models_from_cursor_wrapper_single"), + ], +) +def test_models_from_cursor_applies_converters( + mocked_model_varying_fields, + mocked_model_varying_fields_instance, + models_from_cursor_wrapper, +): + with connection.cursor() as cursor: + cursor.execute( + *mocked_model_varying_fields.objects.all().query.sql_with_params() + ) + queried_instance = models_from_cursor_wrapper( + mocked_model_varying_fields, cursor + ) + + assert queried_instance.id == mocked_model_varying_fields_instance.id + assert queried_instance.title == mocked_model_varying_fields_instance.title + assert ( + queried_instance.updated_at + == mocked_model_varying_fields_instance.updated_at + ) + assert ( + queried_instance.content == mocked_model_varying_fields_instance.content + ) + assert queried_instance.items == mocked_model_varying_fields_instance.items + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +@pytest.mark.parametrize( + "models_from_cursor_wrapper", + [ + pytest.lazy_fixture("models_from_cursor_wrapper_multiple"), + pytest.lazy_fixture("models_from_cursor_wrapper_single"), + ], +) +def test_models_from_cursor_handles_field_order( + mocked_model_varying_fields, + mocked_model_varying_fields_instance, + models_from_cursor_wrapper, +): + with connection.cursor() as cursor: + cursor.execute( + f'SELECT content, items, id, title, updated_at FROM "{mocked_model_varying_fields._meta.db_table}"', + tuple(), + ) + queried_instance = models_from_cursor_wrapper( + mocked_model_varying_fields, cursor + ) + + assert queried_instance.id == mocked_model_varying_fields_instance.id + assert queried_instance.title == mocked_model_varying_fields_instance.title + assert ( + queried_instance.updated_at + == mocked_model_varying_fields_instance.updated_at + ) + assert ( + queried_instance.content == mocked_model_varying_fields_instance.content + ) + assert queried_instance.items == mocked_model_varying_fields_instance.items + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +@pytest.mark.parametrize( + "models_from_cursor_wrapper", + [ + pytest.lazy_fixture("models_from_cursor_wrapper_multiple"), + pytest.lazy_fixture("models_from_cursor_wrapper_single"), + ], +) +def test_models_from_cursor_handles_partial_fields( + mocked_model_varying_fields, + mocked_model_varying_fields_instance, + models_from_cursor_wrapper, +): + with connection.cursor() as cursor: + cursor.execute( + f'SELECT id FROM "{mocked_model_varying_fields._meta.db_table}"', + tuple(), + ) + queried_instance = models_from_cursor_wrapper( + mocked_model_varying_fields, cursor + ) + + assert queried_instance.id == mocked_model_varying_fields_instance.id + assert queried_instance.title is None + assert queried_instance.updated_at is None + assert queried_instance.content is None + assert queried_instance.items is None + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +@pytest.mark.parametrize( + "models_from_cursor_wrapper", + [ + pytest.lazy_fixture("models_from_cursor_wrapper_multiple"), + pytest.lazy_fixture("models_from_cursor_wrapper_single"), + ], +) +def test_models_from_cursor_handles_null( + mocked_model_varying_fields, models_from_cursor_wrapper +): + instance = mocked_model_varying_fields.objects.create() + + with connection.cursor() as cursor: + cursor.execute( + *mocked_model_varying_fields.objects.all().query.sql_with_params() + ) + queried_instance = models_from_cursor_wrapper( + mocked_model_varying_fields, cursor + ) + + assert queried_instance.id == instance.id + assert queried_instance.title is None + assert queried_instance.updated_at is None + assert queried_instance.content is None + assert queried_instance.items is None + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +@pytest.mark.parametrize( + "models_from_cursor_wrapper", + [ + pytest.lazy_fixture("models_from_cursor_wrapper_multiple"), + pytest.lazy_fixture("models_from_cursor_wrapper_single"), + ], +) +def test_models_from_cursor_foreign_key( + mocked_model_single_field, + mocked_model_foreign_keys, + models_from_cursor_wrapper, +): + instance = mocked_model_foreign_keys.objects.create( + varying_fields=None, + single_field=mocked_model_single_field.objects.create(name="test"), + ) + + with connection.cursor() as cursor: + cursor.execute( + *mocked_model_foreign_keys.objects.all().query.sql_with_params() + ) + queried_instance = models_from_cursor_wrapper( + mocked_model_foreign_keys, cursor + ) + + with CaptureQueriesContext(connection) as ctx: + assert queried_instance.id == instance.id + assert queried_instance.varying_fields_id is None + assert queried_instance.varying_fields is None + assert queried_instance.single_field_id == instance.single_field_id + assert queried_instance.single_field.id == instance.single_field.id + assert queried_instance.single_field.name == instance.single_field.name + + assert len(ctx.captured_queries) == 1 + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +@pytest.mark.parametrize( + "models_from_cursor_wrapper", + [ + pytest.lazy_fixture("models_from_cursor_wrapper_multiple"), + pytest.lazy_fixture("models_from_cursor_wrapper_single"), + ], +) +def test_models_from_cursor_related_fields( + mocked_model_varying_fields, + mocked_model_single_field, + mocked_model_foreign_keys, + models_from_cursor_wrapper, +): + instance = mocked_model_foreign_keys.objects.create( + varying_fields=mocked_model_varying_fields.objects.create( + title="test", updated_at=timezone.now() + ), + single_field=mocked_model_single_field.objects.create(name="test"), + ) + + with connection.cursor() as cursor: + cursor.execute( + *mocked_model_foreign_keys.objects.select_related( + "varying_fields", "single_field" + ) + .all() + .query.sql_with_params() + ) + queried_instance = models_from_cursor_wrapper( + mocked_model_foreign_keys, + cursor, + related_fields=["varying_fields", "single_field"], + ) + + with CaptureQueriesContext(connection) as ctx: + assert queried_instance.id == instance.id + + assert queried_instance.varying_fields_id == instance.varying_fields_id + assert queried_instance.varying_fields.id == instance.varying_fields.id + assert ( + queried_instance.varying_fields.title + == instance.varying_fields.title + ) + assert ( + queried_instance.varying_fields.updated_at + == instance.varying_fields.updated_at + ) + assert ( + queried_instance.varying_fields.content + == instance.varying_fields.content + ) + assert ( + queried_instance.varying_fields.items + == instance.varying_fields.items + ) + + assert queried_instance.single_field_id == instance.single_field_id + assert queried_instance.single_field.id == instance.single_field.id + assert queried_instance.single_field.name == instance.single_field.name + + assert len(ctx.captured_queries) == 0 + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +@pytest.mark.parametrize( + "models_from_cursor_wrapper", + [ + pytest.lazy_fixture("models_from_cursor_wrapper_multiple"), + pytest.lazy_fixture("models_from_cursor_wrapper_single"), + ], +) +@pytest.mark.parametrize( + "selected", [True, False], ids=["selected", "not_selected"] +) +def test_models_from_cursor_related_fields_optional( + mocked_model_varying_fields, + mocked_model_foreign_keys, + models_from_cursor_wrapper, + selected, +): + instance = mocked_model_foreign_keys.objects.create( + varying_fields=mocked_model_varying_fields.objects.create( + title="test", updated_at=timezone.now() + ), + single_field=None, + ) + + with connection.cursor() as cursor: + select_related = ["varying_fields"] + if selected: + select_related.append("single_field") + + cursor.execute( + *mocked_model_foreign_keys.objects.select_related(*select_related) + .all() + .query.sql_with_params() + ) + queried_instance = models_from_cursor_wrapper( + mocked_model_foreign_keys, + cursor, + related_fields=["varying_fields", "single_field"], + ) + + assert queried_instance.id == instance.id + assert queried_instance.varying_fields_id == instance.varying_fields_id + assert queried_instance.single_field_id == instance.single_field_id + + with CaptureQueriesContext(connection) as ctx: + assert queried_instance.varying_fields.id == instance.varying_fields.id + assert ( + queried_instance.varying_fields.title + == instance.varying_fields.title + ) + assert ( + queried_instance.varying_fields.updated_at + == instance.varying_fields.updated_at + ) + assert ( + queried_instance.varying_fields.content + == instance.varying_fields.content + ) + assert ( + queried_instance.varying_fields.items + == instance.varying_fields.items + ) + + assert queried_instance.single_field is None + + assert len(ctx.captured_queries) == 0 + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +def test_models_from_cursor_generator_efficiency( + mocked_model_varying_fields, mocked_model_single_field +): + mocked_model_single_field.objects.create(name="a") + mocked_model_single_field.objects.create(name="b") + + with connection.cursor() as cursor: + cursor.execute( + *mocked_model_single_field.objects.all().query.sql_with_params() + ) + + instances_generator = models_from_cursor( + mocked_model_single_field, cursor + ) + assert cursor.rownumber == 0 + + next(instances_generator) + assert cursor.rownumber == 1 + + next(instances_generator) + assert cursor.rownumber == 2 + + assert not next(instances_generator, None) + assert cursor.rownumber == 2 + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +def test_models_from_cursor_tolerates_additional_columns( + mocked_model_foreign_keys, mocked_model_varying_fields +): + with connection.cursor() as cursor: + cursor.execute( + f"ALTER TABLE {mocked_model_foreign_keys._meta.db_table} ADD COLUMN new_col text DEFAULT NULL" + ) + cursor.execute( + f"ALTER TABLE {mocked_model_varying_fields._meta.db_table} ADD COLUMN new_col text DEFAULT NULL" + ) + + instance = mocked_model_foreign_keys.objects.create( + varying_fields=mocked_model_varying_fields.objects.create( + title="test", updated_at=timezone.now() + ), + single_field=None, + ) + + with connection.cursor() as cursor: + cursor.execute( + f""" + SELECT fk_t.*, vf_t.* FROM {mocked_model_foreign_keys._meta.db_table} fk_t + INNER JOIN {mocked_model_varying_fields._meta.db_table} vf_t ON vf_t.id = fk_t.varying_fields_id + """ + ) + + queried_instances = list( + models_from_cursor( + mocked_model_foreign_keys, + cursor, + related_fields=["varying_fields"], + ) + ) + + assert len(queried_instances) == 1 + assert queried_instances[0].id == instance.id + assert ( + queried_instances[0].varying_fields.id == instance.varying_fields.id + ) diff --git a/tests/test_locking.py b/tests/test_locking.py new file mode 100644 index 00000000..6414689d --- /dev/null +++ b/tests/test_locking.py @@ -0,0 +1,106 @@ +import uuid + +import pytest + +from django.db import connection, models, transaction + +from psqlextra.locking import ( + PostgresTableLockMode, + postgres_lock_model, + postgres_lock_table, +) + +from .fake_model import get_fake_model + + +@pytest.fixture +def mocked_model(): + return get_fake_model( + { + "name": models.TextField(), + } + ) + + +def get_table_locks(): + with connection.cursor() as cursor: + return connection.introspection.get_table_locks(cursor) + + +@pytest.mark.django_db(transaction=True) +def test_postgres_lock_table(mocked_model): + lock_signature = ( + "public", + mocked_model._meta.db_table, + "AccessExclusiveLock", + ) + with transaction.atomic(): + postgres_lock_table( + mocked_model._meta.db_table, PostgresTableLockMode.ACCESS_EXCLUSIVE + ) + assert lock_signature in get_table_locks() + + assert lock_signature not in get_table_locks() + + +@pytest.mark.django_db(transaction=True) +def test_postgres_lock_table_in_schema(): + schema_name = str(uuid.uuid4())[:8] + table_name = str(uuid.uuid4())[:8] + quoted_schema_name = connection.ops.quote_name(schema_name) + quoted_table_name = connection.ops.quote_name(table_name) + + with connection.cursor() as cursor: + cursor.execute(f"CREATE SCHEMA {quoted_schema_name}") + cursor.execute( + f"CREATE TABLE {quoted_schema_name}.{quoted_table_name} AS SELECT 'hello world'" + ) + + lock_signature = (schema_name, table_name, "ExclusiveLock") + with transaction.atomic(): + postgres_lock_table( + table_name, PostgresTableLockMode.EXCLUSIVE, schema_name=schema_name + ) + assert lock_signature in get_table_locks() + + assert lock_signature not in get_table_locks() + + +@pytest.mark.parametrize("lock_mode", list(PostgresTableLockMode)) +@pytest.mark.django_db(transaction=True) +def test_postgres_lock_model(mocked_model, lock_mode): + lock_signature = ( + "public", + mocked_model._meta.db_table, + lock_mode.alias, + ) + + with transaction.atomic(): + postgres_lock_model(mocked_model, lock_mode) + assert lock_signature in get_table_locks() + + assert lock_signature not in get_table_locks() + + +@pytest.mark.django_db(transaction=True) +def test_postgres_lock_model_in_schema(mocked_model): + schema_name = str(uuid.uuid4())[:8] + quoted_schema_name = connection.ops.quote_name(schema_name) + quoted_table_name = connection.ops.quote_name(mocked_model._meta.db_table) + + with connection.cursor() as cursor: + cursor.execute(f"CREATE SCHEMA {quoted_schema_name}") + cursor.execute( + f"CREATE TABLE {quoted_schema_name}.{quoted_table_name} (LIKE public.{quoted_table_name} INCLUDING ALL)" + ) + + lock_signature = (schema_name, mocked_model._meta.db_table, "ExclusiveLock") + with transaction.atomic(): + postgres_lock_model( + mocked_model, + PostgresTableLockMode.EXCLUSIVE, + schema_name=schema_name, + ) + assert lock_signature in get_table_locks() + + assert lock_signature not in get_table_locks() diff --git a/tests/test_lookups.py b/tests/test_lookups.py new file mode 100644 index 00000000..32d3f1c9 --- /dev/null +++ b/tests/test_lookups.py @@ -0,0 +1,102 @@ +from django.db import models + +from .fake_model import get_fake_model + + +def test_invalues_lookup_text_field(): + model = get_fake_model({"name": models.TextField()}) + [a, b] = model.objects.bulk_create( + [ + model(name="a"), + model(name="b"), + ] + ) + + results = list(model.objects.filter(name__invalues=[a.name, b.name, "c"])) + assert results == [a, b] + + +def test_invalues_lookup_integer_field(): + model = get_fake_model({"number": models.IntegerField()}) + [a, b] = model.objects.bulk_create( + [ + model(number=1), + model(number=2), + ] + ) + + results = list( + model.objects.filter(number__invalues=[a.number, b.number, 3]) + ) + assert results == [a, b] + + +def test_invalues_lookup_uuid_field(): + model = get_fake_model({"value": models.UUIDField()}) + [a, b] = model.objects.bulk_create( + [ + model(value="f8fe0431-29f8-4c4c-839c-8a6bf29f95d5"), + model(value="2fb0f45b-afaf-4e24-8637-2d81ded997bb"), + ] + ) + + results = list( + model.objects.filter( + value__invalues=[ + a.value, + b.value, + "d7a8df83-f3f8-487b-b982-547c8f22b0bb", + ] + ) + ) + assert results == [a, b] + + +def test_invalues_lookup_related_field(): + model_1 = get_fake_model({"name": models.TextField()}) + model_2 = get_fake_model( + {"relation": models.ForeignKey(model_1, on_delete=models.CASCADE)} + ) + + [a_relation, b_relation] = model_1.objects.bulk_create( + [ + model_1(name="a"), + model_1(name="b"), + ] + ) + + [a, b] = model_2.objects.bulk_create( + [model_2(relation=a_relation), model_2(relation=b_relation)] + ) + + results = list( + model_2.objects.filter(relation__invalues=[a_relation, b_relation]) + ) + assert results == [a, b] + + +def test_invalues_lookup_related_field_subquery(): + model_1 = get_fake_model({"name": models.TextField()}) + model_2 = get_fake_model( + {"relation": models.ForeignKey(model_1, on_delete=models.CASCADE)} + ) + + [a_relation, b_relation] = model_1.objects.bulk_create( + [ + model_1(name="a"), + model_1(name="b"), + ] + ) + + [a, b] = model_2.objects.bulk_create( + [model_2(relation=a_relation), model_2(relation=b_relation)] + ) + + results = list( + model_2.objects.filter( + relation__invalues=model_1.objects.all().values_list( + "id", flat=True + ) + ) + ) + assert results == [a, b] diff --git a/tests/test_on_conflict.py b/tests/test_on_conflict.py index cb4a88ca..02eda62f 100644 --- a/tests/test_on_conflict.py +++ b/tests/test_on_conflict.py @@ -3,6 +3,7 @@ from django.core.exceptions import SuspiciousOperation from django.db import connection, models +from django.test.utils import CaptureQueriesContext, override_settings from django.utils import timezone from psqlextra.fields import HStoreField @@ -13,6 +14,7 @@ @pytest.mark.parametrize("conflict_action", ConflictAction.all()) +@override_settings(POSTGRES_EXTRA_ANNOTATE_SQL=True) def test_on_conflict(conflict_action): """Tests whether simple inserts work correctly.""" @@ -23,9 +25,11 @@ def test_on_conflict(conflict_action): } ) - obj = model.objects.on_conflict( - [("title", "key1")], conflict_action - ).insert_and_get(title={"key1": "beer"}, cookies="cheers") + with CaptureQueriesContext(connection) as queries: + obj = model.objects.on_conflict( + [("title", "key1")], conflict_action + ).insert_and_get(title={"key1": "beer"}, cookies="cheers") + assert " test_on_conflict " in queries[0]["sql"] model.objects.on_conflict( [("title", "key1")], conflict_action diff --git a/tests/test_on_conflict_nothing.py b/tests/test_on_conflict_nothing.py index 78c4c5f4..92e74dfc 100644 --- a/tests/test_on_conflict_nothing.py +++ b/tests/test_on_conflict_nothing.py @@ -170,17 +170,26 @@ def test_on_conflict_nothing_foreign_key_by_id(): assert obj1.data == "some data" -def test_on_conflict_nothing_duplicate_rows(): +@pytest.mark.parametrize( + "rows,expected_row_count", + [ + ([dict(amount=1), dict(amount=1)], 1), + (iter([dict(amount=1), dict(amount=1)]), 1), + ((row for row in [dict(amount=1), dict(amount=1)]), 1), + ([], 0), + (iter([]), 0), + ((row for row in []), 0), + ], +) +def test_on_conflict_nothing_duplicate_rows(rows, expected_row_count): """Tests whether duplicate rows are filtered out when doing a insert NOTHING and no error is raised when the list of rows contains duplicates.""" model = get_fake_model({"amount": models.IntegerField(unique=True)}) - rows = [dict(amount=1), dict(amount=1)] + inserted_rows = model.objects.on_conflict( + ["amount"], ConflictAction.NOTHING + ).bulk_insert(rows) - ( - model.objects.on_conflict( - ["amount"], ConflictAction.NOTHING - ).bulk_insert(rows) - ) + assert len(inserted_rows) == expected_row_count diff --git a/tests/test_on_conflict_update.py b/tests/test_on_conflict_update.py index 8425e3d3..b93e5781 100644 --- a/tests/test_on_conflict_update.py +++ b/tests/test_on_conflict_update.py @@ -1,3 +1,4 @@ +import django import pytest from django.db import models @@ -41,6 +42,35 @@ def test_on_conflict_update(): assert obj2.cookies == "choco" +@pytest.mark.skipif( + django.VERSION < (2, 2), + reason="Django < 2.2 doesn't implement constraints", +) +def test_on_conflict_update_by_unique_constraint(): + model = get_fake_model( + { + "title": models.CharField(max_length=255, null=True), + }, + meta_options={ + "constraints": [ + models.UniqueConstraint(name="test_uniq", fields=["title"]), + ], + }, + ) + + constraint = next( + ( + constraint + for constraint in model._meta.constraints + if constraint.name == "test_uniq" + ) + ) + + model.objects.on_conflict(constraint, ConflictAction.UPDATE).insert_and_get( + title="title" + ) + + def test_on_conflict_update_foreign_key_by_object(): """Tests whether simple upsert works correctly when the conflicting field is a foreign key specified as an object.""" diff --git a/tests/test_partitioning_time.py b/tests/test_partitioning_time.py index 68808324..9f6b5bf1 100644 --- a/tests/test_partitioning_time.py +++ b/tests/test_partitioning_time.py @@ -115,6 +115,45 @@ def test_partitioning_time_monthly_apply(): assert table.partitions[13].name == "2020_feb" +@pytest.mark.postgres_version(lt=110000) +def test_partitioning_time_monthly_with_custom_naming_apply(): + """Tests whether automatically created new partitions are named according + to the specified name_format.""" + + model = define_fake_partitioned_model( + {"timestamp": models.DateTimeField()}, {"key": ["timestamp"]} + ) + + schema_editor = connection.schema_editor() + schema_editor.create_partitioned_model(model) + + # create partitions for the next 12 months (including the current) + with freezegun.freeze_time("2019-1-30"): + manager = PostgresPartitioningManager( + [ + partition_by_current_time( + model, months=1, count=12, name_format="%Y_%m" + ) + ] + ) + manager.plan().apply() + + table = _get_partitioned_table(model) + assert len(table.partitions) == 12 + assert table.partitions[0].name == "2019_01" + assert table.partitions[1].name == "2019_02" + assert table.partitions[2].name == "2019_03" + assert table.partitions[3].name == "2019_04" + assert table.partitions[4].name == "2019_05" + assert table.partitions[5].name == "2019_06" + assert table.partitions[6].name == "2019_07" + assert table.partitions[7].name == "2019_08" + assert table.partitions[8].name == "2019_09" + assert table.partitions[9].name == "2019_10" + assert table.partitions[10].name == "2019_11" + assert table.partitions[11].name == "2019_12" + + @pytest.mark.postgres_version(lt=110000) def test_partitioning_time_weekly_apply(): """Tests whether automatically creating new partitions ahead weekly works diff --git a/tests/test_query.py b/tests/test_query.py index e1496f51..38d6b3cb 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,5 +1,9 @@ -from django.db import models -from django.db.models import Case, F, Q, Value, When +from datetime import datetime, timezone + +from django.db import connection, models +from django.db.models import Case, F, Min, Q, Value, When +from django.db.models.functions.datetime import TruncSecond +from django.test.utils import CaptureQueriesContext, override_settings from psqlextra.expressions import HStoreRef from psqlextra.fields import HStoreField @@ -95,6 +99,40 @@ def test_query_annotate_in_expression(): assert result.is_he_henk == "really henk" +def test_query_annotate_group_by(): + """Tests whether annotations with GROUP BY clauses are properly renamed + when the annotation overwrites a field name.""" + + model = get_fake_model( + { + "name": models.TextField(), + "timestamp": models.DateTimeField(null=False), + "value": models.IntegerField(), + } + ) + + timestamp = datetime(2024, 1, 1, 0, 0, 0, 0, tzinfo=timezone.utc) + + model.objects.create(name="me", timestamp=timestamp, value=1) + + result = ( + model.objects.values("name") + .annotate( + timestamp=TruncSecond("timestamp", tzinfo=timezone.utc), + value=Min("value"), + ) + .values_list( + "name", + "value", + "timestamp", + ) + .order_by("name") + .first() + ) + + assert result == ("me", 1, timestamp) + + def test_query_hstore_value_update_f_ref(): """Tests whether F(..) expressions can be used in hstore values when performing update queries.""" @@ -134,3 +172,21 @@ def test_query_hstore_value_update_escape(): inst = model.objects.all().first() assert inst.title.get("en") == "console.log('test')" + + +@override_settings(POSTGRES_EXTRA_ANNOTATE_SQL=True) +def test_query_comment(): + """Tests whether the query is commented.""" + + model = get_fake_model( + { + "name": models.CharField(max_length=10), + "value": models.IntegerField(), + } + ) + + with CaptureQueriesContext(connection) as queries: + qs = model.objects.all() + assert " test_query_comment " in str(qs.query) + list(qs) + assert " test_query_comment " in queries[0]["sql"] diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 00000000..7ae4a3f2 --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,201 @@ +import freezegun +import pytest + +from django.core.exceptions import SuspiciousOperation, ValidationError +from django.db import InternalError, ProgrammingError, connection +from psycopg2 import errorcodes + +from psqlextra.error import extract_postgres_error_code +from psqlextra.schema import PostgresSchema, postgres_temporary_schema + + +def _does_schema_exist(name: str) -> bool: + with connection.cursor() as cursor: + return name in connection.introspection.get_schema_list(cursor) + + +def test_postgres_schema_create(): + schema = PostgresSchema.create("myschema") + assert schema.name == "myschema" + + assert _does_schema_exist(schema.name) + + +def test_postgres_schema_does_not_overwrite(): + schema = PostgresSchema.create("myschema") + + with pytest.raises(ProgrammingError): + PostgresSchema.create(schema.name) + + +def test_postgres_schema_create_max_name_length(): + with pytest.raises(ValidationError) as exc_info: + PostgresSchema.create( + "stringthatislongerhtan63charactersforsureabsolutelysurethisislongerthanthat" + ) + + assert "is longer than Postgres's limit" in str(exc_info.value) + + +def test_postgres_schema_create_name_that_requires_escaping(): + # 'table' needs escaping because it conflicts with + # the SQL keyword TABLE + schema = PostgresSchema.create("table") + assert schema.name == "table" + + assert _does_schema_exist("table") + + +def test_postgres_schema_create_time_based(): + with freezegun.freeze_time("2023-04-07 13:37:23.4"): + schema = PostgresSchema.create_time_based("myprefix") + + assert schema.name == "myprefix_20230407130423" + assert _does_schema_exist(schema.name) + + +def test_postgres_schema_create_time_based_long_prefix(): + with pytest.raises(ValidationError) as exc_info: + with freezegun.freeze_time("2023-04-07 13:37:23.4"): + PostgresSchema.create_time_based("a" * 49) + + assert "is longer than 48 characters" in str(exc_info.value) + + +def test_postgres_schema_create_random(): + schema = PostgresSchema.create_random("myprefix") + + prefix, suffix = schema.name.split("_") + assert prefix == "myprefix" + assert len(suffix) == 8 + + assert _does_schema_exist(schema.name) + + +def test_postgres_schema_create_random_long_prefix(): + with pytest.raises(ValidationError) as exc_info: + PostgresSchema.create_random("a" * 55) + + assert "is longer than 54 characters" in str(exc_info.value) + + +def test_postgres_schema_delete_and_create(): + schema = PostgresSchema.create("test") + + with connection.cursor() as cursor: + cursor.execute("CREATE TABLE test.bla AS SELECT 'hello'") + cursor.execute("SELECT * FROM test.bla") + + assert cursor.fetchone() == ("hello",) + + # Should refuse to delete since we added a table to the schema + with pytest.raises(InternalError) as exc_info: + schema = PostgresSchema.delete_and_create(schema.name) + + pg_error = extract_postgres_error_code(exc_info.value) + assert pg_error == errorcodes.DEPENDENT_OBJECTS_STILL_EXIST + + # Verify that the schema and table still exist + assert _does_schema_exist(schema.name) + with connection.cursor() as cursor: + cursor.execute("SELECT * FROM test.bla") + assert cursor.fetchone() == ("hello",) + + # Dropping the schema should work with cascade=True + schema = PostgresSchema.delete_and_create(schema.name, cascade=True) + assert _does_schema_exist(schema.name) + + # Since the schema was deleted and re-created, the `bla` + # table should not exist anymore. + with pytest.raises(ProgrammingError) as exc_info: + with connection.cursor() as cursor: + cursor.execute("SELECT * FROM test.bla") + assert cursor.fetchone() == ("hello",) + + pg_error = extract_postgres_error_code(exc_info.value) + assert pg_error == errorcodes.UNDEFINED_TABLE + + +def test_postgres_schema_delete(): + schema = PostgresSchema.create("test") + assert _does_schema_exist(schema.name) + + schema.delete() + assert not _does_schema_exist(schema.name) + + +def test_postgres_schema_delete_not_empty(): + schema = PostgresSchema.create("test") + assert _does_schema_exist(schema.name) + + with connection.cursor() as cursor: + cursor.execute("CREATE TABLE test.bla AS SELECT 'hello'") + + with pytest.raises(InternalError) as exc_info: + schema.delete() + + pg_error = extract_postgres_error_code(exc_info.value) + assert pg_error == errorcodes.DEPENDENT_OBJECTS_STILL_EXIST + + +def test_postgres_schema_delete_cascade_not_empty(): + schema = PostgresSchema.create("test") + assert _does_schema_exist(schema.name) + + with connection.cursor() as cursor: + cursor.execute("CREATE TABLE test.bla AS SELECT 'hello'") + + schema.delete(cascade=True) + assert not _does_schema_exist(schema.name) + + +def test_postgres_schema_no_delete_default(): + with pytest.raises(SuspiciousOperation): + PostgresSchema.default.delete() + + with pytest.raises(SuspiciousOperation): + PostgresSchema("public").delete() + + +def test_postgres_temporary_schema(): + with postgres_temporary_schema("temp") as schema: + name_prefix, name_suffix = schema.name.split("_") + assert name_prefix == "temp" + assert len(name_suffix) == 8 + + assert _does_schema_exist(schema.name) + + assert not _does_schema_exist(schema.name) + + +def test_postgres_temporary_schema_not_empty(): + with pytest.raises(InternalError) as exc_info: + with postgres_temporary_schema("temp") as schema: + with connection.cursor() as cursor: + cursor.execute( + f"CREATE TABLE {schema.name}.mytable AS SELECT 'hello world'" + ) + + pg_error = extract_postgres_error_code(exc_info.value) + assert pg_error == errorcodes.DEPENDENT_OBJECTS_STILL_EXIST + + +def test_postgres_temporary_schema_not_empty_cascade(): + with postgres_temporary_schema("temp", cascade=True) as schema: + with connection.cursor() as cursor: + cursor.execute( + f"CREATE TABLE {schema.name}.mytable AS SELECT 'hello world'" + ) + + assert not _does_schema_exist(schema.name) + + +@pytest.mark.parametrize("delete_on_throw", [True, False]) +def test_postgres_temporary_schema_no_delete_on_throw(delete_on_throw): + with pytest.raises(ValueError): + with postgres_temporary_schema( + "temp", delete_on_throw=delete_on_throw + ) as schema: + raise ValueError("test") + + assert _does_schema_exist(schema.name) != delete_on_throw diff --git a/tests/test_schema_editor_alter_schema.py b/tests/test_schema_editor_alter_schema.py new file mode 100644 index 00000000..7fda103b --- /dev/null +++ b/tests/test_schema_editor_alter_schema.py @@ -0,0 +1,44 @@ +import pytest + +from django.db import connection, models + +from psqlextra.backend.schema import PostgresSchemaEditor + +from .fake_model import get_fake_model + + +@pytest.fixture +def fake_model(): + return get_fake_model( + { + "text": models.TextField(), + } + ) + + +def test_schema_editor_alter_table_schema(fake_model): + obj = fake_model.objects.create(text="hello") + + with connection.cursor() as cursor: + cursor.execute("CREATE SCHEMA target") + + schema_editor = PostgresSchemaEditor(connection) + schema_editor.alter_table_schema(fake_model._meta.db_table, "target") + + with connection.cursor() as cursor: + cursor.execute(f"SELECT * FROM target.{fake_model._meta.db_table}") + assert cursor.fetchall() == [(obj.id, obj.text)] + + +def test_schema_editor_alter_model_schema(fake_model): + obj = fake_model.objects.create(text="hello") + + with connection.cursor() as cursor: + cursor.execute("CREATE SCHEMA target") + + schema_editor = PostgresSchemaEditor(connection) + schema_editor.alter_model_schema(fake_model, "target") + + with connection.cursor() as cursor: + cursor.execute(f"SELECT * FROM target.{fake_model._meta.db_table}") + assert cursor.fetchall() == [(obj.id, obj.text)] diff --git a/tests/test_schema_editor_clone_model_to_schema.py b/tests/test_schema_editor_clone_model_to_schema.py new file mode 100644 index 00000000..c3d41917 --- /dev/null +++ b/tests/test_schema_editor_clone_model_to_schema.py @@ -0,0 +1,330 @@ +import os + +from typing import Set, Tuple + +import django +import pytest + +from django.contrib.postgres.fields import ArrayField +from django.contrib.postgres.indexes import GinIndex +from django.db import connection, models, transaction +from django.db.models import Q + +from psqlextra.backend.schema import PostgresSchemaEditor + +from . import db_introspection +from .fake_model import delete_fake_model, get_fake_model + +django_32_skip_reason = "Django < 3.2 can't support cloning models because it has hard coded references to the public schema" + + +def _create_schema() -> str: + name = os.urandom(4).hex() + + with connection.cursor() as cursor: + cursor.execute( + "DROP SCHEMA IF EXISTS %s CASCADE" + % connection.ops.quote_name(name), + tuple(), + ) + cursor.execute( + "CREATE SCHEMA %s" % connection.ops.quote_name(name), tuple() + ) + + return name + + +@transaction.atomic +def _assert_cloned_table_is_same( + source_table_fqn: Tuple[str, str], + target_table_fqn: Tuple[str, str], + excluding_constraints_and_indexes: bool = False, +): + source_schema_name, source_table_name = source_table_fqn + target_schema_name, target_table_name = target_table_fqn + + source_columns = db_introspection.get_columns( + source_table_name, schema_name=source_schema_name + ) + target_columns = db_introspection.get_columns( + target_table_name, schema_name=target_schema_name + ) + assert source_columns == target_columns + + source_relations = db_introspection.get_relations( + source_table_name, schema_name=source_schema_name + ) + target_relations = db_introspection.get_relations( + target_table_name, schema_name=target_schema_name + ) + if excluding_constraints_and_indexes: + assert target_relations == {} + else: + assert source_relations == target_relations + + source_constraints = db_introspection.get_constraints( + source_table_name, schema_name=source_schema_name + ) + target_constraints = db_introspection.get_constraints( + target_table_name, schema_name=target_schema_name + ) + if excluding_constraints_and_indexes: + assert target_constraints == {} + else: + assert source_constraints == target_constraints + + source_sequences = db_introspection.get_sequences( + source_table_name, schema_name=source_schema_name + ) + target_sequences = db_introspection.get_sequences( + target_table_name, schema_name=target_schema_name + ) + assert source_sequences == target_sequences + + source_storage_settings = db_introspection.get_storage_settings( + source_table_name, + schema_name=source_schema_name, + ) + target_storage_settings = db_introspection.get_storage_settings( + target_table_name, schema_name=target_schema_name + ) + assert source_storage_settings == target_storage_settings + + +def _list_lock_modes_in_schema(schema_name: str) -> Set[str]: + with connection.cursor() as cursor: + cursor.execute( + """ + SELECT + l.mode + FROM pg_locks l + INNER JOIN pg_class t ON t.oid = l.relation + INNER JOIN pg_namespace n ON n.oid = t.relnamespace + WHERE + t.relnamespace >= 2200 + AND n.nspname = %s + ORDER BY n.nspname, t.relname, l.mode + """, + (schema_name,), + ) + + return {lock_mode for lock_mode, in cursor.fetchall()} + + +def _clone_model_into_schema(model): + schema_name = _create_schema() + + with PostgresSchemaEditor(connection) as schema_editor: + schema_editor.clone_model_structure_to_schema( + model, schema_name=schema_name + ) + schema_editor.clone_model_constraints_and_indexes_to_schema( + model, schema_name=schema_name + ) + schema_editor.clone_model_foreign_keys_to_schema( + model, schema_name=schema_name + ) + + return schema_name + + +@pytest.fixture +def fake_model_fk_target_1(): + model = get_fake_model( + { + "name": models.TextField(), + }, + ) + + yield model + + delete_fake_model(model) + + +@pytest.fixture +def fake_model_fk_target_2(): + model = get_fake_model( + { + "name": models.TextField(), + }, + ) + + yield model + + delete_fake_model(model) + + +@pytest.fixture +def fake_model(fake_model_fk_target_1, fake_model_fk_target_2): + model = get_fake_model( + { + "first_name": models.TextField(null=True), + "last_name": models.TextField(), + "age": models.PositiveIntegerField(), + "height": models.FloatField(), + "nicknames": ArrayField(base_field=models.TextField()), + "blob": models.JSONField(), + "family": models.ForeignKey( + fake_model_fk_target_1, on_delete=models.CASCADE + ), + "alternative_family": models.ForeignKey( + fake_model_fk_target_2, null=True, on_delete=models.SET_NULL + ), + }, + meta_options={ + "indexes": [ + models.Index(fields=["age", "height"]), + models.Index(fields=["age"], name="age_index"), + GinIndex(fields=["nicknames"], name="nickname_index"), + ], + "constraints": [ + models.UniqueConstraint( + fields=["first_name", "last_name"], + name="first_last_name_uniq", + ), + models.CheckConstraint( + check=Q(age__gt=0, height__gt=0), name="age_height_check" + ), + ], + "unique_together": ( + "first_name", + "nicknames", + ), + "index_together": ( + "blob", + "age", + ), + }, + ) + + yield model + + delete_fake_model(model) + + +@pytest.mark.skipif( + django.VERSION < (3, 2), + reason=django_32_skip_reason, +) +@pytest.mark.django_db(transaction=True) +def test_schema_editor_clone_model_to_schema( + fake_model, fake_model_fk_target_1, fake_model_fk_target_2 +): + """Tests that cloning a model into a separate schema without obtaining + AccessExclusiveLock on the source table works as expected.""" + + schema_editor = PostgresSchemaEditor(connection) + + with schema_editor: + schema_editor.alter_table_storage_setting( + fake_model._meta.db_table, "autovacuum_enabled", "false" + ) + + table_name = fake_model._meta.db_table + source_schema_name = "public" + target_schema_name = _create_schema() + + with schema_editor: + schema_editor.clone_model_structure_to_schema( + fake_model, schema_name=target_schema_name + ) + + assert _list_lock_modes_in_schema(source_schema_name) == { + "AccessShareLock" + } + + _assert_cloned_table_is_same( + (source_schema_name, table_name), + (target_schema_name, table_name), + excluding_constraints_and_indexes=True, + ) + + with schema_editor: + schema_editor.clone_model_constraints_and_indexes_to_schema( + fake_model, schema_name=target_schema_name + ) + + assert _list_lock_modes_in_schema(source_schema_name) == { + "AccessShareLock", + "ShareRowExclusiveLock", + } + + _assert_cloned_table_is_same( + (source_schema_name, table_name), + (target_schema_name, table_name), + ) + + with schema_editor: + schema_editor.clone_model_foreign_keys_to_schema( + fake_model, schema_name=target_schema_name + ) + + assert _list_lock_modes_in_schema(source_schema_name) == { + "AccessShareLock", + "RowShareLock", + } + + _assert_cloned_table_is_same( + (source_schema_name, table_name), + (target_schema_name, table_name), + ) + + +@pytest.mark.skipif( + django.VERSION < (3, 2), + reason=django_32_skip_reason, +) +def test_schema_editor_clone_model_to_schema_custom_constraint_names( + fake_model, fake_model_fk_target_1 +): + """Tests that even if constraints were given custom names, the cloned table + has those same custom names.""" + + table_name = fake_model._meta.db_table + source_schema_name = "public" + + constraints = db_introspection.get_constraints(table_name) + + primary_key_constraint = next( + ( + name + for name, constraint in constraints.items() + if constraint["primary_key"] + ), + None, + ) + foreign_key_constraint = next( + ( + name + for name, constraint in constraints.items() + if constraint["foreign_key"] + == (fake_model_fk_target_1._meta.db_table, "id") + ), + None, + ) + check_constraint = next( + ( + name + for name, constraint in constraints.items() + if constraint["check"] and constraint["columns"] == ["age"] + ), + None, + ) + + with connection.cursor() as cursor: + cursor.execute( + f"ALTER TABLE {table_name} RENAME CONSTRAINT {primary_key_constraint} TO custompkname" + ) + cursor.execute( + f"ALTER TABLE {table_name} RENAME CONSTRAINT {foreign_key_constraint} TO customfkname" + ) + cursor.execute( + f"ALTER TABLE {table_name} RENAME CONSTRAINT {check_constraint} TO customcheckname" + ) + + target_schema_name = _clone_model_into_schema(fake_model) + + _assert_cloned_table_is_same( + (source_schema_name, table_name), + (target_schema_name, table_name), + ) diff --git a/tests/test_schema_editor_partitioning.py b/tests/test_schema_editor_partitioning.py index 17f2469f..c80efd52 100644 --- a/tests/test_schema_editor_partitioning.py +++ b/tests/test_schema_editor_partitioning.py @@ -76,6 +76,39 @@ def test_schema_editor_create_delete_partitioned_model_list(): assert len(partitions) == 0 +@pytest.mark.postgres_version(lt=110000) +@pytest.mark.parametrize("key", [["name"], ["id", "name"]]) +def test_schema_editor_create_delete_partitioned_model_hash(key): + """Tests whether creating a partitioned model and adding a hash partition + to it using the :see:PostgresSchemaEditor works.""" + + method = PostgresPartitioningMethod.HASH + + model = define_fake_partitioned_model( + {"name": models.TextField()}, + {"method": method, "key": key}, + ) + + schema_editor = PostgresSchemaEditor(connection) + schema_editor.create_partitioned_model(model) + + schema_editor.add_hash_partition(model, "pt1", modulus=1, remainder=0) + + table = db_introspection.get_partitioned_table(model._meta.db_table) + assert table.name == model._meta.db_table + assert table.method == method + assert table.key == key + assert table.partitions[0].full_name == model._meta.db_table + "_pt1" + + schema_editor.delete_partitioned_model(model) + + table = db_introspection.get_partitioned_table(model._meta.db_table) + assert not table + + partitions = db_introspection.get_partitions(model._meta.db_table) + assert len(partitions) == 0 + + @pytest.mark.postgres_version(lt=110000) def test_schema_editor_create_delete_partitioned_model_default(): """Tests whether creating a partitioned model and adding a default diff --git a/tests/test_schema_editor_storage_settings.py b/tests/test_schema_editor_storage_settings.py new file mode 100644 index 00000000..0f45934f --- /dev/null +++ b/tests/test_schema_editor_storage_settings.py @@ -0,0 +1,47 @@ +import pytest + +from django.db import connection, models + +from psqlextra.backend.schema import PostgresSchemaEditor + +from . import db_introspection +from .fake_model import get_fake_model + + +@pytest.fixture +def fake_model(): + return get_fake_model( + { + "text": models.TextField(), + } + ) + + +def test_schema_editor_storage_settings_table_alter_and_reset(fake_model): + table_name = fake_model._meta.db_table + schema_editor = PostgresSchemaEditor(connection) + + schema_editor.alter_table_storage_setting( + table_name, "autovacuum_enabled", "false" + ) + assert db_introspection.get_storage_settings(table_name) == { + "autovacuum_enabled": "false" + } + + schema_editor.reset_table_storage_setting(table_name, "autovacuum_enabled") + assert db_introspection.get_storage_settings(table_name) == {} + + +def test_schema_editor_storage_settings_model_alter_and_reset(fake_model): + table_name = fake_model._meta.db_table + schema_editor = PostgresSchemaEditor(connection) + + schema_editor.alter_model_storage_setting( + fake_model, "autovacuum_enabled", "false" + ) + assert db_introspection.get_storage_settings(table_name) == { + "autovacuum_enabled": "false" + } + + schema_editor.reset_model_storage_setting(fake_model, "autovacuum_enabled") + assert db_introspection.get_storage_settings(table_name) == {} diff --git a/tests/test_schema_editor_vacuum.py b/tests/test_schema_editor_vacuum.py new file mode 100644 index 00000000..59772e86 --- /dev/null +++ b/tests/test_schema_editor_vacuum.py @@ -0,0 +1,147 @@ +import pytest + +from django.core.exceptions import SuspiciousOperation +from django.db import connection, models +from django.test.utils import CaptureQueriesContext + +from psqlextra.backend.schema import PostgresSchemaEditor + +from .fake_model import delete_fake_model, get_fake_model + + +@pytest.fixture +def fake_model(): + model = get_fake_model( + { + "name": models.TextField(), + } + ) + + yield model + + delete_fake_model(model) + + +@pytest.fixture +def fake_model_non_concrete_field(fake_model): + model = get_fake_model( + { + "fk": models.ForeignKey( + fake_model, on_delete=models.CASCADE, related_name="fakes" + ), + } + ) + + yield model + + delete_fake_model(model) + + +def test_schema_editor_vacuum_not_in_transaction(fake_model): + schema_editor = PostgresSchemaEditor(connection) + + with pytest.raises(SuspiciousOperation): + schema_editor.vacuum_table(fake_model._meta.db_table) + + +@pytest.mark.parametrize( + "kwargs,query", + [ + (dict(), "VACUUM %s"), + (dict(full=True), "VACUUM (FULL) %s"), + (dict(analyze=True), "VACUUM (ANALYZE) %s"), + (dict(parallel=8), "VACUUM (PARALLEL 8) %s"), + (dict(analyze=True, verbose=True), "VACUUM (VERBOSE, ANALYZE) %s"), + ( + dict(analyze=True, parallel=8, verbose=True), + "VACUUM (VERBOSE, ANALYZE, PARALLEL 8) %s", + ), + (dict(freeze=True), "VACUUM (FREEZE) %s"), + (dict(verbose=True), "VACUUM (VERBOSE) %s"), + (dict(disable_page_skipping=True), "VACUUM (DISABLE_PAGE_SKIPPING) %s"), + (dict(skip_locked=True), "VACUUM (SKIP_LOCKED) %s"), + (dict(index_cleanup=True), "VACUUM (INDEX_CLEANUP) %s"), + (dict(truncate=True), "VACUUM (TRUNCATE) %s"), + ], +) +@pytest.mark.django_db(transaction=True) +def test_schema_editor_vacuum_table(fake_model, kwargs, query): + schema_editor = PostgresSchemaEditor(connection) + + with CaptureQueriesContext(connection) as ctx: + schema_editor.vacuum_table(fake_model._meta.db_table, **kwargs) + + queries = [query["sql"] for query in ctx.captured_queries] + assert queries == [ + query % connection.ops.quote_name(fake_model._meta.db_table) + ] + + +@pytest.mark.django_db(transaction=True) +def test_schema_editor_vacuum_table_columns(fake_model): + schema_editor = PostgresSchemaEditor(connection) + + with CaptureQueriesContext(connection) as ctx: + schema_editor.vacuum_table( + fake_model._meta.db_table, ["id", "name"], analyze=True + ) + + queries = [query["sql"] for query in ctx.captured_queries] + assert queries == [ + 'VACUUM (ANALYZE) %s ("id", "name")' + % connection.ops.quote_name(fake_model._meta.db_table) + ] + + +@pytest.mark.django_db(transaction=True) +def test_schema_editor_vacuum_model(fake_model): + schema_editor = PostgresSchemaEditor(connection) + + with CaptureQueriesContext(connection) as ctx: + schema_editor.vacuum_model(fake_model, analyze=True, parallel=8) + + queries = [query["sql"] for query in ctx.captured_queries] + assert queries == [ + "VACUUM (ANALYZE, PARALLEL 8) %s" + % connection.ops.quote_name(fake_model._meta.db_table) + ] + + +@pytest.mark.django_db(transaction=True) +def test_schema_editor_vacuum_model_fields(fake_model): + schema_editor = PostgresSchemaEditor(connection) + + with CaptureQueriesContext(connection) as ctx: + schema_editor.vacuum_model( + fake_model, + [fake_model._meta.get_field("name")], + analyze=True, + parallel=8, + ) + + queries = [query["sql"] for query in ctx.captured_queries] + assert queries == [ + 'VACUUM (ANALYZE, PARALLEL 8) %s ("name")' + % connection.ops.quote_name(fake_model._meta.db_table) + ] + + +@pytest.mark.django_db(transaction=True) +def test_schema_editor_vacuum_model_non_concrete_fields( + fake_model, fake_model_non_concrete_field +): + schema_editor = PostgresSchemaEditor(connection) + + with CaptureQueriesContext(connection) as ctx: + schema_editor.vacuum_model( + fake_model, + [fake_model._meta.get_field("fakes")], + analyze=True, + parallel=8, + ) + + queries = [query["sql"] for query in ctx.captured_queries] + assert queries == [ + "VACUUM (ANALYZE, PARALLEL 8) %s" + % connection.ops.quote_name(fake_model._meta.db_table) + ] diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 00000000..44519714 --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,93 @@ +import pytest + +from django.core.exceptions import SuspiciousOperation +from django.db import connection + +from psqlextra.settings import ( + postgres_prepend_local_search_path, + postgres_reset_local_search_path, + postgres_set_local, + postgres_set_local_search_path, +) + + +def _get_current_setting(name: str) -> None: + with connection.cursor() as cursor: + cursor.execute(f"SHOW {name}") + return cursor.fetchone()[0] + + +@postgres_set_local(statement_timeout="2s", lock_timeout="3s") +def test_postgres_set_local_function_decorator(): + assert _get_current_setting("statement_timeout") == "2s" + assert _get_current_setting("lock_timeout") == "3s" + + +def test_postgres_set_local_context_manager(): + with postgres_set_local(statement_timeout="2s"): + assert _get_current_setting("statement_timeout") == "2s" + + assert _get_current_setting("statement_timeout") == "0" + + +def test_postgres_set_local_iterable(): + with postgres_set_local(search_path=["a", "public"]): + assert _get_current_setting("search_path") == "a, public" + + assert _get_current_setting("search_path") == '"$user", public' + + +def test_postgres_set_local_nested(): + with postgres_set_local(statement_timeout="2s"): + assert _get_current_setting("statement_timeout") == "2s" + + with postgres_set_local(statement_timeout="3s"): + assert _get_current_setting("statement_timeout") == "3s" + + assert _get_current_setting("statement_timeout") == "2s" + + assert _get_current_setting("statement_timeout") == "0" + + +@pytest.mark.django_db(transaction=True) +def test_postgres_set_local_no_transaction(): + with pytest.raises(SuspiciousOperation): + with postgres_set_local(statement_timeout="2s"): + pass + + +def test_postgres_set_local_search_path(): + with postgres_set_local_search_path(["a", "public"]): + assert _get_current_setting("search_path") == "a, public" + + assert _get_current_setting("search_path") == '"$user", public' + + +def test_postgres_reset_local_search_path(): + with postgres_set_local_search_path(["a", "public"]): + with postgres_reset_local_search_path(): + assert _get_current_setting("search_path") == '"$user", public' + + assert _get_current_setting("search_path") == "a, public" + + assert _get_current_setting("search_path") == '"$user", public' + + +def test_postgres_prepend_local_search_path(): + with postgres_prepend_local_search_path(["a", "b"]): + assert _get_current_setting("search_path") == 'a, b, "$user", public' + + assert _get_current_setting("search_path") == '"$user", public' + + +def test_postgres_prepend_local_search_path_nested(): + with postgres_prepend_local_search_path(["a", "b"]): + with postgres_prepend_local_search_path(["c"]): + assert ( + _get_current_setting("search_path") + == 'c, a, b, "$user", public' + ) + + assert _get_current_setting("search_path") == 'a, b, "$user", public' + + assert _get_current_setting("search_path") == '"$user", public' diff --git a/tests/test_upsert.py b/tests/test_upsert.py index 42e2eb86..a9e567b2 100644 --- a/tests/test_upsert.py +++ b/tests/test_upsert.py @@ -1,12 +1,14 @@ import django import pytest -from django.db import models -from django.db.models import Q +from django.db import connection, models +from django.db.models import F, Q from django.db.models.expressions import CombinedExpression, Value +from django.test.utils import CaptureQueriesContext from psqlextra.expressions import ExcludedCol from psqlextra.fields import HStoreField +from psqlextra.query import ConflictAction from .fake_model import get_fake_model @@ -82,6 +84,22 @@ def test_upsert_explicit_pk(): assert obj2.cookies == "second-boo" +def test_upsert_one_to_one_field(): + model1 = get_fake_model({"title": models.TextField(unique=True)}) + model2 = get_fake_model( + {"model1": models.OneToOneField(model1, on_delete=models.CASCADE)} + ) + + obj1 = model1.objects.create(title="hello world") + + obj2_id = model2.objects.upsert( + conflict_target=["model1"], fields=dict(model1=obj1) + ) + + obj2 = model2.objects.get(id=obj2_id) + assert obj2.model1 == obj1 + + def test_upsert_with_update_condition(): """Tests that an expression can be used as an upsert update condition.""" @@ -127,6 +145,83 @@ def test_upsert_with_update_condition(): assert obj1.active +@pytest.mark.parametrize("update_condition_value", [0, False]) +def test_upsert_with_update_condition_false(update_condition_value): + """Tests that an expression can be used as an upsert update condition.""" + + model = get_fake_model( + { + "name": models.TextField(unique=True), + "priority": models.IntegerField(), + "active": models.BooleanField(), + } + ) + + obj1 = model.objects.create(name="joe", priority=1, active=False) + + with CaptureQueriesContext(connection) as ctx: + upsert_result = model.objects.upsert( + conflict_target=["name"], + update_condition=update_condition_value, + fields=dict(name="joe", priority=2, active=True), + ) + assert upsert_result is None + assert len(ctx) == 1 + assert 'ON CONFLICT ("name") DO NOTHING' in ctx[0]["sql"] + + obj1.refresh_from_db() + assert obj1.priority == 1 + assert not obj1.active + + +def test_upsert_with_update_values(): + """Tests that the default update values can be overriden with custom + expressions.""" + + model = get_fake_model( + { + "name": models.TextField(unique=True), + "count": models.IntegerField(default=0), + } + ) + + obj1 = model.objects.create(name="joe") + + model.objects.upsert( + conflict_target=["name"], + fields=dict(name="joe"), + update_values=dict( + count=F("count") + 1, + ), + ) + + obj1.refresh_from_db() + assert obj1.count == 1 + + +def test_upsert_with_update_values_empty(): + """Tests that an upsert with an empty dict turns into ON CONFLICT DO + NOTHING.""" + + model = get_fake_model( + { + "name": models.TextField(unique=True), + "count": models.IntegerField(default=0), + } + ) + + obj1 = model.objects.create(name="joe") + + model.objects.upsert( + conflict_target=["name"], + fields=dict(name="joe"), + update_values={}, + ) + + obj1.refresh_from_db() + assert obj1.count == 0 + + @pytest.mark.skipif( django.VERSION < (3, 1), reason="requires django 3.1 or newer" ) @@ -183,7 +278,7 @@ def from_db_value(self, value, expression, connection): assert obj.title == "bye" -def test_upsert_bulk(): +def test_bulk_upsert(): """Tests whether bulk_upsert works properly.""" model = get_fake_model( @@ -229,10 +324,18 @@ def test_upsert_bulk_no_rows(): {"name": models.CharField(max_length=255, null=True, unique=True)} ) + model.objects.on_conflict(ConflictAction.UPDATE, ["name"]).bulk_insert( + rows=[] + ) + model.objects.bulk_upsert(conflict_target=["name"], rows=[]) model.objects.bulk_upsert(conflict_target=["name"], rows=None) + model.objects.on_conflict(ConflictAction.UPDATE, ["name"]).bulk_insert( + rows=None + ) + def test_bulk_upsert_return_models(): """Tests whether models are returned instead of dictionaries when @@ -312,3 +415,93 @@ def __iter__(self): for index, obj in enumerate(objs, 1): assert isinstance(obj, model) assert obj.id == index + + +def test_bulk_upsert_update_values(): + model = get_fake_model( + { + "name": models.CharField(max_length=255, unique=True), + "count": models.IntegerField(default=0), + } + ) + + model.objects.bulk_create( + [ + model(name="joe"), + model(name="john"), + ] + ) + + objs = model.objects.bulk_upsert( + conflict_target=["name"], + rows=[], + return_model=True, + update_values=dict(count=F("count") + 1), + ) + + assert all([obj for obj in objs if obj.count == 1]) + + +@pytest.mark.parametrize("return_model", [True]) +def test_bulk_upsert_extra_columns_in_schema(return_model): + """Tests that extra columns being returned by the database that aren't + known by Django don't make the bulk upsert crash.""" + + model = get_fake_model( + { + "name": models.CharField(max_length=255, unique=True), + } + ) + + with connection.cursor() as cursor: + cursor.execute( + f"ALTER TABLE {model._meta.db_table} ADD COLUMN new_name text NOT NULL DEFAULT %s", + ("newjoe",), + ) + + objs = model.objects.bulk_upsert( + conflict_target=["name"], + rows=[ + dict(name="joe"), + ], + return_model=return_model, + ) + + assert len(objs) == 1 + + if return_model: + assert objs[0].name == "joe" + else: + assert objs[0]["name"] == "joe" + assert sorted(list(objs[0].keys())) == ["id", "name"] + + +def test_upsert_extra_columns_in_schema(): + """Tests that extra columns being returned by the database that aren't + known by Django don't make the upsert crash.""" + + model = get_fake_model( + { + "name": models.CharField(max_length=255, unique=True), + } + ) + + with connection.cursor() as cursor: + cursor.execute( + f"ALTER TABLE {model._meta.db_table} ADD COLUMN new_name text NOT NULL DEFAULT %s", + ("newjoe",), + ) + + obj_id = model.objects.upsert( + conflict_target=["name"], + fields=dict(name="joe"), + ) + + assert obj_id == 1 + + obj = model.objects.upsert_and_get( + conflict_target=["name"], + fields=dict(name="joe"), + ) + + assert obj.name == "joe" diff --git a/tests/test_view_models.py b/tests/test_view_models.py index d9333ddc..b0ee8669 100644 --- a/tests/test_view_models.py +++ b/tests/test_view_models.py @@ -2,6 +2,7 @@ from django.core.exceptions import ImproperlyConfigured from django.db import models +from django.test.utils import override_settings from psqlextra.models import PostgresMaterializedViewModel, PostgresViewModel @@ -11,6 +12,7 @@ @pytest.mark.parametrize( "model_base", [PostgresViewModel, PostgresMaterializedViewModel] ) +@override_settings(POSTGRES_EXTRA_ANNOTATE_SQL=True) def test_view_model_meta_query_set(model_base): """Tests whether you can set a :see:QuerySet to be used as the underlying query for a view.""" @@ -26,7 +28,8 @@ def test_view_model_meta_query_set(model_base): expected_sql = 'SELECT "{0}"."id", "{0}"."name" FROM "{0}"'.format( model._meta.db_table ) - assert view_model._view_meta.query == (expected_sql, tuple()) + assert view_model._view_meta.query[0].startswith(expected_sql + " /* ") + assert view_model._view_meta.query[1] == tuple() @pytest.mark.parametrize( diff --git a/tox.ini b/tox.ini index bb4e0d77..70a0e8ce 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,9 @@ [tox] -envlist = py36-dj{20,21,22,30,31,32}, py37-dj{20,21,22,30,31,32}, py38-dj{20,21,22,30,31,32}, py39-dj{21,22,30,31,32}, py310-dj{21,22,30,31,32} +envlist = + {py36,py37}-dj{20,21,22,30,31,32}-psycopg{28,29} + {py38,py39,py310}-dj{21,22,30,31,32,40}-psycopg{28,29} + {py38,py39,py310,py311}-dj{41}-psycopg{28,29} + {py310,py311}-dj{42,50}-psycopg{28,29,31} [testenv] deps = @@ -9,6 +13,13 @@ deps = dj30: Django~=3.0.0 dj31: Django~=3.1.0 dj32: Django~=3.2.0 + dj40: Django~=4.0.0 + dj41: Django~=4.1.0 + dj42: Django~=4.2.0 + dj50: Django~=5.0.1 + psycopg28: psycopg2[binary]~=2.8 + psycopg29: psycopg2[binary]~=2.9 + psycopg31: psycopg[binary]~=3.1 .[test] setenv = DJANGO_SETTINGS_MODULE=settings