@@ -35,31 +35,26 @@ def dynamic_import(name):
35
35
return mod
36
36
37
37
38
- def _get_index_class_fields (object_type ):
38
+ def _get_index_class_fields (object_type : str | NetBoxType ):
39
39
"""
40
40
Given an object type name (e.g., 'dcim.site'), dynamically find and return the corresponding Index class fields.
41
41
42
42
:param object_type: Object type name in the format 'app_label.model_name'
43
43
:return: The corresponding model and its Index class (e.g., SiteIndex) field names or None.
44
44
"""
45
45
try :
46
- # Extract app_label and model_name from 'dcim.site'
47
- app_label , model_name = object_type .split ('.' )
46
+ if isinstance (object_type , str ):
47
+ app_label , model_name = object_type .split ('.' )
48
+ else :
49
+ app_label , model_name = object_type .app_label , object_type .model
48
50
49
- # Get the model class dynamically
50
51
model = apps .get_model (app_label , model_name )
51
52
52
- # TagIndex registered in the netbox_diode_plugin
53
53
if app_label == "extras" and model_name == "tag" :
54
54
app_label = "netbox_diode_plugin"
55
55
56
- # Import the module where index classes are defined (adjust if needed)
57
56
index_module = dynamic_import (f"{ app_label } .search.{ model .__name__ } Index" )
58
-
59
- # Retrieve the index class fields tuple
60
57
fields = getattr (index_module , "fields" , None )
61
-
62
- # Extract the field names list from the tuple
63
58
field_names = [field [0 ] for field in fields ]
64
59
65
60
return model , field_names
@@ -244,12 +239,13 @@ class ApplyChangeSetView(views.APIView):
244
239
permission_classes = [IsAuthenticated , IsDiodeWriter ]
245
240
246
241
@staticmethod
247
- def _get_object_type_model (object_type : str ):
242
+ def _get_object_type_model (object_type : str | NetBoxType ):
248
243
"""Get the object type model from object_type."""
249
- app_label , model_name = object_type .split ("." )
250
- object_content_type = NetBoxType .objects .get_by_natural_key (
251
- app_label , model_name
252
- )
244
+ if isinstance (object_type , str ):
245
+ app_label , model_name = object_type .split ("." )
246
+ object_content_type = NetBoxType .objects .get_by_natural_key (app_label , model_name )
247
+ else :
248
+ object_content_type = object_type
253
249
return object_content_type , object_content_type .model_class ()
254
250
255
251
def _get_assigned_object_type (self , model_name : str ):
@@ -274,19 +270,19 @@ def _get_serializer(
274
270
object_data : dict ,
275
271
):
276
272
"""Get the serializer for the object type."""
277
- object_type_model , object_type_model_class = self ._get_object_type_model (object_type )
273
+ _ , object_type_model_class = self ._get_object_type_model (object_type )
278
274
279
275
if change_type == "create" :
280
- return self ._get_serializer_to_create (object_data , object_type , object_type_model , object_type_model_class )
276
+ return self ._get_serializer_to_create (object_data , object_type , object_type_model_class )
281
277
282
278
if change_type == "update" :
283
279
return self ._get_serializer_to_update (object_data , object_id , object_type , object_type_model_class )
284
280
285
281
raise ValidationError ("Invalid change_type" )
286
282
287
- def _get_serializer_to_create (self , object_data , object_type , object_type_model , object_type_model_class ):
283
+ def _get_serializer_to_create (self , object_data , object_type , object_type_model_class ):
288
284
# Get object data fields that are not dictionaries or lists
289
- fields = self ._get_fields_to_find_existing_objects (object_data , object_type , object_type_model )
285
+ fields = self ._get_fields_to_find_existing_objects (object_data , object_type )
290
286
# Check if the object already exists
291
287
try :
292
288
instance = object_type_model_class .objects .get (** fields )
@@ -351,10 +347,11 @@ def _get_serializer_to_update(self, object_data, object_id, object_type, object_
351
347
)
352
348
return serializer
353
349
354
- def _get_fields_to_find_existing_objects (self , object_data , object_type , object_type_model ):
350
+ def _get_fields_to_find_existing_objects (self , object_data , object_type ):
355
351
fields = {}
356
352
for key , value in object_data .items ():
357
353
self ._add_nested_opts (fields , key , value )
354
+
358
355
match object_type :
359
356
case "dcim.interface" | "virtualization.vminterface" :
360
357
mac_address = fields .pop ("mac_address" , None )
@@ -364,7 +361,18 @@ def _get_fields_to_find_existing_objects(self, object_data, object_type, object_
364
361
fields .pop ("assigned_object_type" )
365
362
fields ["assigned_object_type_id" ] = fields .pop ("assigned_object_id" )
366
363
case "ipam.prefix" | "virtualization.cluster" :
367
- fields ["scope_type" ] = object_type_model
364
+ if scope_type := object_data .get ("scope_type" ):
365
+ scope_type_model , _ = self ._get_object_type_model (scope_type )
366
+ fields ["scope_type" ] = scope_type_model
367
+ case "virtualization.virtualmachine" :
368
+ if cluster_scope_type := fields .get ("cluster__scope_type" ):
369
+ cluster_scope_type_model , _ = self ._get_object_type_model (cluster_scope_type )
370
+ fields ["cluster__scope_type" ] = cluster_scope_type_model
371
+ case "virtualization.vminterface" :
372
+ if cluster_scope_type := fields .get ("virtual_machine__cluster__scope_type" ):
373
+ cluster_scope_type_model , _ = self ._get_object_type_model (cluster_scope_type )
374
+ fields ["virtual_machine__cluster__scope_type" ] = cluster_scope_type_model
375
+
368
376
return fields
369
377
370
378
def _retrieve_primary_ip_address (self , primary_ip_attr : str , object_data : dict ):
@@ -515,13 +523,18 @@ def _handle_interface_mac_address_compat(self, instance, object_type: str, obje
515
523
instance .save ()
516
524
return None
517
525
518
- def _handle_scope (self , object_data : dict ) -> Optional [Dict [str , Any ]]:
526
+ def _handle_scope (self , object_data : dict , is_nested : bool = False ) -> Optional [Dict [str , Any ]]:
519
527
"""Handle scope object."""
520
528
if object_data .get ("site" ):
521
529
site = object_data .pop ("site" )
522
530
scope_type = "dcim.site"
523
- _ , object_type_model_class = self ._get_object_type_model (scope_type )
524
- object_data ["scope_type" ] = scope_type
531
+ object_type_model , object_type_model_class = self ._get_object_type_model (scope_type )
532
+ # Scope type of the nested object happens to be resolved differently than in the top-level object
533
+ # and is expected to be a content type object instead of "app_label.model_name" string format
534
+ if is_nested :
535
+ object_data ["scope_type" ] = object_type_model
536
+ else :
537
+ object_data ["scope_type" ] = scope_type
525
538
site_id = site .get ("id" , None )
526
539
if site_id is None :
527
540
try :
@@ -544,9 +557,18 @@ def _transform_object_data(self, object_type: str, object_data: dict) -> Optiona
544
557
case "ipam.ipaddress" :
545
558
errors = self ._handle_ipaddress_assigned_object (object_data )
546
559
case "ipam.prefix" :
547
- errors = self ._handle_scope (object_data )
560
+ errors = self ._handle_scope (object_data , False )
548
561
case "virtualization.cluster" :
549
- errors = self ._handle_scope (object_data )
562
+ errors = self ._handle_scope (object_data , False )
563
+ case "virtualization.virtualmachine" :
564
+ if cluster_object_data := object_data .get ("cluster" ):
565
+ errors = self ._handle_scope (cluster_object_data , True )
566
+ object_data ["cluster" ] = cluster_object_data
567
+ case "virtualization.vminterface" :
568
+ cluster_object_data = object_data .get ("virtual_machine" , {}).get ("cluster" )
569
+ if cluster_object_data is not None :
570
+ errors = self ._handle_scope (cluster_object_data , True )
571
+ object_data ["virtual_machine" ]["cluster" ] = cluster_object_data
550
572
case _:
551
573
pass
552
574
@@ -651,3 +673,24 @@ class ApplyChangeSetException(Exception):
651
673
"""ApplyChangeSetException used to cause atomic transaction rollback."""
652
674
653
675
pass
676
+
677
+ #####
678
+
679
+ import logging
680
+ logger = logging .getLogger ("netbox.diode_data" )
681
+
682
+
683
+ class GenerateDiffView (views .APIView ):
684
+ """GenerateDiff view."""
685
+
686
+ permission_classes = [IsAuthenticated , IsDiodeWriter ]
687
+
688
+ def post (self , request , * args , ** kwargs ):
689
+ """Generate diff for entity."""
690
+
691
+ entity = request .data .get ("entity" )
692
+ object_type = request .data .get ("object_type" )
693
+
694
+ logger .error (f"generate diff called with entity: { entity } and object_type: { object_type } " )
695
+
696
+ return Response ({}, status = status .HTTP_500_INTERNAL_SERVER_ERROR )
0 commit comments