-
Notifications
You must be signed in to change notification settings - Fork 4.3k
/
Copy pathmodule_quantize.py
executable file
·72 lines (54 loc) · 3.07 KB
/
module_quantize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
def quantize_transformer_layer(orig_layer_impl, model, megatron=False, preln=False):
""" Quantize bert-style transformer layers with DeepSpeed's transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
e.g., transformers.models.bert.modeling_bert.BertLayer or transformers.BertLayer
model (torch.nn.Module): user's nn.module representing their model
megatron (bool): megatron model-parallel implementation (this is supported for inference only)
preln (bool): does the original layer implementation do pre or post layer norm?
Note: For Bert kind of models, we inject based on the DeepSpeed-Example models, if not setting huggingface flag.
Returns:
Updated nn.module with quantized transformer layers
"""
def quantize_weight(weight):
return weight.to(torch.int8)
def megatron_layer_quantize(layer):
layer.attention.query_key_value.weight.data = quantize_weight(layer.attention.query_key_value.weight.data)
layer.attention.dense.weight.data = quantize_weight(layer.attention.dense.weight.data)
layer.mlp.dense_h_to_4h.weight.data = quantize_weight(layer.mlp.dense_h_to_4h.weight.data)
layer.mlp.dense_4h_to_h.weight.data = quantize_weight(layer.mlp.dense_4h_to_h.weight.data)
def bert_layer_quantize(layer):
layer.attention.self.query.weight.data = quantize_weight(layer.attention.self.query.weight.data)
layer.attention.self.key.weight.data = quantize_weight(layer.attention.self.key.weight.data)
layer.attention.self.value.weight.data = quantize_weight(layer.attention.self.value.weight.data)
layer.attention.output.dense.weight.data = quantize_weight(layer.attention.output.dense.weight.data)
if preln:
layer.intermediate.dense_act.weight.data = quantize_weight(layer.intermediate.dense_act.weight.data)
else:
layer.intermediate.dense.weight.data = quantize_weight(layer.intermediate.dense.weight.data)
layer.output.dense.weight.data = quantize_weight(layer.output.dense.weight.data)
def quantize_fn(child):
if megatron:
# Quantize megatron GPT2 / GPT3 trained model
megatron_layer_quantize(child)
else:
# Quantize either DeepSpeed or HuggingFace trained model
bert_layer_quantize(child)
return child
return quantize_module(model=model, orig_class=orig_layer_impl, quantize_fn=quantize_fn)
def quantize_module(model, orig_class, quantize_fn):
policy = {orig_class: quantize_fn}
return _quantize_module(model, policy)
def _quantize_module(model, policies):
for name, child in model.named_children():
if child.__class__ in policies:
orig = repr(child)
setattr(model, name, policies[child.__class__](child))
new = getattr(model, name)
else:
_quantize_module(child, policies)
return model