Skip to content

Commit

Permalink
latest
Browse files Browse the repository at this point in the history
  • Loading branch information
jambayk committed Feb 19, 2025
1 parent 2d6c026 commit 6e5c3b4
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 62 deletions.
11 changes: 6 additions & 5 deletions examples/utils/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def generate(
session = self.sessions["iterator"]
io_binding = session.io_binding() if use_io_binding else None

print(np_buffers["attention_mask"])
# print(np_buffers["attention_mask"])

if use_io_binding:
if idx < 2:
Expand Down Expand Up @@ -388,9 +388,9 @@ def generate(
first_cache = outputs[1]
if use_io_binding:
first_cache = first_cache.numpy()
print(first_cache[:, 0, :, 0])
if idx == 2:
sdcd
# print(first_cache[:, 0, :, 0])
# if idx == 2:
# sdcd

# update cache
cache.update(outputs[1:])
Expand Down Expand Up @@ -439,7 +439,7 @@ def get_initial_inputs(
else:
padding_args = {"padding": "longest"}
# encode prompt
encodings_dict = self.tokenizer(prompt, return_tensors="np", **padding_args)
encodings_dict = self.tokenizer(prompt, return_tensors="np", add_special_tokens=False, **padding_args)
input_ids = encodings_dict["input_ids"].astype(self.input_info["input_ids"]["dtype"])
batch_size, prompt_length = input_ids.shape
attention_mask = encodings_dict["attention_mask"]
Expand All @@ -457,6 +457,7 @@ def get_initial_inputs(
int(attention_mask.sum(axis=-1).max()),
)
if isinstance(cache, GQASharedCache):
print(cache.max_cache_len, prompt_length)
attention_mask = np.concatenate(
[attention_mask, np.zeros((batch_size, cache.max_cache_len - prompt_length), dtype=np.int32)], 1
)
Expand Down
34 changes: 30 additions & 4 deletions olive/passes/onnx/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import json
import logging
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union

import onnx
import transformers
from packaging import version

from olive.common.utils import IntEnumBase, StrEnumBase
from olive.hardware.accelerator import AcceleratorSpec, Device
Expand Down Expand Up @@ -66,8 +67,24 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon
"int4_accuracy_level": PassConfigParam(
type_=ModelBuilder.AccuracyLevel,
required=False,
description="Specify the minimum accuracy level for activation of MatMul in int4 quantization.",
description=(
"Specify the minimum accuracy level for activation of MatMul in int4 quantization. Defualt is None"
" (= 0)."
),
),
# "int4_is_symmetric": PassConfigParam(
# type_=bool,
# required=False,
# description="Whether to use symmetric quantization for int4 quantization. Default is None (= True).",
# ),
# "int4_op_types_to_quantize": PassConfigParam(
# type_=List[str],
# required=False,
# description=(
# 'Specify the op types to quantize for int4 quantization. Default is None (= [ "MatMul" ]). Example:'
# ' ["MatMul", "Gemm"]'
# ),
# ),
"exclude_embeds": PassConfigParam(
type_=bool,
default_value=False,
Expand Down Expand Up @@ -127,6 +144,7 @@ def _run_for_config(
config: Dict[str, Any],
output_model_path: str,
) -> ONNXModelHandler:
from onnxruntime_genai import __version__ as genai_version
from onnxruntime_genai.models.builder import create_model

precision = config["precision"]
Expand Down Expand Up @@ -174,16 +192,24 @@ def _run_for_config(
if config.get("int4_accuracy_level"):
extra_args["int4_accuracy_level"] = config["int4_accuracy_level"].value

# args that are only checked for presence, not value
# args that are only checked for presence, not value, until 0.6.0
for arg in ["exclude_embeds", "exclude_lm_head"]:
if config[arg]:
extra_args[arg] = True

# args that are checked for presence and value (if present)
for arg in ["enable_cuda_graph"]:
# for arg, min_version in [("enable_cuda_graph", None), ("int4_is_symmetric", 0.6.0)]:
for arg, min_version in [("enable_cuda_graph", None)]:
if config[arg] is not None:
if min_version and version.parse(genai_version) < version.parse(min_version):
raise ValueError(
f"{arg} is not supported in genai version {genai_version}. Minimum version: {min_version}"
)
extra_args[arg] = "1" if config[arg] else "0"

# if config["int4_op_types_to_quantize"]:
# extra_args["int4_op_types_to_quantize"] = "/".join(config["int4_op_types_to_quantize"])

model_attributes = copy.deepcopy(model.model_attributes or {})

try:
Expand Down
103 changes: 50 additions & 53 deletions olive/passes/onnx/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,16 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon
default_value=None,
description="List of node names to exclude from quantization.",
),
"nodes_to_include": PassConfigParam(
type_=list,
default_value=None,
description="List of node names to include in quantization.",
),
"op_types_to_quantize": PassConfigParam(
type_=list,
default_value=None,
description='List of operator types to quantize. Default value is None (= [ "MatMul" ]).',
),
"accuracy_level": PassConfigParam(
# TODO(trajep): to make it searchable
type_=int,
Expand Down Expand Up @@ -765,6 +775,19 @@ def _validators(cls) -> Dict[str, Callable]:
def _run_for_config(
self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str
) -> ONNXModelHandler:
from onnxruntime import __version__ as OrtVersion

if version.parse(OrtVersion) < version.parse("1.18.0"):
raise ValueError("MatMul4BitsQuantizer is only supported for onnxruntime>=1.18.0")

from onnxruntime.quantization.matmul_4bits_quantizer import (
DefaultWeightOnlyQuantConfig,
GPTQWeightOnlyQuantConfig,
HQQWeightOnlyQuantConfig,
MatMul4BitsQuantizer,
RTNWeightOnlyQuantConfig,
)

if model_has_adapters(model.model_path) and config["algorithm"] not in {None, "DEFAULT"}:
logger.info(
"Model has adapters which should only be quantized with algorithm=None or DEFAULT. Got %s. Returning"
Expand All @@ -773,65 +796,38 @@ def _run_for_config(
)
return model

from onnxruntime import __version__ as OrtVersion

if version.parse(OrtVersion) < version.parse("1.16.2"):
raise ValueError("MatMul4BitsQuantizer is only supported in onnxruntime >= 1.16.2")

from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer

output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name)

weight_only_quant_config_class = None
weight_only_quant_config = None
algo_config = deepcopy(config["weight_only_quant_configs"] or {})
if version.parse(OrtVersion) >= version.parse("1.17.0"):
from onnxruntime.quantization.matmul_4bits_quantizer import (
GPTQWeightOnlyQuantConfig,
RTNWeightOnlyQuantConfig,
)

if config["algorithm"] == "RTN":
weight_only_quant_config_class = RTNWeightOnlyQuantConfig
elif config["algorithm"] == "GPTQ":
if "block_size" in algo_config and version.parse(OrtVersion) < version.parse("1.18.0"):
# ort 1.17.0+ uses blocksize instead of block_size :(
algo_config["blocksize"] = algo_config["block_size"]
algo_config.pop("block_size")
dataloader = get_calibration_dataloader(config, model.model_path, model.io_config)
weight_only_quant_config_class = partial(GPTQWeightOnlyQuantConfig, calibration_data_reader=dataloader)

if version.parse(OrtVersion) >= version.parse("1.18.0"):
from onnxruntime.quantization.matmul_4bits_quantizer import (
DefaultWeightOnlyQuantConfig,
HQQWeightOnlyQuantConfig,
)

if config["algorithm"] == "DEFAULT":
weight_only_quant_config_class = DefaultWeightOnlyQuantConfig
elif config["algorithm"] == "HQQ":
weight_only_quant_config_class = HQQWeightOnlyQuantConfig
elif config["algorithm"] in ("HQQ", "DEFAULT"):
raise ValueError("HQQ and DEFAULT algorithm are only supported in onnxruntime >= 1.18.0")

if weight_only_quant_config_class:
weight_only_quant_config = weight_only_quant_config_class(**algo_config)
quant = MatMul4BitsQuantizer(
model.load_model(),
block_size=config["block_size"],
is_symmetric=config["is_symmetric"],
nodes_to_exclude=config["nodes_to_exclude"],
accuracy_level=config["accuracy_level"],
algo_config=weight_only_quant_config,
)
else:
# TODO(trajep): remove this block once we migrate customer to onnxruntime>=1.17.0 all
quant = MatMul4BitsQuantizer(
model.load_model(),
block_size=config["block_size"],
is_symmetric=config["is_symmetric"],
nodes_to_exclude=config["nodes_to_exclude"],
)
if config["algorithm"] == "RTN":
weight_only_quant_config_class = RTNWeightOnlyQuantConfig
elif config["algorithm"] == "GPTQ":
dataloader = get_calibration_dataloader(config, model.model_path, model.io_config)
weight_only_quant_config_class = partial(GPTQWeightOnlyQuantConfig, calibration_data_reader=dataloader)
elif config["algorithm"] == "DEFAULT":
weight_only_quant_config_class = DefaultWeightOnlyQuantConfig
elif config["algorithm"] == "HQQ":
weight_only_quant_config_class = HQQWeightOnlyQuantConfig

if weight_only_quant_config_class:
weight_only_quant_config = weight_only_quant_config_class(**algo_config)

kwargs = {
"block_size": config["block_size"],
"is_symmetric": config["is_symmetric"],
"nodes_to_exclude": config["nodes_to_exclude"],
"accuracy_level": config["accuracy_level"],
"algo_config": weight_only_quant_config,
}
for key in ["nodes_to_include", "op_types_to_quantize"]:
if config[key]:
if version.parse(OrtVersion) < version.parse("1.20.0"):
raise ValueError(f"MatMul4BitsQuantizer {key} is only supported for onnxruntime>=1.20.0")
kwargs[key] = config[key]
quant = MatMul4BitsQuantizer(model.load_model(), **kwargs)
quant.process()
# topologically sort the graph at the end since previous optimizations may have broken it
quant.model.topological_sort()
Expand Down Expand Up @@ -878,6 +874,7 @@ def _validate_weight_only_quant_config(v, values, field):
default_config_keys = ["block_size", "bits", "axis"]
elif values["algorithm"] == "GPTQ":
default_config_keys = ["percdamp", "block_size", "actorder", "mse", "perchannel"]
default_config_keys.append("op_types_to_quantize")

if not all(key in default_config_keys for key in config_keys):
invalid_config_keys = set(config_keys) - set(default_config_keys)
Expand Down

0 comments on commit 6e5c3b4

Please sign in to comment.