diff --git a/netbox_diode_plugin/api/views.py b/netbox_diode_plugin/api/views.py index acfee6e..e791ab8 100644 --- a/netbox_diode_plugin/api/views.py +++ b/netbox_diode_plugin/api/views.py @@ -35,7 +35,7 @@ def dynamic_import(name): return mod -def _get_index_class_fields(object_type): +def _get_index_class_fields(object_type: str | NetBoxType): """ Given an object type name (e.g., 'dcim.site'), dynamically find and return the corresponding Index class fields. @@ -43,23 +43,18 @@ def _get_index_class_fields(object_type): :return: The corresponding model and its Index class (e.g., SiteIndex) field names or None. """ try: - # Extract app_label and model_name from 'dcim.site' - app_label, model_name = object_type.split('.') + if isinstance(object_type, str): + app_label, model_name = object_type.split('.') + else: + app_label, model_name = object_type.app_label, object_type.model - # Get the model class dynamically model = apps.get_model(app_label, model_name) - # TagIndex registered in the netbox_diode_plugin if app_label == "extras" and model_name == "tag": app_label = "netbox_diode_plugin" - # Import the module where index classes are defined (adjust if needed) index_module = dynamic_import(f"{app_label}.search.{model.__name__}Index") - - # Retrieve the index class fields tuple fields = getattr(index_module, "fields", None) - - # Extract the field names list from the tuple field_names = [field[0] for field in fields] return model, field_names @@ -244,12 +239,13 @@ class ApplyChangeSetView(views.APIView): permission_classes = [IsAuthenticated, IsDiodeWriter] @staticmethod - def _get_object_type_model(object_type: str): + def _get_object_type_model(object_type: str | NetBoxType): """Get the object type model from object_type.""" - app_label, model_name = object_type.split(".") - object_content_type = NetBoxType.objects.get_by_natural_key( - app_label, model_name - ) + if isinstance(object_type, str): + app_label, model_name = object_type.split(".") + object_content_type = NetBoxType.objects.get_by_natural_key(app_label, model_name) + else: + object_content_type = object_type return object_content_type, object_content_type.model_class() def _get_assigned_object_type(self, model_name: str): @@ -274,19 +270,19 @@ def _get_serializer( object_data: dict, ): """Get the serializer for the object type.""" - object_type_model, object_type_model_class = self._get_object_type_model(object_type) + _, object_type_model_class = self._get_object_type_model(object_type) if change_type == "create": - return self._get_serializer_to_create(object_data, object_type, object_type_model, object_type_model_class) + return self._get_serializer_to_create(object_data, object_type, object_type_model_class) if change_type == "update": return self._get_serializer_to_update(object_data, object_id, object_type, object_type_model_class) raise ValidationError("Invalid change_type") - def _get_serializer_to_create(self, object_data, object_type, object_type_model, object_type_model_class): + def _get_serializer_to_create(self, object_data, object_type, object_type_model_class): # Get object data fields that are not dictionaries or lists - fields = self._get_fields_to_find_existing_objects(object_data, object_type, object_type_model) + fields = self._get_fields_to_find_existing_objects(object_data, object_type) # Check if the object already exists try: instance = object_type_model_class.objects.get(**fields) @@ -351,10 +347,11 @@ def _get_serializer_to_update(self, object_data, object_id, object_type, object_ ) return serializer - def _get_fields_to_find_existing_objects(self, object_data, object_type, object_type_model): + def _get_fields_to_find_existing_objects(self, object_data, object_type): fields = {} for key, value in object_data.items(): self._add_nested_opts(fields, key, value) + match object_type: case "dcim.interface" | "virtualization.vminterface": mac_address = fields.pop("mac_address", None) @@ -364,7 +361,18 @@ def _get_fields_to_find_existing_objects(self, object_data, object_type, object_ fields.pop("assigned_object_type") fields["assigned_object_type_id"] = fields.pop("assigned_object_id") case "ipam.prefix" | "virtualization.cluster": - fields["scope_type"] = object_type_model + if scope_type := object_data.get("scope_type"): + scope_type_model, _ = self._get_object_type_model(scope_type) + fields["scope_type"] = scope_type_model + case "virtualization.virtualmachine": + if cluster_scope_type := fields.get("cluster__scope_type"): + cluster_scope_type_model, _ = self._get_object_type_model(cluster_scope_type) + fields["cluster__scope_type"] = cluster_scope_type_model + case "virtualization.vminterface": + if cluster_scope_type := fields.get("virtual_machine__cluster__scope_type"): + cluster_scope_type_model, _ = self._get_object_type_model(cluster_scope_type) + fields["virtual_machine__cluster__scope_type"] = cluster_scope_type_model + return fields def _retrieve_primary_ip_address(self, primary_ip_attr: str, object_data: dict): @@ -515,13 +523,18 @@ def _handle_interface_mac_address_compat(self, instance, object_type: str, obje instance.save() return None - def _handle_scope(self, object_data: dict) -> Optional[Dict[str, Any]]: + def _handle_scope(self, object_data: dict, is_nested: bool = False) -> Optional[Dict[str, Any]]: """Handle scope object.""" if object_data.get("site"): site = object_data.pop("site") scope_type = "dcim.site" - _, object_type_model_class = self._get_object_type_model(scope_type) - object_data["scope_type"] = scope_type + object_type_model, object_type_model_class = self._get_object_type_model(scope_type) + # Scope type of the nested object happens to be resolved differently than in the top-level object + # and is expected to be a content type object instead of "app_label.model_name" string format + if is_nested: + object_data["scope_type"] = object_type_model + else: + object_data["scope_type"] = scope_type site_id = site.get("id", None) if site_id is None: try: @@ -544,9 +557,18 @@ def _transform_object_data(self, object_type: str, object_data: dict) -> Optiona case "ipam.ipaddress": errors = self._handle_ipaddress_assigned_object(object_data) case "ipam.prefix": - errors = self._handle_scope(object_data) + errors = self._handle_scope(object_data, False) case "virtualization.cluster": - errors = self._handle_scope(object_data) + errors = self._handle_scope(object_data, False) + case "virtualization.virtualmachine": + if cluster_object_data := object_data.get("cluster"): + errors = self._handle_scope(cluster_object_data, True) + object_data["cluster"] = cluster_object_data + case "virtualization.vminterface": + cluster_object_data = object_data.get("virtual_machine", {}).get("cluster") + if cluster_object_data is not None: + errors = self._handle_scope(cluster_object_data, True) + object_data["virtual_machine"]["cluster"] = cluster_object_data case _: pass diff --git a/netbox_diode_plugin/tests/test_api_apply_change_set.py b/netbox_diode_plugin/tests/test_api_apply_change_set.py index 6bef32d..62950d4 100644 --- a/netbox_diode_plugin/tests/test_api_apply_change_set.py +++ b/netbox_diode_plugin/tests/test_api_apply_change_set.py @@ -14,6 +14,7 @@ Site, ) from django.contrib.auth import get_user_model +from django.contrib.contenttypes.models import ContentType from ipam.models import ASN, RIR, IPAddress, Prefix from netaddr import IPNetwork from rest_framework import status @@ -101,9 +102,13 @@ def setUp(self): name="Cluster Type 1", slug="cluster-type-1" ) + self.cluster_types = (cluster_type,) + + site_content_type = ContentType.objects.get_for_model(Site) + self.clusters = ( - Cluster(name="Cluster 1", type=cluster_type), - Cluster(name="Cluster 2", type=cluster_type), + Cluster(name="Cluster 1", type=cluster_type, scope_type=site_content_type, scope_id=self.sites[0].id), + Cluster(name="Cluster 2", type=cluster_type, scope_type=site_content_type, scope_id=self.sites[0].id), ) Cluster.objects.bulk_create(self.clusters) @@ -1154,3 +1159,65 @@ def test_create_prefix_with_unknown_site_fails(self): response.json().get("errors")[0].get("site"), ) self.assertFalse(Prefix.objects.filter(prefix="192.168.0.0/24").exists()) + + def test_create_virtualization_cluster_with_site_stored_as_scope(self): + """Test create cluster with site stored as scope.""" + payload = { + "change_set_id": str(uuid.uuid4()), + "change_set": [ + { + "change_id": str(uuid.uuid4()), + "change_type": "create", + "object_version": None, + "object_type": "virtualization.cluster", + "object_id": None, + "data": { + "name": "Cluster 3", + "type": { + "name": self.cluster_types[0].name, + }, + "site": { + "name": self.sites[0].name, + }, + }, + }, + ], + } + response = self.send_request(payload) + + self.assertEqual(response.json().get("result"), "success") + self.assertEqual(Cluster.objects.get(name="Cluster 3").scope, self.sites[0]) + + def test_create_virtualmachine_with_cluster_site_stored_as_scope(self): + """Test create virtualmachine with cluster site stored as scope.""" + payload = { + "change_set_id": str(uuid.uuid4()), + "change_set": [ + { + "change_id": str(uuid.uuid4()), + "change_type": "create", + "object_version": None, + "object_type": "virtualization.virtualmachine", + "object_id": None, + "data": { + "name": "VM foobar", + "site": { + "name": self.sites[0].name, + }, + "cluster": { + "name": self.clusters[0].name, + "type": { + "name": self.cluster_types[0].name, + }, + "site": { + "name": self.sites[0].name, + }, + }, + }, + }, + ], + } + response = self.send_request(payload) + + self.assertEqual(response.json().get("result"), "success") + self.assertEqual(VirtualMachine.objects.get(name="VM foobar", site_id=self.sites[0].id).cluster.scope, self.sites[0])