Skip to content

Commit

Permalink
708 - Add Natural Key Support to Tags model.
Browse files Browse the repository at this point in the history
  • Loading branch information
Trafire committed Jul 14, 2024
1 parent 5cdfef7 commit f1cc22e
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 2 deletions.
15 changes: 15 additions & 0 deletions taggit/managers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import uuid
from operator import attrgetter
from typing import List

from django.conf import settings
from django.contrib.contenttypes.fields import GenericRelation
Expand Down Expand Up @@ -56,6 +57,20 @@ def clone(self):
return type(self)(self.alias, self.col, self.content_types[:])


class NaturalKeyManager(models.Manager):
def __init__(self, natural_key_fields: List[str], *args, **kwargs):
super().__init__(*args, **kwargs)
self.natural_key_fields = natural_key_fields

def get_by_natural_key(self, *args):
if len(args) != len(self.model.natural_key_fields):
raise ValueError(
"Number of arguments does not match number of natural key fields."
)
lookup_kwargs = dict(zip(self.model.natural_key_fields, args))
return self.get(**lookup_kwargs)


class _TaggableManager(models.Manager):
# TODO investigate whether we can use a RelatedManager instead of all this stuff
# to take advantage of all the Django goodness
Expand Down
35 changes: 33 additions & 2 deletions taggit/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

from django.conf import settings
from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
Expand All @@ -15,7 +17,29 @@ def unidecode(tag):
return tag


class TagBase(models.Model):
class NaturalKeyManager(models.Manager):
def __init__(self, natural_key_fields: List[str], *args, **kwargs):
super().__init__(*args, **kwargs)
self.natural_key_fields = natural_key_fields

def get_by_natural_key(self, *args):
if len(args) != len(self.model.natural_key_fields):
raise ValueError(
"Number of arguments does not match number of natural key fields."
)
lookup_kwargs = dict(zip(self.model.natural_key_fields, args))
return self.get(**lookup_kwargs)


class NaturalKeyModel(models.Model):
def natural_key(self):
return (getattr(self, field) for field in self.natural_key_fields)

class Meta:
abstract = True


class TagBase(NaturalKeyModel):
name = models.CharField(
verbose_name=pgettext_lazy("A tag name", "name"), unique=True, max_length=100
)
Expand All @@ -26,6 +50,9 @@ class TagBase(models.Model):
allow_unicode=True,
)

natural_key_fields = ["name"]
objects = NaturalKeyManager(natural_key_fields)

def __str__(self):
return self.name

Expand Down Expand Up @@ -91,13 +118,15 @@ class Meta:
app_label = "taggit"


class ItemBase(models.Model):
class ItemBase(NaturalKeyModel):
def __str__(self):
return gettext("%(object)s tagged with %(tag)s") % {
"object": self.content_object,
"tag": self.tag,
}

objects = NaturalKeyManager(natural_key_fields=["name"])

class Meta:
abstract = True

Expand Down Expand Up @@ -170,13 +199,15 @@ def tags_for(cls, model, instance=None, **extra_filters):

class GenericTaggedItemBase(CommonGenericTaggedItemBase):
object_id = models.IntegerField(verbose_name=_("object ID"), db_index=True)
natural_key_fields = ["object_id"]

class Meta:
abstract = True


class GenericUUIDTaggedItemBase(CommonGenericTaggedItemBase):
object_id = models.UUIDField(verbose_name=_("object ID"), db_index=True)
natural_key_fields = ["object_id"]

class Meta:
abstract = True
Expand Down
98 changes: 98 additions & 0 deletions tests/tests.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from io import StringIO
from unittest import mock

Expand Down Expand Up @@ -1398,3 +1399,100 @@ def test_tests_have_no_pending_migrations(self):
out = StringIO()
call_command("makemigrations", "tests", dry_run=True, stdout=out)
self.assertEqual(out.getvalue().strip(), "No changes detected in app 'tests'")


class NaturalKeyTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.tag_names = ["circle", "square", "triangle", "rectangle", "pentagon"]
cls.filename = "test_data_dump.json"
cls.tag_count = len(cls.tag_names)

def setUp(self):
self.tags = self._create_tags()

def tearDown(self):
self._clear_existing_tags()
try:
os.remove(self.filename)
except FileNotFoundError:
pass

@property
def _queryset(self):
return Tag.objects.filter(name__in=self.tag_names)

def _create_tags(self):
return Tag.objects.bulk_create(
[Tag(name=shape, slug=shape) for shape in self.tag_names],
ignore_conflicts=True,
)

def _clear_existing_tags(self):
self._queryset.delete()

def _dump_model(self, model):
model_label = model._meta.label
with open(self.filename, "w") as f:
call_command(
"dumpdata",
model_label,
natural_primary=True,
use_natural_foreign_keys=True,
stdout=f,
)

def _load_model(self):
call_command("loaddata", self.filename)

def test_tag_natural_key(self):
"""
Test that tags can be dumped and loaded using natural keys.
"""

# confirm count in the DB
self.assertEqual(self._queryset.count(), self.tag_count)

# dump all tags to a file
self._dump_model(Tag)

# Delete all tags
self._clear_existing_tags()

# confirm all tags clear
self.assertEqual(self._queryset.count(), 0)

# load the tags from the file
self._load_model()

# confirm count in the DB
self.assertEqual(self._queryset.count(), self.tag_count)

def test_tag_reloading_with_changed_pk(self):
"""Test that tags are not reliant on the primary key of the tag model.
Test that data is correctly loaded after database state has changed.
"""
original_shape = self._queryset.first()
original_pk = original_shape.pk
original_shape_name = original_shape.name
new_shape_name = "hexagon"

# dump all tags to a file
self._dump_model(Tag)

# Delete the tag
self._clear_existing_tags()

# create new tag with the same PK
Tag.objects.create(name=new_shape_name, slug=new_shape_name, pk=original_pk)

# Load the tags from the file
self._load_model()

# confirm that load did not overwrite the new_shape
self.assertEqual(Tag.objects.get(pk=original_pk).name, new_shape_name)

# confirm that the original shape was reloaded with a different PK
self.assertNotEqual(Tag.objects.get(name=original_shape_name).pk, original_pk)

0 comments on commit f1cc22e

Please sign in to comment.