Skip to content

Commit

Permalink
MatMulNBitsToQDQ: add nodes_to_exclude option (#1613)
Browse files Browse the repository at this point in the history
## Describe your changes
Add an option to give a list of nodes to exclude from the operator
conversion. This is useful when the model is already fully WOQ using
matmul4quantizer but we want to keep some nodes like the lmheads as
`MatMulNBits`.

## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [ ] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [ ] Lint and apply fixes to your code by running `lintrunner -a`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.
- [ ] Is this PR including examples changes? If yes, please remember to
update [example
documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md)
in a follow-up PR.

## (Optional) Issue link
  • Loading branch information
jambayk authored Feb 14, 2025
1 parent 4153b37 commit 481807e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
13 changes: 12 additions & 1 deletion olive/passes/onnx/mnb_to_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon
" False."
),
),
"nodes_to_exclude": PassConfigParam(
type_=list,
default_value=None,
description=(
"List of node names to exclude from the conversion. The node names should be the names of the"
" MatMulNBits nodes. Default is None."
),
),
**get_external_data_config(),
}

Expand All @@ -74,10 +82,13 @@ def _run_for_config(
int_np_dtype = np.int8 if config["use_int4"] else np.uint8
int_elem_type = onnx.TensorProto.INT4 if config["use_int4"] else onnx.TensorProto.UINT4

# set of nodes to exclude from the conversion
nodes_to_exclude = set(config["nodes_to_exclude"] or [])

num_modified = 0
for node_name in dag.get_node_names():
op_type = dag.get_node_op_type(node_name)
if op_type != "MatMulNBits":
if op_type != "MatMulNBits" or node_name in nodes_to_exclude:
continue

node_inputs = dag.get_node_inputs(node_name)
Expand Down
25 changes: 23 additions & 2 deletions test/unit_test/passes/onnx/test_mnb_to_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from olive.model import ONNXModelHandler
from olive.passes.olive_pass import create_pass_from_dict
from olive.passes.onnx.mnb_to_qdq import MatMulNBitsToQDQ
from olive.passes.onnx.onnx_dag import OnnxDAG


@pytest.fixture(params=[True, False], ids=["symmetric", "asymmetric"], name="create_mnb_model")
Expand Down Expand Up @@ -65,8 +66,11 @@ def forward(self, x):
@pytest.mark.parametrize("use_transpose_op", [True, False])
@pytest.mark.parametrize("use_int4", [True, False])
@pytest.mark.parametrize("add_zero_point", [True, False])
@pytest.mark.parametrize("nodes_to_exclude", [None, ["/f1/MatMul_Q4"]])
@pytest.mark.parametrize("execution_provider", ["CPUExecutionProvider"])
def test_mnb_to_qdq(create_mnb_model, execution_provider, add_zero_point, use_int4, use_transpose_op, tmp_path):
def test_mnb_to_qdq(
create_mnb_model, execution_provider, nodes_to_exclude, add_zero_point, use_int4, use_transpose_op, tmp_path
):
available_providers = onnxruntime.get_available_providers()
if execution_provider not in available_providers:
pytest.skip(f"{execution_provider} is not available on this system {available_providers}")
Expand All @@ -77,14 +81,31 @@ def test_mnb_to_qdq(create_mnb_model, execution_provider, add_zero_point, use_in
# setup
p = create_pass_from_dict(
MatMulNBitsToQDQ,
{"use_transpose_op": use_transpose_op, "use_int4": use_int4, "add_zero_point": add_zero_point},
{
"use_transpose_op": use_transpose_op,
"use_int4": use_int4,
"add_zero_point": add_zero_point,
"nodes_to_exclude": nodes_to_exclude,
},
disable_search=True,
)
output_folder = tmp_path / "qdq-model"

# execute
qdq_model: ONNXModelHandler = p.run(input_model, output_folder)

# count ops
num_matmuls = 0
num_mnbs = 0
dag = OnnxDAG.from_model_path(qdq_model.model_path)
for name in dag.get_node_names():
op_type = dag.get_node_op_type(name)
if op_type == "MatMul":
num_matmuls += 1
elif op_type == "MatMulNBits":
num_mnbs += 1
assert num_matmuls == 3 - len(nodes_to_exclude or [])
assert num_mnbs == len(nodes_to_exclude or [])
# validate
original_session = onnxruntime.InferenceSession(str(mnb_path), providers=[execution_provider])
original_session.disable_fallback()
Expand Down

0 comments on commit 481807e

Please sign in to comment.