Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test for node adaptation on Squeeze11 #155

Merged
merged 12 commits into from
Aug 1, 2024
75 changes: 55 additions & 20 deletions tests/test_adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import onnx
import onnx.parser
import onnxruntime as ort
import pytest

import spox.opset.ai.onnx.v18 as op18
Expand Down Expand Up @@ -71,32 +72,34 @@ def inline_old_identity_twice_graph(old_identity):
return results(final=z).with_opset(("ai.onnx", 17))


@pytest.fixture
def old_squeeze_graph(old_squeeze):
class Squeeze11(StandardNode):
@dataclass
class Attributes(BaseAttributes):
axes: AttrInt64s
class Squeeze11(StandardNode):
@dataclass
class Attributes(BaseAttributes):
axes: AttrInt64s

@dataclass
class Inputs(BaseInputs):
data: Var

@dataclass
class Outputs(BaseOutputs):
squeezed: Var

@dataclass
class Inputs(BaseInputs):
data: Var
op_type = OpType("Squeeze", "", 11)

@dataclass
class Outputs(BaseOutputs):
squeezed: Var
attrs: Attributes
inputs: Inputs
outputs: Outputs

op_type = OpType("Squeeze", "", 11)

attrs: Attributes
inputs: Inputs
outputs: Outputs
def squeeze11(_data: Var, _axes: Iterable[int]):
return Squeeze11(
Squeeze11.Attributes(AttrInt64s(_axes, "axes")), Squeeze11.Inputs(_data)
).outputs.squeezed

def squeeze11(_data: Var, _axes: Iterable[int]):
return Squeeze11(
Squeeze11.Attributes(AttrInt64s(_axes, "axes")), Squeeze11.Inputs(_data)
).outputs.squeezed

@pytest.fixture
def old_squeeze_graph(old_squeeze):
(data,) = arguments(
data=Tensor(
np.float32,
Expand Down Expand Up @@ -233,3 +236,35 @@ def test_inline_model_custom_node_nested(old_squeeze: onnx.ModelProto):
# Add another node to the model to trigger the adaption logic
c = op18.identity(b)
build({"a": a}, {"c": c})


def test_if_adapatation_squeeze():
cond = argument(Tensor(np.bool_, ()))
b = argument(Tensor(np.float32, (1,)))
squeezed = squeeze11(b, [0])
out = op18.if_(
cond,
then_branch=lambda: [squeezed],
else_branch=lambda: [squeeze11(b, [0])],
)
model = build({"b": b, "cond": cond}, {"out": out[0]})

# predict on model
b = np.array([1.1], dtype=np.float32)
cond = np.array(True, dtype=np.bool_)
out = ort.InferenceSession(model.SerializeToString()).run(
None, {"b": b, "cond": cond}
)


def test_if_adaptation_const():
sq = op19.const(1.1453, dtype=np.float32)
b = argument(Tensor(np.float32, ("N",)))
cond = op18.equal(sq, b)
out = op18.if_(cond, then_branch=lambda: [sq], else_branch=lambda: [sq])
model = build({"b": b}, {"out": out[0]})
assert model.domain == "" or model.domain == "ai.onnx"
assert (
model.opset_import[0].domain == "ai.onnx" or model.opset_import[0].domain == ""
)
assert model.opset_import[0].version > 11