Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/ds/quant #2118

Draft
wants to merge 25 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions ds/infer_bf16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# ==--------------------------------------------------------------------------==
# Patch for loading DS models
import os
from typing import Optional, Union
from zipfile import is_zipfile

import torch
from packaging import version
from transformers.integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from transformers.utils import is_safetensors_available, strtobool

if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file


def is_fsdp_enabled():
return (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1
and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1
)


def is_local_dist_rank_0():
return (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and int(os.environ.get("LOCAL_RANK", -1)) == 0
)


def load_state_dict(
checkpoint_file: Union[str, os.PathLike],
is_quantized: bool = False,
map_location: Optional[Union[str, torch.device]] = None,
weights_only: bool = True,
):
"""Reads a PyTorch checkpoint file, returning properly formatted errors if they arise."""

if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
# Check format of the archive
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method."
)
return safe_load_file(checkpoint_file)
try:
if map_location is None:
if (
(
is_deepspeed_zero3_enabled()
and torch.distributed.is_initialized()
and torch.distributed.get_rank() > 0
)
or (is_fsdp_enabled() and not is_local_dist_rank_0())
) and not is_quantized:
map_location = "meta"
else:
map_location = "cpu"
extra_args = {}
# mmap can only be used with files serialized with zipfile-based format.
if (
isinstance(checkpoint_file, str)
and map_location != "meta"
and version.parse(torch.__version__) >= version.parse("2.1.0")
and is_zipfile(checkpoint_file)
):
extra_args = {"mmap": True}
weights_only_kwarg = {"weights_only": weights_only}
return torch.load(
checkpoint_file,
map_location=map_location,
**weights_only_kwarg,
**extra_args,
)
except Exception as e:
try:
with open(checkpoint_file) as f:
if f.read(7) == "version":
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please install "
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
"you cloned."
)
else:
raise ValueError(
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
"model. Make sure you have saved the model properly."
) from e
except (UnicodeDecodeError, ValueError):
raise OSError(
f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
f"at '{checkpoint_file}'. "
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
)


def set_initialized_submodules(model, state_dict_keys):
"""Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state
dict."""
state_dict_keys = set(state_dict_keys)
not_initialized_submodules = {}
for module_name, module in model.named_modules():
if module_name == "":
# When checking if the root module is loaded there's no need to prepend module_name.
module_keys = set(module.state_dict())
else:
module_keys = {f"{module_name}.{k}" for k in module.state_dict()}
if module_keys.issubset(state_dict_keys):
module._is_hf_initialized = True
else:
not_initialized_submodules[module_name] = module
return not_initialized_submodules


# ==--------------------------------------------------------------------------==


def patch_transformers():
import transformers

transformers.modeling_utils.load_state_dict = load_state_dict
transformers.modeling_utils.set_initialized_submodules = set_initialized_submodules


import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def eval(model_path):
import transformers
from transformers.modeling_utils import no_init_weights

# from patch_for_ds import patch_transformers
# if not not_patch_lin:
# patch_lin()

def _patch__initialize_weights(self, module):
print("Skipping init_weights ")
module._is_hf_initialized = True

transformers.modeling_utils.PreTrainedModel._initialize_weights = _patch__initialize_weights
patch_transformers()
with no_init_weights():
model = transformers.AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype="auto",
low_cpu_mem_usage=True,
trust_remote_code=True,
)
logger.info(f"Patched model: {model}")
model.eval()
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
prompt = "Hi, who"
encode = tokenizer.encode(prompt, return_tensors="pt")
with torch.no_grad():
output_tokens = model.generate(encode, max_length=10)
output = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
logger.info(f"Prompt: {prompt}")
logger.info(f"Output: {output}")


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("-m", "--qmodel_path", type=str, required=True)
parser.add_argument("--not_patch_lin", action="store_true", help="Measure float model")
args = parser.parse_args()
eval(args.qmodel_path)
70 changes: 70 additions & 0 deletions examples/ds/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Note for quantize DeepSeek model

## Prerequisite

```
pip install -r requirements.txt
```

## Usage

### Step 1. quantize model weights

- Option 1 (Recommended): Quantize weights directly

```bash
python quant.py --model_path /path/to/DeepSeek/R1/BF16/ --qmodel_path /path/to/DeepSeek/R1-Dynamic-FP8 --low_cpu_mem
```

- Option 2: Load the model using transformers (requires ~700 GB of DRAM)

```bash
python quant.py --model_path /path/to/DeepSeek/R1/BF16/ --qmodel_path /path/to/DeepSeek/R1/Dynamic-FP8
```

> [!NOTE]
> - weight dtype is `torch.float8_e4m3fn` (full range is `-448` to `448`)
> - `WEIGHT_BACKOFF = 0.5`
> - `SCALE_DTYPE = torch.bfloat16`

### Step 2. copy model files for inference

Since DeepSeek V3 and R1 are not yet supported by Transformers, we need to manually copy some model files.

```bash
python post_process.py --model_path /path/to/DeepSeek/R1/BF16/ --qmodel_path /path/to/DeepSeek/R1/Dynamic-FP8
```

## More details

