Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test for node adaptation on Squeeze11 #155

Merged
merged 12 commits into from
Aug 1, 2024
70 changes: 52 additions & 18 deletions tests/test_adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy
import onnx
import onnx.parser
import onnxruntime as ort
import pytest

import spox.opset.ai.onnx.v18 as op18
Expand Down Expand Up @@ -71,32 +72,33 @@ def inline_old_identity_twice_graph(old_identity):
return results(final=z).with_opset(("ai.onnx", 17))


@pytest.fixture
def old_squeeze_graph(old_squeeze):
class Squeeze11(StandardNode):
@dataclass
class Attributes(BaseAttributes):
axes: AttrInt64s
class Squeeze11(StandardNode):
@dataclass
class Attributes(BaseAttributes):
axes: AttrInt64s

@dataclass
class Inputs(BaseInputs):
data: Var
@dataclass
class Inputs(BaseInputs):
data: Var

@dataclass
class Outputs(BaseOutputs):
squeezed: Var
@dataclass
class Outputs(BaseOutputs):
squeezed: Var

op_type = OpType("Squeeze", "", 11)
op_type = OpType("Squeeze", "", 11)

attrs: Attributes
inputs: Inputs
outputs: Outputs
attrs: Attributes
inputs: Inputs
outputs: Outputs

def squeeze11(_data: Var, _axes: Iterable[int]):
def squeeze11(cls, _data: Var, _axes: Iterable[int]):
aivanoved marked this conversation as resolved.
Show resolved Hide resolved
return Squeeze11(
Squeeze11.Attributes(AttrInt64s(_axes, "axes")), Squeeze11.Inputs(_data)
).outputs.squeezed


@pytest.fixture
def old_squeeze_graph(old_squeeze):
(data,) = arguments(
data=Tensor(
numpy.float32,
Expand All @@ -106,7 +108,7 @@ def squeeze11(_data: Var, _axes: Iterable[int]):
),
)
)
result = squeeze11(data, [0])
result = Squeeze11.squeeze11(data, [0])
return results(final=result).with_opset(("ai.onnx", 17))


Expand Down Expand Up @@ -233,3 +235,35 @@ def test_inline_model_custom_node_nested(old_squeeze: onnx.ModelProto):
# Add another node to the model to trigger the adaption logic
c = op18.identity(b)
build({"a": a}, {"c": c})


def test_if_adapatation_squeeze():
cond = argument(Tensor(numpy.bool_, ()))
b = argument(Tensor(numpy.float32, (1,)))
squeezed = Squeeze11.squeeze11(b, [0])
out = op18.if_(
cond,
then_branch=lambda: [squeezed],
else_branch=lambda: [Squeeze11.squeeze11(b, [0])],
)
model = build({"b": b, "cond": cond}, {"out": out[0]})

# predict on model
b = numpy.array([1.1], dtype=numpy.float32)
cond = numpy.array(True, dtype=numpy.bool_)
out = ort.InferenceSession(model.SerializeToString()).run(
None, {"b": b, "cond": cond}
)


def test_if_adaptation_const():
sq = op19.const(1.1453, dtype=numpy.float32)
b = argument(Tensor(numpy.float32, ("N",)))
cond = op18.equal(sq, b)
out = op18.if_(cond, then_branch=lambda: [sq], else_branch=lambda: [sq])
model = build({"b": b}, {"out": out[0]})
assert model.domain == "" or model.domain == "ai.onnx"
assert (
model.opset_import[0].domain == "ai.onnx" or model.opset_import[0].domain == ""
)
assert model.opset_import[0].version > 11
Loading