Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: scope support on apply change set #64

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 48 additions & 26 deletions netbox_diode_plugin/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,26 @@
return mod


def _get_index_class_fields(object_type):
def _get_index_class_fields(object_type: str | NetBoxType):
"""
Given an object type name (e.g., 'dcim.site'), dynamically find and return the corresponding Index class fields.

:param object_type: Object type name in the format 'app_label.model_name'
:return: The corresponding model and its Index class (e.g., SiteIndex) field names or None.
"""
try:
# Extract app_label and model_name from 'dcim.site'
app_label, model_name = object_type.split('.')
if isinstance(object_type, str):
app_label, model_name = object_type.split('.')
else:
app_label, model_name = object_type.app_label, object_type.model

# Get the model class dynamically
model = apps.get_model(app_label, model_name)

# TagIndex registered in the netbox_diode_plugin
if app_label == "extras" and model_name == "tag":
app_label = "netbox_diode_plugin"

# Import the module where index classes are defined (adjust if needed)
index_module = dynamic_import(f"{app_label}.search.{model.__name__}Index")

# Retrieve the index class fields tuple
fields = getattr(index_module, "fields", None)

# Extract the field names list from the tuple
field_names = [field[0] for field in fields]

return model, field_names
Expand Down Expand Up @@ -244,12 +239,13 @@
permission_classes = [IsAuthenticated, IsDiodeWriter]

@staticmethod
def _get_object_type_model(object_type: str):
def _get_object_type_model(object_type: str | NetBoxType):
"""Get the object type model from object_type."""
app_label, model_name = object_type.split(".")
object_content_type = NetBoxType.objects.get_by_natural_key(
app_label, model_name
)
if isinstance(object_type, str):
app_label, model_name = object_type.split(".")
object_content_type = NetBoxType.objects.get_by_natural_key(app_label, model_name)
else:
object_content_type = object_type
return object_content_type, object_content_type.model_class()

def _get_assigned_object_type(self, model_name: str):
Expand All @@ -274,19 +270,19 @@
object_data: dict,
):
"""Get the serializer for the object type."""
object_type_model, object_type_model_class = self._get_object_type_model(object_type)
_, object_type_model_class = self._get_object_type_model(object_type)

if change_type == "create":
return self._get_serializer_to_create(object_data, object_type, object_type_model, object_type_model_class)
return self._get_serializer_to_create(object_data, object_type, object_type_model_class)

if change_type == "update":
return self._get_serializer_to_update(object_data, object_id, object_type, object_type_model_class)

raise ValidationError("Invalid change_type")

def _get_serializer_to_create(self, object_data, object_type, object_type_model, object_type_model_class):
def _get_serializer_to_create(self, object_data, object_type, object_type_model_class):
# Get object data fields that are not dictionaries or lists
fields = self._get_fields_to_find_existing_objects(object_data, object_type, object_type_model)
fields = self._get_fields_to_find_existing_objects(object_data, object_type)
# Check if the object already exists
try:
instance = object_type_model_class.objects.get(**fields)
Expand Down Expand Up @@ -351,10 +347,11 @@
)
return serializer

def _get_fields_to_find_existing_objects(self, object_data, object_type, object_type_model):
def _get_fields_to_find_existing_objects(self, object_data, object_type):

Check failure on line 350 in netbox_diode_plugin/api/views.py

View workflow job for this annotation

GitHub Actions / tests (3.10)

Ruff (C901)

netbox_diode_plugin/api/views.py:350:9: C901 `_get_fields_to_find_existing_objects` is too complex (11 > 10)
fields = {}
for key, value in object_data.items():
self._add_nested_opts(fields, key, value)

match object_type:
case "dcim.interface" | "virtualization.vminterface":
mac_address = fields.pop("mac_address", None)
Expand All @@ -364,7 +361,18 @@
fields.pop("assigned_object_type")
fields["assigned_object_type_id"] = fields.pop("assigned_object_id")
case "ipam.prefix" | "virtualization.cluster":
fields["scope_type"] = object_type_model
if scope_type := object_data.get("scope_type"):
scope_type_model, _ = self._get_object_type_model(scope_type)
fields["scope_type"] = scope_type_model
case "virtualization.virtualmachine":
if cluster_scope_type := fields.get("cluster__scope_type"):
cluster_scope_type_model, _ = self._get_object_type_model(cluster_scope_type)
fields["cluster__scope_type"] = cluster_scope_type_model
case "virtualization.vminterface":
if cluster_scope_type := fields.get("virtual_machine__cluster__scope_type"):
cluster_scope_type_model, _ = self._get_object_type_model(cluster_scope_type)
fields["virtual_machine__cluster__scope_type"] = cluster_scope_type_model

return fields

