diff --git a/olive/passes/onnx/mnb_to_qdq.py b/olive/passes/onnx/mnb_to_qdq.py index 6b40fb1ca..bafafbc8d 100644 --- a/olive/passes/onnx/mnb_to_qdq.py +++ b/olive/passes/onnx/mnb_to_qdq.py @@ -56,6 +56,14 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon " False." ), ), + "nodes_to_exclude": PassConfigParam( + type_=list, + default_value=None, + description=( + "List of node names to exclude from the conversion. The node names should be the names of the" + " MatMulNBits nodes. Default is None." + ), + ), **get_external_data_config(), } @@ -74,10 +82,13 @@ def _run_for_config( int_np_dtype = np.int8 if config["use_int4"] else np.uint8 int_elem_type = onnx.TensorProto.INT4 if config["use_int4"] else onnx.TensorProto.UINT4 + # set of nodes to exclude from the conversion + nodes_to_exclude = set(config["nodes_to_exclude"] or []) + num_modified = 0 for node_name in dag.get_node_names(): op_type = dag.get_node_op_type(node_name) - if op_type != "MatMulNBits": + if op_type != "MatMulNBits" or node_name in nodes_to_exclude: continue node_inputs = dag.get_node_inputs(node_name) diff --git a/test/unit_test/passes/onnx/test_mnb_to_qdq.py b/test/unit_test/passes/onnx/test_mnb_to_qdq.py index 8934023ff..380bacacf 100644 --- a/test/unit_test/passes/onnx/test_mnb_to_qdq.py +++ b/test/unit_test/passes/onnx/test_mnb_to_qdq.py @@ -12,6 +12,7 @@ from olive.model import ONNXModelHandler from olive.passes.olive_pass import create_pass_from_dict from olive.passes.onnx.mnb_to_qdq import MatMulNBitsToQDQ +from olive.passes.onnx.onnx_dag import OnnxDAG @pytest.fixture(params=[True, False], ids=["symmetric", "asymmetric"], name="create_mnb_model") @@ -65,8 +66,11 @@ def forward(self, x): @pytest.mark.parametrize("use_transpose_op", [True, False]) @pytest.mark.parametrize("use_int4", [True, False]) @pytest.mark.parametrize("add_zero_point", [True, False]) +@pytest.mark.parametrize("nodes_to_exclude", [None, ["/f1/MatMul_Q4"]]) @pytest.mark.parametrize("execution_provider", ["CPUExecutionProvider"]) -def test_mnb_to_qdq(create_mnb_model, execution_provider, add_zero_point, use_int4, use_transpose_op, tmp_path): +def test_mnb_to_qdq( + create_mnb_model, execution_provider, nodes_to_exclude, add_zero_point, use_int4, use_transpose_op, tmp_path +): available_providers = onnxruntime.get_available_providers() if execution_provider not in available_providers: pytest.skip(f"{execution_provider} is not available on this system {available_providers}") @@ -77,7 +81,12 @@ def test_mnb_to_qdq(create_mnb_model, execution_provider, add_zero_point, use_in # setup p = create_pass_from_dict( MatMulNBitsToQDQ, - {"use_transpose_op": use_transpose_op, "use_int4": use_int4, "add_zero_point": add_zero_point}, + { + "use_transpose_op": use_transpose_op, + "use_int4": use_int4, + "add_zero_point": add_zero_point, + "nodes_to_exclude": nodes_to_exclude, + }, disable_search=True, ) output_folder = tmp_path / "qdq-model" @@ -85,6 +94,18 @@ def test_mnb_to_qdq(create_mnb_model, execution_provider, add_zero_point, use_in # execute qdq_model: ONNXModelHandler = p.run(input_model, output_folder) + # count ops + num_matmuls = 0 + num_mnbs = 0 + dag = OnnxDAG.from_model_path(qdq_model.model_path) + for name in dag.get_node_names(): + op_type = dag.get_node_op_type(name) + if op_type == "MatMul": + num_matmuls += 1 + elif op_type == "MatMulNBits": + num_mnbs += 1 + assert num_matmuls == 3 - len(nodes_to_exclude or []) + assert num_mnbs == len(nodes_to_exclude or []) # validate original_session = onnxruntime.InferenceSession(str(mnb_path), providers=[execution_provider]) original_session.disable_fallback()