1. Name convention:
- weight scale name: `prefix.scale_weight`
- input scale name: `prefix.scale_input` (for static only)
2. A json file mapping from tensor name to safetensor file name.

```python
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = torch.nn.Linear(10, 5, bias=False)

def forward(self, inp):
x1 = self.fc1(inp)
return x1
```

```bash
1. state dict
{
"fc1.weight": torch.Tensor(...),
"fc1.scale_weight": torch.Tensor(...),
"fc1.scale_input": torch.Tensor(...),
}

2. json file, `model.safetensors.index.json`
{
"fc1.weight": "qmodel.safetensors",
"fc1.scale_weight": "qmodel.safetensors",
"fc1.scale_input": "qmodel.safetensors"
}
```
143 changes: 143 additions & 0 deletions examples/ds/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import os
import torch
import tqdm
from loguru import logger
import logging
import safetensors
from safetensors import safe_open
from safetensors.torch import save_file
import json

logging.basicConfig(level=logging.DEBUG)
torch.set_grad_enabled(False)

# CONSTANTS
SAFETENSORS = "safetensors"
WEIGHT_SCALE_NAME = "scale_weight"
INPUT_SCALE_NAME = "scale_input"
SCALE_DTYPE = torch.bfloat16
SCALE_FILE_NAME = f"scales.{SAFETENSORS}"
FULL_RANGE = torch.finfo(torch.float8_e4m3fn).max
WEIGHT_BACKOFF = 0.5
QUANT_MODULE_TYPES = (torch.nn.Linear,)
SKIP_WEIGHT_LST = {
"model.norm",
"layernorm",
"e_score_correction_bias",
# "lm_head.weight",
"embed_tokens",
"mlp.gate.weight", # mlp.gate is not linear
}
"""
# https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html?highlight=backoff#supported-json-config-file-options
Similarly, the maxabs value of a weight is scaled to weight_backoff*FP8_143_FULLSCALE. The default values are input_backoff=0.25 and weight_backoff=0.5.
"""
MODEL_STATE_DICT_MAPPING_FILENAME = "model.safetensors.index.json"


def skip_weight(weight_name):
return any([skip_name in weight_name for skip_name in SKIP_WEIGHT_LST])


def get_cpu_mem_size_in_gb():
import psutil

mem = psutil.virtual_memory()
return mem.available


from quant import quant_tensor


from torch import nn


# Adapted from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/1d044fd82b15f1cedb197a288e50cc96a2c27205/inference/model.py#L91-L108
class FP8QDQLinear(torch.nn.Linear):
dtype = torch.bfloat16
fp8_dtype = torch.float8_e4m3fn

def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None):
super().__init__(in_features, out_features, bias=bias)
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(
torch.empty(out_features, in_features, dtype=FP8QDQLinear.fp8_dtype), requires_grad=True
)
self.scale_weight = nn.Parameter(torch.tensor(0, dtype=FP8QDQLinear.dtype), requires_grad=False)
if bias:
self.bias = nn.Parameter(torch.empty(out_features))
else:
self.register_parameter("bias", None)

def dequant_weight_online(self):
fp8_weight = self.weight
qdq_weight = fp8_weight.to(FP8QDQLinear.dtype) * self.scale_weight
return qdq_weight

def qdq_input(self, bf16_input: torch.Tensor):
input_scale, input_fp8 = quant_tensor(bf16_input)
qdq_input_bf16 = input_fp8.to(FP8QDQLinear.dtype) * input_scale
return qdq_input_bf16

@classmethod
def create_from_linear(cls, linear: nn.Linear):
qdq_linear = cls(linear.in_features, linear.out_features)
qdq_linear.weight.data = linear.weight.data
if linear.bias is not None:
qdq_linear.bias = linear.bias
return qdq_linear

def forward(self, bf16_input: torch.Tensor) -> torch.Tensor:
qdq_input = self.qdq_input(bf16_input)
qdq_weight = self.dequant_weight_online()
out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias)
return out


def patch_lin():
logger.warning("Patching torch.nn.Linear to FP8QDQLinear")
torch.nn.Linear = FP8QDQLinear


def qdq_eval(model_path, not_patch_lin=False):
import transformers
from transformers.modeling_utils import no_init_weights
from patch_for_ds import patch_transformers

if not not_patch_lin:
patch_lin()

def _patch__initialize_weights(self, module):
print(f"Skipping init_weights ")
module._is_hf_initialized = True

transformers.modeling_utils.PreTrainedModel._initialize_weights = _patch__initialize_weights
patch_transformers()
with no_init_weights():
model = transformers.AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype="auto",
low_cpu_mem_usage=True,
trust_remote_code=True,
)
logger.info(f"Patched model: {model}")
model.eval()
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
prompt = "Hi, who"
encode = tokenizer.encode(prompt, return_tensors="pt")
with torch.no_grad():
output_tokens = model.generate(encode, max_length=10)
output = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
logger.info(f"Prompt: {prompt}")
logger.info(f"Output: {output}")


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--qmodel_path", type=str, required=True)
parser.add_argument("--not_patch_lin", action="store_true", help="Measure float model")
args = parser.parse_args()
qdq_eval(args.qmodel_path, not_patch_lin=args.not_patch_lin)
Loading