def _retrieve_primary_ip_address(self, primary_ip_attr: str, object_data: dict):
Expand Down Expand Up @@ -515,13 +523,18 @@
instance.save()
return None

def _handle_scope(self, object_data: dict) -> Optional[Dict[str, Any]]:
def _handle_scope(self, object_data: dict, is_nested: bool = False) -> Optional[Dict[str, Any]]:
"""Handle scope object."""
if object_data.get("site"):
site = object_data.pop("site")
scope_type = "dcim.site"
_, object_type_model_class = self._get_object_type_model(scope_type)
object_data["scope_type"] = scope_type
object_type_model, object_type_model_class = self._get_object_type_model(scope_type)
# Scope type of the nested object happens to be resolved differently than in the top-level object
# and is expected to be a content type object instead of "app_label.model_name" string format
if is_nested:
object_data["scope_type"] = object_type_model
else:
object_data["scope_type"] = scope_type
site_id = site.get("id", None)
if site_id is None:
try:
Expand All @@ -544,9 +557,18 @@
case "ipam.ipaddress":
errors = self._handle_ipaddress_assigned_object(object_data)
case "ipam.prefix":
errors = self._handle_scope(object_data)
errors = self._handle_scope(object_data, False)
case "virtualization.cluster":
errors = self._handle_scope(object_data)
errors = self._handle_scope(object_data, False)
case "virtualization.virtualmachine":
if cluster_object_data := object_data.get("cluster"):
errors = self._handle_scope(cluster_object_data, True)
object_data["cluster"] = cluster_object_data
case "virtualization.vminterface":
cluster_object_data = object_data.get("virtual_machine", {}).get("cluster")
if cluster_object_data is not None:
errors = self._handle_scope(cluster_object_data, True)
object_data["virtual_machine"]["cluster"] = cluster_object_data
case _:
pass

Expand Down
71 changes: 69 additions & 2 deletions netbox_diode_plugin/tests/test_api_apply_change_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Site,
)
from django.contrib.auth import get_user_model
from django.contrib.contenttypes.models import ContentType
from ipam.models import ASN, RIR, IPAddress, Prefix
from netaddr import IPNetwork
from rest_framework import status
Expand Down Expand Up @@ -101,9 +102,13 @@ def setUp(self):
name="Cluster Type 1", slug="cluster-type-1"
)

self.cluster_types = (cluster_type,)

site_content_type = ContentType.objects.get_for_model(Site)

self.clusters = (
Cluster(name="Cluster 1", type=cluster_type),
Cluster(name="Cluster 2", type=cluster_type),
Cluster(name="Cluster 1", type=cluster_type, scope_type=site_content_type, scope_id=self.sites[0].id),
Cluster(name="Cluster 2", type=cluster_type, scope_type=site_content_type, scope_id=self.sites[0].id),
)
Cluster.objects.bulk_create(self.clusters)

Expand Down Expand Up @@ -1154,3 +1159,65 @@ def test_create_prefix_with_unknown_site_fails(self):
response.json().get("errors")[0].get("site"),
)
self.assertFalse(Prefix.objects.filter(prefix="192.168.0.0/24").exists())

def test_create_virtualization_cluster_with_site_stored_as_scope(self):
"""Test create cluster with site stored as scope."""
payload = {
"change_set_id": str(uuid.uuid4()),
"change_set": [
{
"change_id": str(uuid.uuid4()),
"change_type": "create",
"object_version": None,
"object_type": "virtualization.cluster",
"object_id": None,
"data": {
"name": "Cluster 3",
"type": {
"name": self.cluster_types[0].name,
},
"site": {
"name": self.sites[0].name,
},
},
},
],
}
response = self.send_request(payload)

self.assertEqual(response.json().get("result"), "success")
self.assertEqual(Cluster.objects.get(name="Cluster 3").scope, self.sites[0])

def test_create_virtualmachine_with_cluster_site_stored_as_scope(self):
"""Test create virtualmachine with cluster site stored as scope."""
payload = {
"change_set_id": str(uuid.uuid4()),
"change_set": [
{
"change_id": str(uuid.uuid4()),
"change_type": "create",
"object_version": None,
"object_type": "virtualization.virtualmachine",
"object_id": None,
"data": {
"name": "VM foobar",
"site": {
"name": self.sites[0].name,
},
"cluster": {
"name": self.clusters[0].name,
"type": {
"name": self.cluster_types[0].name,
},
"site": {
"name": self.sites[0].name,
},
},
},
},
],
}
response = self.send_request(payload)

self.assertEqual(response.json().get("result"), "success")
self.assertEqual(VirtualMachine.objects.get(name="VM foobar", site_id=self.sites[0].id).cluster.scope, self.sites[0])
Loading