diff --git a/ds/infer_bf16.py b/ds/infer_bf16.py new file mode 100644 index 00000000000..6fc88d11d0a --- /dev/null +++ b/ds/infer_bf16.py @@ -0,0 +1,178 @@ +# ==--------------------------------------------------------------------------== +# Patch for loading DS models +import os +from typing import Optional, Union +from zipfile import is_zipfile + +import torch +from packaging import version +from transformers.integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled +from transformers.utils import is_safetensors_available, strtobool + +if is_safetensors_available(): + from safetensors import safe_open + from safetensors.torch import load_file as safe_load_file + from safetensors.torch import save_file as safe_save_file + + +def is_fsdp_enabled(): + return ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1 + and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1 + ) + + +def is_local_dist_rank_0(): + return ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and int(os.environ.get("LOCAL_RANK", -1)) == 0 + ) + + +def load_state_dict( + checkpoint_file: Union[str, os.PathLike], + is_quantized: bool = False, + map_location: Optional[Union[str, torch.device]] = None, + weights_only: bool = True, +): + """Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.""" + + if checkpoint_file.endswith(".safetensors") and is_safetensors_available(): + # Check format of the archive + with safe_open(checkpoint_file, framework="pt") as f: + metadata = f.metadata() + if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: + raise OSError( + f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " + "you save your model with the `save_pretrained` method." + ) + return safe_load_file(checkpoint_file) + try: + if map_location is None: + if ( + ( + is_deepspeed_zero3_enabled() + and torch.distributed.is_initialized() + and torch.distributed.get_rank() > 0 + ) + or (is_fsdp_enabled() and not is_local_dist_rank_0()) + ) and not is_quantized: + map_location = "meta" + else: + map_location = "cpu" + extra_args = {} + # mmap can only be used with files serialized with zipfile-based format. + if ( + isinstance(checkpoint_file, str) + and map_location != "meta" + and version.parse(torch.__version__) >= version.parse("2.1.0") + and is_zipfile(checkpoint_file) + ): + extra_args = {"mmap": True} + weights_only_kwarg = {"weights_only": weights_only} + return torch.load( + checkpoint_file, + map_location=map_location, + **weights_only_kwarg, + **extra_args, + ) + except Exception as e: + try: + with open(checkpoint_file) as f: + if f.read(7) == "version": + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " + "model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' " + f"at '{checkpoint_file}'. " + "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." + ) + + +def set_initialized_submodules(model, state_dict_keys): + """Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state + dict.""" + state_dict_keys = set(state_dict_keys) + not_initialized_submodules = {} + for module_name, module in model.named_modules(): + if module_name == "": + # When checking if the root module is loaded there's no need to prepend module_name. + module_keys = set(module.state_dict()) + else: + module_keys = {f"{module_name}.{k}" for k in module.state_dict()} + if module_keys.issubset(state_dict_keys): + module._is_hf_initialized = True + else: + not_initialized_submodules[module_name] = module + return not_initialized_submodules + + +# ==--------------------------------------------------------------------------== + + +def patch_transformers(): + import transformers + + transformers.modeling_utils.load_state_dict = load_state_dict + transformers.modeling_utils.set_initialized_submodules = set_initialized_submodules + + +import logging + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def eval(model_path): + import transformers + from transformers.modeling_utils import no_init_weights + + # from patch_for_ds import patch_transformers + # if not not_patch_lin: + # patch_lin() + + def _patch__initialize_weights(self, module): + print("Skipping init_weights ") + module._is_hf_initialized = True + + transformers.modeling_utils.PreTrainedModel._initialize_weights = _patch__initialize_weights + patch_transformers() + with no_init_weights(): + model = transformers.AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype="auto", + low_cpu_mem_usage=True, + trust_remote_code=True, + ) + logger.info(f"Patched model: {model}") + model.eval() + tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) + prompt = "Hi, who" + encode = tokenizer.encode(prompt, return_tensors="pt") + with torch.no_grad(): + output_tokens = model.generate(encode, max_length=10) + output = tokenizer.decode(output_tokens[0], skip_special_tokens=True) + logger.info(f"Prompt: {prompt}") + logger.info(f"Output: {output}") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("-m", "--qmodel_path", type=str, required=True) + parser.add_argument("--not_patch_lin", action="store_true", help="Measure float model") + args = parser.parse_args() + eval(args.qmodel_path) diff --git a/examples/ds/README.md b/examples/ds/README.md new file mode 100644 index 00000000000..45df4738c7d --- /dev/null +++ b/examples/ds/README.md @@ -0,0 +1,70 @@ +# Note for quantize DeepSeek model + +## Prerequisite + +``` +pip install -r requirements.txt +``` + +## Usage + +### Step 1. quantize model weights + +- Option 1 (Recommended): Quantize weights directly + +```bash +python quant.py --model_path /path/to/DeepSeek/R1/BF16/ --qmodel_path /path/to/DeepSeek/R1-Dynamic-FP8 --low_cpu_mem +``` + +- Option 2: Load the model using transformers (requires ~700 GB of DRAM) + +```bash +python quant.py --model_path /path/to/DeepSeek/R1/BF16/ --qmodel_path /path/to/DeepSeek/R1/Dynamic-FP8 +``` + +> [!NOTE] +> - weight dtype is `torch.float8_e4m3fn` (full range is `-448` to `448`) +> - `WEIGHT_BACKOFF = 0.5` +> - `SCALE_DTYPE = torch.bfloat16` + +### Step 2. copy model files for inference + +Since DeepSeek V3 and R1 are not yet supported by Transformers, we need to manually copy some model files. + +```bash +python post_process.py --model_path /path/to/DeepSeek/R1/BF16/ --qmodel_path /path/to/DeepSeek/R1/Dynamic-FP8 +``` + +## More details + +1. Name convention: + - weight scale name: `prefix.scale_weight` + - input scale name: `prefix.scale_input` (for static only) +2. A json file mapping from tensor name to safetensor file name. + +```python +class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(10, 5, bias=False) + + def forward(self, inp): + x1 = self.fc1(inp) + return x1 +``` + +```bash +1. state dict +{ + "fc1.weight": torch.Tensor(...), + "fc1.scale_weight": torch.Tensor(...), + "fc1.scale_input": torch.Tensor(...), +} + +2. json file, `model.safetensors.index.json` +{ + "fc1.weight": "qmodel.safetensors", + "fc1.scale_weight": "qmodel.safetensors", + "fc1.scale_input": "qmodel.safetensors" +} +``` diff --git a/examples/ds/eval.py b/examples/ds/eval.py new file mode 100644 index 00000000000..16d51f2697a --- /dev/null +++ b/examples/ds/eval.py @@ -0,0 +1,143 @@ +import os +import torch +import tqdm +from loguru import logger +import logging +import safetensors +from safetensors import safe_open +from safetensors.torch import save_file +import json + +logging.basicConfig(level=logging.DEBUG) +torch.set_grad_enabled(False) + +# CONSTANTS +SAFETENSORS = "safetensors" +WEIGHT_SCALE_NAME = "scale_weight" +INPUT_SCALE_NAME = "scale_input" +SCALE_DTYPE = torch.bfloat16 +SCALE_FILE_NAME = f"scales.{SAFETENSORS}" +FULL_RANGE = torch.finfo(torch.float8_e4m3fn).max +WEIGHT_BACKOFF = 0.5 +QUANT_MODULE_TYPES = (torch.nn.Linear,) +SKIP_WEIGHT_LST = { + "model.norm", + "layernorm", + "e_score_correction_bias", + # "lm_head.weight", + "embed_tokens", + "mlp.gate.weight", # mlp.gate is not linear +} +""" +# https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html?highlight=backoff#supported-json-config-file-options +Similarly, the maxabs value of a weight is scaled to weight_backoff*FP8_143_FULLSCALE. The default values are input_backoff=0.25 and weight_backoff=0.5. +""" +MODEL_STATE_DICT_MAPPING_FILENAME = "model.safetensors.index.json" + + +def skip_weight(weight_name): + return any([skip_name in weight_name for skip_name in SKIP_WEIGHT_LST]) + + +def get_cpu_mem_size_in_gb(): + import psutil + + mem = psutil.virtual_memory() + return mem.available + + +from quant import quant_tensor + + +from torch import nn + + +# Adapted from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/1d044fd82b15f1cedb197a288e50cc96a2c27205/inference/model.py#L91-L108 +class FP8QDQLinear(torch.nn.Linear): + dtype = torch.bfloat16 + fp8_dtype = torch.float8_e4m3fn + + def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None): + super().__init__(in_features, out_features, bias=bias) + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter( + torch.empty(out_features, in_features, dtype=FP8QDQLinear.fp8_dtype), requires_grad=True + ) + self.scale_weight = nn.Parameter(torch.tensor(0, dtype=FP8QDQLinear.dtype), requires_grad=False) + if bias: + self.bias = nn.Parameter(torch.empty(out_features)) + else: + self.register_parameter("bias", None) + + def dequant_weight_online(self): + fp8_weight = self.weight + qdq_weight = fp8_weight.to(FP8QDQLinear.dtype) * self.scale_weight + return qdq_weight + + def qdq_input(self, bf16_input: torch.Tensor): + input_scale, input_fp8 = quant_tensor(bf16_input) + qdq_input_bf16 = input_fp8.to(FP8QDQLinear.dtype) * input_scale + return qdq_input_bf16 + + @classmethod + def create_from_linear(cls, linear: nn.Linear): + qdq_linear = cls(linear.in_features, linear.out_features) + qdq_linear.weight.data = linear.weight.data + if linear.bias is not None: + qdq_linear.bias = linear.bias + return qdq_linear + + def forward(self, bf16_input: torch.Tensor) -> torch.Tensor: + qdq_input = self.qdq_input(bf16_input) + qdq_weight = self.dequant_weight_online() + out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias) + return out + + +def patch_lin(): + logger.warning("Patching torch.nn.Linear to FP8QDQLinear") + torch.nn.Linear = FP8QDQLinear + + +def qdq_eval(model_path, not_patch_lin=False): + import transformers + from transformers.modeling_utils import no_init_weights + from patch_for_ds import patch_transformers + + if not not_patch_lin: + patch_lin() + + def _patch__initialize_weights(self, module): + print(f"Skipping init_weights ") + module._is_hf_initialized = True + + transformers.modeling_utils.PreTrainedModel._initialize_weights = _patch__initialize_weights + patch_transformers() + with no_init_weights(): + model = transformers.AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype="auto", + low_cpu_mem_usage=True, + trust_remote_code=True, + ) + logger.info(f"Patched model: {model}") + model.eval() + tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) + prompt = "Hi, who" + encode = tokenizer.encode(prompt, return_tensors="pt") + with torch.no_grad(): + output_tokens = model.generate(encode, max_length=10) + output = tokenizer.decode(output_tokens[0], skip_special_tokens=True) + logger.info(f"Prompt: {prompt}") + logger.info(f"Output: {output}") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--qmodel_path", type=str, required=True) + parser.add_argument("--not_patch_lin", action="store_true", help="Measure float model") + args = parser.parse_args() + qdq_eval(args.qmodel_path, not_patch_lin=args.not_patch_lin) diff --git a/examples/ds/patch_for_ds.py b/examples/ds/patch_for_ds.py new file mode 100644 index 00000000000..d82251e3e62 --- /dev/null +++ b/examples/ds/patch_for_ds.py @@ -0,0 +1,132 @@ +# ==--------------------------------------------------------------------------== +# Patch for loading DS models from transformers +from typing import Union, Optional +import torch +import os +from packaging import version +from zipfile import is_zipfile +from transformers.utils import is_safetensors_available, strtobool +from transformers.integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled + +if is_safetensors_available(): + from safetensors import safe_open + from safetensors.torch import load_file as safe_load_file + from safetensors.torch import save_file as safe_save_file + + +def is_fsdp_enabled(): + return ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1 + and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1 + ) + + +def is_local_dist_rank_0(): + return ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and int(os.environ.get("LOCAL_RANK", -1)) == 0 + ) + + +def load_state_dict( + checkpoint_file: Union[str, os.PathLike], + is_quantized: bool = False, + map_location: Optional[Union[str, torch.device]] = None, + weights_only: bool = True, +): + """ + Reads a PyTorch checkpoint file, returning properly formatted errors if they arise. + """ + + if checkpoint_file.endswith(".safetensors") and is_safetensors_available(): + # Check format of the archive + with safe_open(checkpoint_file, framework="pt") as f: + metadata = f.metadata() + if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: + raise OSError( + f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " + "you save your model with the `save_pretrained` method." + ) + return safe_load_file(checkpoint_file) + try: + if map_location is None: + if ( + ( + is_deepspeed_zero3_enabled() + and torch.distributed.is_initialized() + and torch.distributed.get_rank() > 0 + ) + or (is_fsdp_enabled() and not is_local_dist_rank_0()) + ) and not is_quantized: + map_location = "meta" + else: + map_location = "cpu" + extra_args = {} + # mmap can only be used with files serialized with zipfile-based format. + if ( + isinstance(checkpoint_file, str) + and map_location != "meta" + and version.parse(torch.__version__) >= version.parse("2.1.0") + and is_zipfile(checkpoint_file) + ): + extra_args = {"mmap": True} + weights_only_kwarg = {"weights_only": weights_only} + return torch.load( + checkpoint_file, + map_location=map_location, + **weights_only_kwarg, + **extra_args, + ) + except Exception as e: + try: + with open(checkpoint_file) as f: + if f.read(7) == "version": + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " + "model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' " + f"at '{checkpoint_file}'. " + "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." + ) + +# https://github.com/huggingface/transformers/pull/35493 +def set_initialized_submodules(model, state_dict_keys): + """ + Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state + dict. + """ + state_dict_keys = set(state_dict_keys) + not_initialized_submodules = {} + for module_name, module in model.named_modules(): + if module_name == "": + # When checking if the root module is loaded there's no need to prepend module_name. + module_keys = set(module.state_dict()) + else: + module_keys = {f"{module_name}.{k}" for k in module.state_dict()} + if module_keys.issubset(state_dict_keys): + module._is_hf_initialized = True + else: + not_initialized_submodules[module_name] = module + return not_initialized_submodules + + +# ==--------------------------------------------------------------------------== + + +def patch_transformers(): + import transformers + + transformers.modeling_utils.load_state_dict = load_state_dict + transformers.modeling_utils.set_initialized_submodules = set_initialized_submodules diff --git a/examples/ds/post_process.py b/examples/ds/post_process.py new file mode 100644 index 00000000000..a9fe929fbeb --- /dev/null +++ b/examples/ds/post_process.py @@ -0,0 +1,101 @@ +import json +from loguru import logger + +quantization_config = { + "_json_file": "/tmp/tmpe3ckugb_.json", + "allowlist": { + "names": [], + "types": [ + "Matmul", + "Linear", + "ParallelLMHead", + "RowParallelLinear", + "ColumnParallelLinear", + "MergedColumnParallelLinear", + "QKVParallelLinear", + "FalconLinear", + "KVCache", + "VLLMKVCache", + "Conv2d", + "LoRACompatibleLinear", + "LoRACompatibleConv", + "Softmax", + "ModuleFusedSDPA", + "MoeMatmul", + "ReplicatedLinear", + "FusedMoE", + "GaudiMixtralSparseMoeBlock", + "VllmMixtureOfExpertsOp", + "LinearLayer", + "LinearAllreduce", + "ScopedLinearAllReduce", + "LmHeadLinearAllreduce", + ], + }, + "blocklist": {}, + "dump_stats_path": "./hqt_output/measure", + "fake_quant": "False", + "fp8_config": "E4M3", + "hp_dtype": "bf16", + "measure_on_hpu": True, + "mod_dict": {}, + "mode": "LOAD", + "observer": "maxabs", + "scale_format": "const", + "scale_method": "maxabs_pow2_dynamic", + "scale_params": {}, + "use_qdq": "False", +} + + +# add the quantization config to config.json +def update_config(model_path, qmodel_path): + import json + import os + + with open(os.path.join(model_path, "config.json"), "r") as f: + config = json.load(f) + config["quantization_config"] = quantization_config + logger.info(f"Updated config: {config}") + config_filepath = os.path.join(qmodel_path, "config.json") + logger.debug(f"Saving config to {config_filepath}") + with open(config_filepath, "w") as f: + json.dump(config, f, indent=4) + + +MODEL_FILE_LST = [ + "configuration_deepseek.py", + "generation_config.json", + "modeling_deepseek.py", + "tokenizer.json", + "tokenizer_config.json", +] + + +def cp_model_files(model_path, qmodel_path): + # copy model files + import shutil + import os + + for file in MODEL_FILE_LST: + logger.debug(f"Copying {file} from {model_path} to {qmodel_path}") + file_path = os.path.join(model_path, file) + # check if file exists + if not os.path.exists(file_path): + logger.error(f"File {file_path} does not exist") + raise FileNotFoundError(f"File {file_path} does not exist") + shutil.copy(os.path.join(model_path, file), qmodel_path) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, required=True) + parser.add_argument("--qmodel_path", type=str, required=True) + parser.add_argument("--low_cpu_mem", action="store_true", help="Load weight file one by one to reduce memory usage") + args = parser.parse_args() + # update the config + update_config(args.model_path, args.qmodel_path) + # copy model files + cp_model_files(args.model_path, args.qmodel_path) diff --git a/examples/ds/quant.py b/examples/ds/quant.py new file mode 100644 index 00000000000..615b6bf33bf --- /dev/null +++ b/examples/ds/quant.py @@ -0,0 +1,170 @@ +import os +import torch +import tqdm +from loguru import logger +import logging +import safetensors +from safetensors import safe_open +from safetensors.torch import save_file +import json + +logging.basicConfig(level=logging.DEBUG) +torch.set_grad_enabled(False) + +# CONSTANTS +SAFETENSORS = "safetensors" +WEIGHT_SCALE_NAME = "scale_weight" +INPUT_SCALE_NAME = "scale_input" +SCALE_DTYPE = torch.bfloat16 +SCALE_FILE_NAME = f"scales.{SAFETENSORS}" +FULL_RANGE = torch.finfo(torch.float8_e4m3fn).max +WEIGHT_BACKOFF = 0.5 +QUANT_MODULE_TYPES = (torch.nn.Linear,) +SKIP_WEIGHT_LST = { + "model.norm", + "layernorm", + "e_score_correction_bias", + # "lm_head.weight", + "embed_tokens", + "mlp.gate.weight", # mlp.gate is not linear +} +""" +# https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html?highlight=backoff#supported-json-config-file-options +Similarly, the maxabs value of a weight is scaled to weight_backoff*FP8_143_FULLSCALE. The default values are input_backoff=0.25 and weight_backoff=0.5. +""" +MODEL_STATE_DICT_MAPPING_FILENAME = "model.safetensors.index.json" + + +def skip_weight(weight_name): + return any([skip_name in weight_name for skip_name in SKIP_WEIGHT_LST]) + + +def get_cpu_mem_size_in_gb(): + import psutil + + mem = psutil.virtual_memory() + return mem.available + + +def get_all_weight_filename(model_path): + all_files = os.listdir(model_path) + all_weight_filename = [] + for file in all_files: + if file.endswith(f".{SAFETENSORS}"): + all_weight_filename.append(file) + return all_weight_filename + + +# from _fp8_quant/_core/fp_utils.py +def calc_maxabs_scale(xmaxabs, fullscale, backoff=1): + scale = xmaxabs / (fullscale * backoff) + return scale + + +def quant_tensor(tensor): + # Note: + # 1. Check the scale dtype + # 2. Check the scale shape + amax = tensor.abs().max() + scale = calc_maxabs_scale(amax, FULL_RANGE, WEIGHT_BACKOFF) + scale = scale.to(SCALE_DTYPE) + qtensor = tensor / scale + cliped_qtensor = torch.clamp(qtensor, -FULL_RANGE, FULL_RANGE) + cliped_qtensor_fp8 = cliped_qtensor.to(torch.float8_e4m3fn) + return scale, cliped_qtensor_fp8 + + +def _maybe_create_dir(qmodel_path): + if not os.path.exists(qmodel_path): + os.makedirs(qmodel_path) + + +def quant_model_weight_with_low_cpu_usage(model_path, qmodel_path): + _maybe_create_dir(qmodel_path) + all_weight_filename = get_all_weight_filename(model_path) + files_cnt = len(all_weight_filename) + logger.info(f"Got {len(all_weight_filename)} weight files") + qtensor_mappping = {} + for i, filename in enumerate(all_weight_filename): + logger.info(f"Processing {i + 1}/{len(all_weight_filename)}: {filename}") + file_path = os.path.join(model_path, filename) + qmodel_file_name = filename + qmodel_file_path = os.path.join(qmodel_path, qmodel_file_name) + qtensors = {} + with safe_open(file_path, framework="pt", device="cpu") as f: + for weight_name in f.keys(): + weight = f.get_tensor(weight_name) + if skip_weight(weight_name): + logger.debug(f"Skipping quantize {weight_name}") + qtensors[weight_name] = weight + qtensor_mappping[weight_name] = qmodel_file_name + continue + logger.debug(f"[{i+1}/{files_cnt}] Processing {weight_name}") + scale, qtensor = quant_tensor(weight) + preifx_name = weight_name[: -len(".weight")] + scale_name = f"{preifx_name}.{WEIGHT_SCALE_NAME}" + qtensors[scale_name] = scale + qtensors[weight_name] = qtensor + qtensor_mappping[scale_name] = qmodel_file_name + qtensor_mappping[weight_name] = qmodel_file_name + logger.debug(f"[{i+1}/{files_cnt}] Saving {len(qtensors)} tensors to {qmodel_file_path}") + save_file(qtensors, os.path.join(qmodel_path, qmodel_file_path)) + # Dump tensor mapping into json file + model_state_dict_mapping_file_path = os.path.join(qmodel_path, MODEL_STATE_DICT_MAPPING_FILENAME) + logger.info(f"Saving tensor mapping to {model_state_dict_mapping_file_path}") + state_dict_mapping = { + "metadata":{}, + "weight_map": qtensor_mappping, + } + with open(model_state_dict_mapping_file_path, "w") as f: + json.dump(state_dict_mapping, f, indent=4) + + +def _import_oh(): + import transformers + from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi + + orig_check_support_param_buffer_assignment = transformers.modeling_utils.check_support_param_buffer_assignment + adapt_transformers_to_gaudi() + transformers.modeling_utils.check_support_param_buffer_assignment = orig_check_support_param_buffer_assignment + + +@torch.no_grad() +def static_quant_model_tran(model_path, qmodel_path): + # assert get_cpu_mem_size_in_gb(800), "Not enough memory, please use quant_model_weight_with_low_cpu_usage" + import transformers + from patch_for_ds import patch_transformers + + # import_oh() + patch_transformers() + model = transformers.AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype="auto", + low_cpu_mem_usage=True, + trust_remote_code=True, + ) + for name, module in model.named_modules(): + if not isinstance(module, QUANT_MODULE_TYPES) or skip_weight(name): + logger.debug(f"Skipping quantize {name}") + continue + logger.debug(f"Processing {name}") + weight = module.weight + scale, qtensor = quant_tensor(weight) + module.weight.data = qtensor + setattr(module, "scale_weight", torch.nn.Parameter(scale, requires_grad=False)) + logger.info(f"Saving quantized model to {qmodel_path}") + model.save_pretrained(qmodel_path) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, required=True) + parser.add_argument("--qmodel_path", type=str, required=True) + parser.add_argument("--low_cpu_mem", action="store_true", help="Load weight file one by one to reduce memory usage") + args = parser.parse_args() + if args.low_cpu_mem: + quant_model_weight_with_low_cpu_usage(args.model_path, args.qmodel_path) + else: + static_quant_model_tran(args.model_path, args.qmodel_path) diff --git a/examples/ds/requirements.txt b/examples/ds/requirements.txt new file mode 100644 index 00000000000..c224c095b1b --- /dev/null +++ b/examples/ds/requirements.txt @@ -0,0 +1,6 @@ +loguru +torch +safetensors +tqdm +transformers +psutil \ No newline at end of file