@@ -276,10 +276,13 @@ def import_node(self, node: onnx.NodeProto):
276
276
with InsertionPoint (self ._b ), Location .name (node .name ):
277
277
op_type = node .op_type
278
278
# Handle special op types that materialize to non-op IR constructs.
279
+ # Handlers return True if the op was handled, else this function
280
+ # should process it as a general node.
279
281
special_key = f"_handle_node_{ op_type } "
280
282
if hasattr (self , special_key ):
281
- getattr (self , special_key )(node )
282
- return
283
+ was_handled = getattr (self , special_key )(node )
284
+ if was_handled :
285
+ return
283
286
284
287
# General node import.
285
288
input_values = []
@@ -333,16 +336,19 @@ def import_attributes(
333
336
)
334
337
attrs [f"torch.onnx.{ onnx_attr .name } " ] = handler (onnx_attr , self ._cc )
335
338
336
- def import_initializer (self , initializer : onnx .TensorProto ) -> Value :
337
- with InsertionPoint (self ._b ), Location .name (initializer .name ):
339
+ def import_initializer (self , initializer : onnx .TensorProto , extern_name : str = None ) -> Value :
340
+ # If an explicitly specified name is given, use that; otherwise, pick
341
+ # up the name from the tensor proto itself
342
+ iname = extern_name if extern_name else initializer .name
343
+ with InsertionPoint (self ._b ), Location .name (iname ):
338
344
value_attr = self ._cc .tensor_proto_to_attr (initializer )
339
345
vtensor_type = self ._cc .tensor_proto_to_type (initializer )
340
346
literal_op = Operation .create (
341
347
name = "torch.vtensor.literal" ,
342
348
results = [vtensor_type ],
343
349
attributes = {"value" : value_attr },
344
350
)
345
- self ._nv_map [initializer . name ] = literal_op .result
351
+ self ._nv_map [iname ] = literal_op .result
346
352
return literal_op .result
347
353
348
354
def _get_immediate_tensor (self , name : str ) -> np .array :
@@ -366,7 +372,23 @@ def _get_immediate_tensor(self, name: str) -> np.array:
366
372
f"Unhandled ONNX TensorProto immediate data: { initializer } "
367
373
)
368
374
369
- def _handle_node_ConstantOfShape (self , node : onnx .NodeProto ):
375
+ def _handle_node_Constant (self , node : onnx .NodeProto ) -> bool :
376
+ # Special case only for constants specified by value attribute (for now)
377
+ value_proto = _get_attr (node , "value" , False )
378
+ if not value_proto :
379
+ return False
380
+
381
+ # Produce an initializer for the constant, so that it can be used in
382
+ # combination with other ops, such as ConstantOfShape, requiring
383
+ # a constant input
384
+ assert value_proto .type == onnx .AttributeProto .AttributeType .TENSOR
385
+ assert len (node .output ) == 1
386
+ const_name = node .output [0 ]
387
+ self .import_initializer (value_proto .t , const_name )
388
+ self ._gi .initializer_map [const_name ] = value_proto .t
389
+ return True
390
+
391
+ def _handle_node_ConstantOfShape (self , node : onnx .NodeProto ) -> bool :
370
392
# This op is special: It has an input of the shape, and in full generality
371
393
# could involve eager production of constants of variable size. In
372
394
# practice, the DNN profile for ONNX makes this very difficult to do
@@ -394,6 +416,7 @@ def _handle_node_ConstantOfShape(self, node: onnx.NodeProto):
394
416
attributes = {"value" : value_attr },
395
417
)
396
418
self ._nv_map [node .output [0 ]] = literal_op .result
419
+ return True
397
420
398
421
399
422
class ContextCache :
@@ -515,6 +538,11 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute:
515
538
onnx .TensorProto .DataType .FLOAT : lambda tp , shape : DenseElementsAttr .get_splat (
516
539
RankedTensorType .get (shape , F32Type .get ()), FloatAttr .get_f32 (tp .float_data [0 ])
517
540
),
541
+ onnx .TensorProto .DataType .INT64 : lambda tp , shape : DenseElementsAttr .get_splat (
542
+ RankedTensorType .get (shape , IntegerType .get_signed (64 )), IntegerAttr .get (
543
+ IntegerType .get_signed (64 ), int .from_bytes (tp .raw_data , "little" ,
544
+ signed = True ) if tp .HasField ("raw_data" ) else tp .int64_data [0 ])
545
+ ),
518
546
# TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB
519
547
}
520
548
@@ -605,9 +633,10 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute:
605
633
}
606
634
607
635
608
- def _get_attr (node : onnx .NodeProto , attr_name : str ) -> onnx .AttributeProto :
636
+ def _get_attr (node : onnx .NodeProto , attr_name : str , is_required : bool = True ) -> onnx .AttributeProto :
609
637
for attr in node .attribute :
610
638
if attr .name == attr_name :
611
639
return attr
612
- else :
640
+ if is_required :
613
641
raise OnnxImportError (f"Required attribute { attr_name } not found in { node } " )
642
+ return None
0 commit comments