Skip to content

Commit d452c4f

Browse files
daveliddellDave Liddell
and
Dave Liddell
authored
Fix onnx importer to treat Constant values as static (#2780)
Fixes #2764 In the case of OPT, there are ConstantOfShape ops whose input shape is not static (that is, an initializer), but rather comes from a Constant op. The importer can't handle such non-static input shapes. The fix here is to create initializers for a subset of Constant ops (ones with "value" attributes), so that their outputs can be used statically. Additionally, there was no case for creating a splat of int64, so I added that as well. --------- Co-authored-by: Dave Liddell <[email protected]>
1 parent cad98e8 commit d452c4f

File tree

1 file changed

+37
-8
lines changed

1 file changed

+37
-8
lines changed

python/torch_mlir/extras/onnx_importer.py

+37-8
Original file line numberDiff line numberDiff line change
@@ -276,10 +276,13 @@ def import_node(self, node: onnx.NodeProto):
276276
with InsertionPoint(self._b), Location.name(node.name):
277277
op_type = node.op_type
278278
# Handle special op types that materialize to non-op IR constructs.
279+
# Handlers return True if the op was handled, else this function
280+
# should process it as a general node.
279281
special_key = f"_handle_node_{op_type}"
280282
if hasattr(self, special_key):
281-
getattr(self, special_key)(node)
282-
return
283+
was_handled = getattr(self, special_key)(node)
284+
if was_handled:
285+
return
283286

284287
# General node import.
285288
input_values = []
@@ -333,16 +336,19 @@ def import_attributes(
333336
)
334337
attrs[f"torch.onnx.{onnx_attr.name}"] = handler(onnx_attr, self._cc)
335338

336-
def import_initializer(self, initializer: onnx.TensorProto) -> Value:
337-
with InsertionPoint(self._b), Location.name(initializer.name):
339+
def import_initializer(self, initializer: onnx.TensorProto, extern_name: str = None) -> Value:
340+
# If an explicitly specified name is given, use that; otherwise, pick
341+
# up the name from the tensor proto itself
342+
iname = extern_name if extern_name else initializer.name
343+
with InsertionPoint(self._b), Location.name(iname):
338344
value_attr = self._cc.tensor_proto_to_attr(initializer)
339345
vtensor_type = self._cc.tensor_proto_to_type(initializer)
340346
literal_op = Operation.create(
341347
name="torch.vtensor.literal",
342348
results=[vtensor_type],
343349
attributes={"value": value_attr},
344350
)
345-
self._nv_map[initializer.name] = literal_op.result
351+
self._nv_map[iname] = literal_op.result
346352
return literal_op.result
347353

348354
def _get_immediate_tensor(self, name: str) -> np.array:
@@ -366,7 +372,23 @@ def _get_immediate_tensor(self, name: str) -> np.array:
366372
f"Unhandled ONNX TensorProto immediate data: {initializer}"
367373
)
368374

369-
def _handle_node_ConstantOfShape(self, node: onnx.NodeProto):
375+
def _handle_node_Constant(self, node: onnx.NodeProto) -> bool:
376+
# Special case only for constants specified by value attribute (for now)
377+
value_proto = _get_attr(node, "value", False)
378+
if not value_proto:
379+
return False
380+
381+
# Produce an initializer for the constant, so that it can be used in
382+
# combination with other ops, such as ConstantOfShape, requiring
383+
# a constant input
384+
assert value_proto.type == onnx.AttributeProto.AttributeType.TENSOR
385+
assert len(node.output) == 1
386+
const_name = node.output[0]
387+
self.import_initializer(value_proto.t, const_name)
388+
self._gi.initializer_map[const_name] = value_proto.t
389+
return True
390+
391+
def _handle_node_ConstantOfShape(self, node: onnx.NodeProto) -> bool:
370392
# This op is special: It has an input of the shape, and in full generality
371393
# could involve eager production of constants of variable size. In
372394
# practice, the DNN profile for ONNX makes this very difficult to do
@@ -394,6 +416,7 @@ def _handle_node_ConstantOfShape(self, node: onnx.NodeProto):
394416
attributes={"value": value_attr},
395417
)
396418
self._nv_map[node.output[0]] = literal_op.result
419+
return True
397420

398421

399422
class ContextCache:
@@ -515,6 +538,11 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute:
515538
onnx.TensorProto.DataType.FLOAT: lambda tp, shape: DenseElementsAttr.get_splat(
516539
RankedTensorType.get(shape, F32Type.get()), FloatAttr.get_f32(tp.float_data[0])
517540
),
541+
onnx.TensorProto.DataType.INT64: lambda tp, shape: DenseElementsAttr.get_splat(
542+
RankedTensorType.get(shape, IntegerType.get_signed(64)), IntegerAttr.get(
543+
IntegerType.get_signed(64), int.from_bytes(tp.raw_data, "little",
544+
signed=True) if tp.HasField("raw_data") else tp.int64_data[0])
545+
),
518546
# TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB
519547
}
520548

@@ -605,9 +633,10 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute:
605633
}
606634

607635

608-
def _get_attr(node: onnx.NodeProto, attr_name: str) -> onnx.AttributeProto:
636+
def _get_attr(node: onnx.NodeProto, attr_name: str, is_required: bool = True) -> onnx.AttributeProto:
609637
for attr in node.attribute:
610638
if attr.name == attr_name:
611639
return attr
612-
else:
640+
if is_required:
613641
raise OnnxImportError(f"Required attribute {attr_name} not found in {node}")
642+
return None

0 commit comments

Comments
 (0)