diff --git a/natural_keys/models.py b/natural_keys/models.py index 3f1261b..93eac6d 100644 --- a/natural_keys/models.py +++ b/natural_keys/models.py @@ -71,13 +71,14 @@ def get_by_natural_key(self, *args): return self.get(**kwargs) - def create_by_natural_key(self, *args): + def create_by_natural_key(self, *args, **kwargs): """ Create a new object from the provided natural key values. If the natural key contains related objects, recursively get or create them by their natural keys. """ + defaults = keyword_only_defaults(kwargs, 'create_by_natural_key') kwargs = self.natural_key_kwargs(*args) for name, rel_to in self.model.get_natural_key_info(): if not rel_to: @@ -90,24 +91,36 @@ def create_by_natural_key(self, *args): ) else: kwargs[name] = None - return self.create(**kwargs) - - def get_or_create_by_natural_key(self, *args): + if defaults: + attrs = defaults + attrs.update(kwargs) + else: + attrs = kwargs + return self.create(**attrs) + + def get_or_create_by_natural_key(self, *args, **kwargs): """ get_or_create + get_by_natural_key """ + defaults = keyword_only_defaults( + kwargs, 'get_or_create_by_natural_key' + ) try: return self.get_by_natural_key(*args), False except self.model.DoesNotExist: - return self.create_by_natural_key(*args), True + return self.create_by_natural_key(*args, defaults=defaults), True # Shortcut for common use case - def find(self, *args): + def find(self, *args, **kwargs): """ Shortcut for get_or_create_by_natural_key that discards the "created" boolean. """ - obj, is_new = self.get_or_create_by_natural_key(*args) + defaults = keyword_only_defaults(kwargs, 'find') + obj, is_new = self.get_or_create_by_natural_key( + *args, + defaults=defaults + ) return obj def natural_key_kwargs(self, *args): @@ -239,3 +252,16 @@ def extract_nested_key(key, cls, prefix=''): return values else: return None + + +# TODO: Once we drop Python 2.7 support, this can be removed +def keyword_only_defaults(kwargs, fname): + defaults = kwargs.pop('defaults', None) + if kwargs: + raise TypeError( + "{fname}() got an unexpected keyword argument '{arg}'".format( + fname=fname, + arg=list(kwargs.keys())[0] + ) + ) + return defaults diff --git a/tests/test_app/migrations/0001_initial.py b/tests/test_app/migrations/0001_initial.py index 980001b..a89e0e8 100644 --- a/tests/test_app/migrations/0001_initial.py +++ b/tests/test_app/migrations/0001_initial.py @@ -1,6 +1,4 @@ -# -*- coding: utf-8 -*- -# Generated by Django 1.11 on 2017-04-21 13:17 -from __future__ import unicode_literals +# Generated by Django 2.2.3 on 2019-07-26 02:08 from django.db import migrations, models import django.db.models.deletion @@ -15,20 +13,24 @@ class Migration(migrations.Migration): operations = [ migrations.CreateModel( - name='ModelWithNaturalKey', + name='ModelWithSingleUniqueField', fields=[ ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('value', models.CharField(max_length=10)), + ('code', models.CharField(max_length=10, unique=True)), ], + options={ + 'abstract': False, + }, ), migrations.CreateModel( - name='ModelWithSingleUniqueField', + name='NaturalKeyParent', fields=[ ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('code', models.CharField(max_length=10, unique=True)), + ('code', models.CharField(max_length=10)), + ('group', models.CharField(max_length=10)), ], options={ - 'abstract': False, + 'unique_together': {('code', 'group')}, }, ), migrations.CreateModel( @@ -36,32 +38,30 @@ class Migration(migrations.Migration): fields=[ ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), ('mode', models.CharField(max_length=10)), + ('parent', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='test_app.NaturalKeyParent')), ], + options={ + 'unique_together': {('parent', 'mode')}, + }, ), migrations.CreateModel( - name='NaturalKeyParent', + name='ModelWithNaturalKey', fields=[ ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('code', models.CharField(max_length=10)), - ('group', models.CharField(max_length=10)), + ('value', models.CharField(max_length=10)), + ('key', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='test_app.NaturalKeyChild')), ], ), - migrations.AlterUniqueTogether( - name='naturalkeyparent', - unique_together=set([('code', 'group')]), - ), - migrations.AddField( - model_name='naturalkeychild', - name='parent', - field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='test_app.NaturalKeyParent'), - ), - migrations.AddField( - model_name='modelwithnaturalkey', - name='key', - field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='test_app.NaturalKeyChild'), - ), - migrations.AlterUniqueTogether( - name='naturalkeychild', - unique_together=set([('parent', 'mode')]), + migrations.CreateModel( + name='ModelWithExtraField', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('code', models.CharField(max_length=10, unique=True)), + ('date', models.DateField(max_length=10, unique=True)), + ('extra', models.TextField()), + ], + options={ + 'unique_together': {('code', 'date')}, + }, ), ] diff --git a/tests/test_app/models.py b/tests/test_app/models.py index dfcbc50..6751cea 100644 --- a/tests/test_app/models.py +++ b/tests/test_app/models.py @@ -25,3 +25,12 @@ class ModelWithNaturalKey(models.Model): class ModelWithSingleUniqueField(NaturalKeyModel): code = models.CharField(max_length=10, unique=True) + + +class ModelWithExtraField(NaturalKeyModel): + code = models.CharField(max_length=10, unique=True) + date = models.DateField(max_length=10, unique=True) + extra = models.TextField() + + class Meta: + unique_together = ['code', 'date'] diff --git a/tests/test_naturalkey.py b/tests/test_naturalkey.py index f5dd69e..56fc2c5 100644 --- a/tests/test_naturalkey.py +++ b/tests/test_naturalkey.py @@ -2,7 +2,7 @@ from rest_framework import status from tests.test_app.models import ( NaturalKeyParent, NaturalKeyChild, ModelWithNaturalKey, - ModelWithSingleUniqueField + ModelWithSingleUniqueField, ModelWithExtraField ) from natural_keys import NaturalKeySerializer from django.db.utils import IntegrityError @@ -315,3 +315,25 @@ def test_filter_with_Q(self): ModelWithSingleUniqueField.objects.filter(query).count(), 0 ) + + def test_find_with_defaults(self): + obj = ModelWithExtraField.objects.find( + 'extra1', + '2019-07-26', + defaults={'extra': 'Test 123'}, + ) + self.assertEqual( + obj.extra, + 'Test 123' + ) + + def test_find_with_kwargs(self): + with self.assertRaises(TypeError) as e: + ModelWithExtraField.objects.find( + 'extra1', + date='2019-07-26', + ) + self.assertEqual( + str(e.exception), + "find() got an unexpected keyword argument 'date'" + )