diff --git a/django/db/migrations/writer.py b/django/db/migrations/writer.py index 3c4bbb042714..6d77c6121a28 100644 --- a/django/db/migrations/writer.py +++ b/django/db/migrations/writer.py @@ -46,29 +46,14 @@ def __init__(self, operation): self.buff = [] def serialize(self): - imports = set() - name, args, kwargs = self.operation.deconstruct() - argspec = inspect.getargspec(self.operation.__init__) - normalized_kwargs = inspect.getcallargs(self.operation.__init__, *args, **kwargs) - - # See if this operation is in django.db.migrations. If it is, - # We can just use the fact we already have that imported, - # otherwise, we need to add an import for the operation class. - if getattr(migrations, name, None) == self.operation.__class__: - self.feed('migrations.%s(' % name) - else: - imports.add('import %s' % (self.operation.__class__.__module__)) - self.feed('%s.%s(' % (self.operation.__class__.__module__, name)) - self.indent() - for arg_name in argspec.args[1:]: - arg_value = normalized_kwargs[arg_name] - if (arg_name in self.operation.serialization_expand_args and - isinstance(arg_value, (list, tuple, dict))): - if isinstance(arg_value, dict): - self.feed('%s={' % arg_name) + def _write(_arg_name, _arg_value): + if (_arg_name in self.operation.serialization_expand_args and + isinstance(_arg_value, (list, tuple, dict))): + if isinstance(_arg_value, dict): + self.feed('%s={' % _arg_name) self.indent() - for key, value in arg_value.items(): + for key, value in _arg_value.items(): key_string, key_imports = MigrationWriter.serialize(key) arg_string, arg_imports = MigrationWriter.serialize(value) self.feed('%s: %s,' % (key_string, arg_string)) @@ -77,18 +62,47 @@ def serialize(self): self.unindent() self.feed('},') else: - self.feed('%s=[' % arg_name) + self.feed('%s=[' % _arg_name) self.indent() - for item in arg_value: + for item in _arg_value: arg_string, arg_imports = MigrationWriter.serialize(item) self.feed('%s,' % arg_string) imports.update(arg_imports) self.unindent() self.feed('],') else: - arg_string, arg_imports = MigrationWriter.serialize(arg_value) - self.feed('%s=%s,' % (arg_name, arg_string)) + arg_string, arg_imports = MigrationWriter.serialize(_arg_value) + self.feed('%s=%s,' % (_arg_name, arg_string)) imports.update(arg_imports) + + imports = set() + name, args, kwargs = self.operation.deconstruct() + argspec = inspect.getargspec(self.operation.__init__) + + # See if this operation is in django.db.migrations. If it is, + # We can just use the fact we already have that imported, + # otherwise, we need to add an import for the operation class. + if getattr(migrations, name, None) == self.operation.__class__: + self.feed('migrations.%s(' % name) + else: + imports.add('import %s' % (self.operation.__class__.__module__)) + self.feed('%s.%s(' % (self.operation.__class__.__module__, name)) + + self.indent() + + # Start at one because argspec includes "self" + for i, arg in enumerate(args, 1): + arg_value = arg + arg_name = argspec.args[i] + _write(arg_name, arg_value) + + i = len(args) + # Only iterate over remaining arguments + for arg_name in argspec.args[i + 1:]: + if arg_name in kwargs: + arg_value = kwargs[arg_name] + _write(arg_name, arg_value) + self.unindent() self.feed('),') return self.render(), imports diff --git a/tests/custom_migration_operations/operations.py b/tests/custom_migration_operations/operations.py index 3a4127d75336..bd62280f81a2 100644 --- a/tests/custom_migration_operations/operations.py +++ b/tests/custom_migration_operations/operations.py @@ -31,3 +31,64 @@ def database_backwards(self, app_label, schema_editor, from_state, to_state): class CreateModel(TestOperation): pass + + +class ArgsOperation(TestOperation): + def __init__(self, arg1, arg2): + self.arg1, self.arg2 = arg1, arg2 + + def deconstruct(self): + return ( + self.__class__.__name__, + [self.arg1, self.arg2], + {} + ) + + +class KwargsOperation(TestOperation): + def __init__(self, kwarg1=None, kwarg2=None): + self.kwarg1, self.kwarg2 = kwarg1, kwarg2 + + def deconstruct(self): + kwargs = {} + if self.kwarg1 is not None: + kwargs['kwarg1'] = self.kwarg1 + if self.kwarg2 is not None: + kwargs['kwarg2'] = self.kwarg2 + return ( + self.__class__.__name__, + [], + kwargs + ) + + +class ArgsKwargsOperation(TestOperation): + def __init__(self, arg1, arg2, kwarg1=None, kwarg2=None): + self.arg1, self.arg2 = arg1, arg2 + self.kwarg1, self.kwarg2 = kwarg1, kwarg2 + + def deconstruct(self): + kwargs = {} + if self.kwarg1 is not None: + kwargs['kwarg1'] = self.kwarg1 + if self.kwarg2 is not None: + kwargs['kwarg2'] = self.kwarg2 + return ( + self.__class__.__name__, + [self.arg1, self.arg2], + kwargs, + ) + + +class ExpandArgsOperation(TestOperation): + serialization_expand_args = ['arg'] + + def __init__(self, arg): + self.arg = arg + + def deconstruct(self): + return ( + self.__class__.__name__, + [self.arg], + {} + ) diff --git a/tests/migrations/test_writer.py b/tests/migrations/test_writer.py index edcc5b285d1d..2519e43f4275 100644 --- a/tests/migrations/test_writer.py +++ b/tests/migrations/test_writer.py @@ -10,8 +10,8 @@ from django.core.validators import RegexValidator, EmailValidator from django.db import models, migrations -from django.db.migrations.writer import MigrationWriter, SettingsReference -from django.test import TestCase, ignore_warnings +from django.db.migrations.writer import MigrationWriter, OperationWriter, SettingsReference +from django.test import SimpleTestCase, TestCase, ignore_warnings from django.conf import settings from django.utils import datetime_safe, six from django.utils.deconstruct import deconstructible @@ -30,6 +30,79 @@ def upload_to(self): thing = models.FileField(upload_to=upload_to) +class OperationWriterTests(SimpleTestCase): + + def test_empty_signature(self): + operation = custom_migration_operations.operations.TestOperation() + writer = OperationWriter(operation) + writer.indentation = 0 + buff, imports = writer.serialize() + self.assertEqual(imports, {'import custom_migration_operations.operations'}) + self.assertEqual( + buff, + 'custom_migration_operations.operations.TestOperation(\n' + '),' + ) + + def test_args_signature(self): + operation = custom_migration_operations.operations.ArgsOperation(1, 2) + writer = OperationWriter(operation) + writer.indentation = 0 + buff, imports = writer.serialize() + self.assertEqual(imports, {'import custom_migration_operations.operations'}) + self.assertEqual( + buff, + 'custom_migration_operations.operations.ArgsOperation(\n' + ' arg1=1,\n' + ' arg2=2,\n' + '),' + ) + + def test_kwargs_signature(self): + operation = custom_migration_operations.operations.KwargsOperation(kwarg1=1) + writer = OperationWriter(operation) + writer.indentation = 0 + buff, imports = writer.serialize() + self.assertEqual(imports, {'import custom_migration_operations.operations'}) + self.assertEqual( + buff, + 'custom_migration_operations.operations.KwargsOperation(\n' + ' kwarg1=1,\n' + '),' + ) + + def test_args_kwargs_signature(self): + operation = custom_migration_operations.operations.ArgsKwargsOperation(1, 2, kwarg2=4) + writer = OperationWriter(operation) + writer.indentation = 0 + buff, imports = writer.serialize() + self.assertEqual(imports, {'import custom_migration_operations.operations'}) + self.assertEqual( + buff, + 'custom_migration_operations.operations.ArgsKwargsOperation(\n' + ' arg1=1,\n' + ' arg2=2,\n' + ' kwarg2=4,\n' + '),' + ) + + def test_expand_args_signature(self): + operation = custom_migration_operations.operations.ExpandArgsOperation([1, 2]) + writer = OperationWriter(operation) + writer.indentation = 0 + buff, imports = writer.serialize() + self.assertEqual(imports, {'import custom_migration_operations.operations'}) + self.assertEqual( + buff, + 'custom_migration_operations.operations.ExpandArgsOperation(\n' + ' arg=[\n' + ' 1,\n' + ' 2,\n' + ' ],\n' + '),' + ) + + class WriterTests(TestCase): """ Tests the migration writer (makes migration files from Migration instances)