diff --git a/taggit/managers.py b/taggit/managers.py index 0ff8e976..6d02eda5 100644 --- a/taggit/managers.py +++ b/taggit/managers.py @@ -1,5 +1,6 @@ import uuid from operator import attrgetter +from typing import List from django.conf import settings from django.contrib.contenttypes.fields import GenericRelation @@ -56,6 +57,20 @@ def clone(self): return type(self)(self.alias, self.col, self.content_types[:]) +class NaturalKeyManager(models.Manager): + def __init__(self, natural_key_fields: List[str], *args, **kwargs): + super().__init__(*args, **kwargs) + self.natural_key_fields = natural_key_fields + + def get_by_natural_key(self, *args): + if len(args) != len(self.model.natural_key_fields): + raise ValueError( + "Number of arguments does not match number of natural key fields." + ) + lookup_kwargs = dict(zip(self.model.natural_key_fields, args)) + return self.get(**lookup_kwargs) + + class _TaggableManager(models.Manager): # TODO investigate whether we can use a RelatedManager instead of all this stuff # to take advantage of all the Django goodness diff --git a/taggit/models.py b/taggit/models.py index 8d7f60bd..43064d43 100644 --- a/taggit/models.py +++ b/taggit/models.py @@ -1,3 +1,5 @@ +from typing import List + from django.conf import settings from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.models import ContentType @@ -15,7 +17,29 @@ def unidecode(tag): return tag -class TagBase(models.Model): +class NaturalKeyManager(models.Manager): + def __init__(self, natural_key_fields: List[str], *args, **kwargs): + super().__init__(*args, **kwargs) + self.natural_key_fields = natural_key_fields + + def get_by_natural_key(self, *args): + if len(args) != len(self.model.natural_key_fields): + raise ValueError( + "Number of arguments does not match number of natural key fields." + ) + lookup_kwargs = dict(zip(self.model.natural_key_fields, args)) + return self.get(**lookup_kwargs) + + +class NaturalKeyModel(models.Model): + def natural_key(self): + return (getattr(self, field) for field in self.natural_key_fields) + + class Meta: + abstract = True + + +class TagBase(NaturalKeyModel): name = models.CharField( verbose_name=pgettext_lazy("A tag name", "name"), unique=True, max_length=100 ) @@ -26,6 +50,9 @@ class TagBase(models.Model): allow_unicode=True, ) + natural_key_fields = ["name"] + objects = NaturalKeyManager(natural_key_fields) + def __str__(self): return self.name @@ -91,13 +118,15 @@ class Meta: app_label = "taggit" -class ItemBase(models.Model): +class ItemBase(NaturalKeyModel): def __str__(self): return gettext("%(object)s tagged with %(tag)s") % { "object": self.content_object, "tag": self.tag, } + objects = NaturalKeyManager(natural_key_fields=["name"]) + class Meta: abstract = True @@ -170,6 +199,7 @@ def tags_for(cls, model, instance=None, **extra_filters): class GenericTaggedItemBase(CommonGenericTaggedItemBase): object_id = models.IntegerField(verbose_name=_("object ID"), db_index=True) + natural_key_fields = ["object_id"] class Meta: abstract = True @@ -177,6 +207,7 @@ class Meta: class GenericUUIDTaggedItemBase(CommonGenericTaggedItemBase): object_id = models.UUIDField(verbose_name=_("object ID"), db_index=True) + natural_key_fields = ["object_id"] class Meta: abstract = True diff --git a/tests/tests.py b/tests/tests.py index 79489598..7b041db8 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -1,3 +1,4 @@ +import os from io import StringIO from unittest import mock @@ -1398,3 +1399,100 @@ def test_tests_have_no_pending_migrations(self): out = StringIO() call_command("makemigrations", "tests", dry_run=True, stdout=out) self.assertEqual(out.getvalue().strip(), "No changes detected in app 'tests'") + + +class NaturalKeyTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.tag_names = ["circle", "square", "triangle", "rectangle", "pentagon"] + cls.filename = "test_data_dump.json" + cls.tag_count = len(cls.tag_names) + + def setUp(self): + self.tags = self._create_tags() + + def tearDown(self): + self._clear_existing_tags() + try: + os.remove(self.filename) + except FileNotFoundError: + pass + + @property + def _queryset(self): + return Tag.objects.filter(name__in=self.tag_names) + + def _create_tags(self): + return Tag.objects.bulk_create( + [Tag(name=shape, slug=shape) for shape in self.tag_names], + ignore_conflicts=True, + ) + + def _clear_existing_tags(self): + self._queryset.delete() + + def _dump_model(self, model): + model_label = model._meta.label + with open(self.filename, "w") as f: + call_command( + "dumpdata", + model_label, + natural_primary=True, + use_natural_foreign_keys=True, + stdout=f, + ) + + def _load_model(self): + call_command("loaddata", self.filename) + + def test_tag_natural_key(self): + """ + Test that tags can be dumped and loaded using natural keys. + """ + + # confirm count in the DB + self.assertEqual(self._queryset.count(), self.tag_count) + + # dump all tags to a file + self._dump_model(Tag) + + # Delete all tags + self._clear_existing_tags() + + # confirm all tags clear + self.assertEqual(self._queryset.count(), 0) + + # load the tags from the file + self._load_model() + + # confirm count in the DB + self.assertEqual(self._queryset.count(), self.tag_count) + + def test_tag_reloading_with_changed_pk(self): + """Test that tags are not reliant on the primary key of the tag model. + + Test that data is correctly loaded after database state has changed. + + """ + original_shape = self._queryset.first() + original_pk = original_shape.pk + original_shape_name = original_shape.name + new_shape_name = "hexagon" + + # dump all tags to a file + self._dump_model(Tag) + + # Delete the tag + self._clear_existing_tags() + + # create new tag with the same PK + Tag.objects.create(name=new_shape_name, slug=new_shape_name, pk=original_pk) + + # Load the tags from the file + self._load_model() + + # confirm that load did not overwrite the new_shape + self.assertEqual(Tag.objects.get(pk=original_pk).name, new_shape_name) + + # confirm that the original shape was reloaded with a different PK + self.assertNotEqual(Tag.objects.get(name=original_shape_name).pk, original_pk)