Skip to content

Commit 63fbba3

Browse files
authored
[zero] add L2 gradient clipping for ZeRO (#2112)
* [zero] add L2 gradient clipping * [testing] add MlpModel * [zero] add unit test for grad clipping * fix atol
1 parent 70a8556 commit 63fbba3

File tree

5 files changed

+194
-11
lines changed

5 files changed

+194
-11
lines changed

colossalai/gemini/chunk/chunk.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def alloc_storage(tensor: torch.Tensor) -> None:
5151

5252

5353
class Chunk:
54-
5554
_total_number = 0
5655

5756
def __init__(self,
@@ -140,6 +139,10 @@ def __init__(self,
140139
# if the cpu_shard has been visited during the training step, the flag is True
141140
self.cpu_vis_flag = False
142141

142+
# whether to record l2 norm for the gradient clipping calculation
143+
self.l2_norm_flag = False
144+
self.l2_norm = None
145+
143146
@property
144147
def memory_usage(self) -> Dict[str, int]:
145148
cuda_memory = 0
@@ -213,16 +216,28 @@ def can_reduce(self):
213216

214217
@property
215218
def has_inf_or_nan(self) -> bool:
216-
"""Check if the chunk has inf or nan values in CUDA.
219+
"""Check if the chunk has inf or nan values on CUDA.
217220
"""
218221
if self.is_gathered:
219222
valid_tensor = self.chunk_total[:self.utilized_size]
220223
else:
221-
assert self.cuda_shard is not None # only check in CUDA
224+
assert self.cuda_shard is not None # only check on CUDA
222225
valid_tensor = self.cuda_shard[:self.valid_end]
223226

224227
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
225228

229+
def set_l2_norm(self) -> None:
230+
"""Record l2 norm of this chunks on CUDA.
231+
"""
232+
assert self.l2_norm is None, "you are calculating the l2 norm twice"
233+
if self.is_gathered:
234+
valid_tensor = self.chunk_total[:self.utilized_size]
235+
else:
236+
assert self.cuda_shard is not None # calculate on CUDA
237+
valid_tensor = self.cuda_shard[:self.valid_end]
238+
chunk_l2_norm = valid_tensor.data.float().norm(2)
239+
self.l2_norm = chunk_l2_norm.item()**2
240+
226241
def append_tensor(self, tensor: torch.Tensor):
227242
"""Add a tensor to the chunk.
228243

colossalai/nn/optimizer/zero_optimizer.py

+54-7
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
from enum import Enum
23
from typing import Any, Dict, Set, Tuple
34

@@ -56,6 +57,8 @@ def __init__(self,
5657
growth_interval: int = 1000,
5758
hysteresis: int = 2,
5859
max_scale: float = 2**32,
60+
clipping_norm: float = 0.0,
61+
norm_type: float = 2.0,
5962
**defaults: Any):
6063
super().__init__(optim)
6164
assert isinstance(module, ZeroDDP)
@@ -66,11 +69,17 @@ def __init__(self,
6669
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
6770
self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
6871
self.chunk16_set: Set[Chunk] = set()
72+
self.clipping_flag = clipping_norm > 0.0
73+
self.max_norm = clipping_norm
74+
75+
if self.clipping_flag:
76+
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"
6977

7078
params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)]
7179
for p, fp32_p in zip(params_list, module.fp32_params):
7280
chunk_16 = self.chunk_manager.get_chunk(p)
7381
if chunk_16 not in self.chunk16_set:
82+
chunk_16.l2_norm_flag = self.clipping_flag
7483
self.chunk16_set.add(chunk_16)
7584

7685
self.__init__optimizer()
@@ -128,12 +137,45 @@ def _check_overflow(self):
128137

129138
return self._found_overflow.item() > 0
130139

131-
def _unscale_grads(self):
140+
def _calc_global_norm(self) -> float:
141+
norm_sqr: float = 0.0
142+
group_to_norm = dict()
143+
for c16 in self.chunk16_set:
144+
assert c16.l2_norm is not None
145+
146+
if c16.is_gathered:
147+
norm_sqr += c16.l2_norm
148+
else:
149+
# this chunk is sharded, use communication to collect total norm
150+
if c16.torch_pg not in group_to_norm:
151+
group_to_norm[c16.torch_pg] = 0.0
152+
group_to_norm[c16.torch_pg] += c16.l2_norm
153+
154+
c16.l2_norm = None # clear l2 norm
155+
156+
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device())
157+
for group, part_norm in group_to_norm.items():
158+
comm_buffer.fill_(part_norm)
159+
dist.all_reduce(comm_buffer, group=group)
160+
norm_sqr += comm_buffer.item()
161+
162+
global_norm = math.sqrt(norm_sqr)
163+
return global_norm
164+
165+
def _unscale_and_clip_grads(self):
132166
assert self.optim_state == OptimState.SCALED
167+
168+
combined_scale = self.loss_scale
169+
if self.clipping_flag:
170+
total_norm = self._calc_global_norm()
171+
clip = ((total_norm / self.loss_scale) + 1e-6) / self.max_norm
172+
if clip > 1:
173+
combined_scale = clip * self.loss_scale
174+
133175
for group in self.optim.param_groups:
134176
for p in group['params']:
135177
if p.grad is not None:
136-
p.grad.data.div_(self.loss_scale)
178+
p.grad.data.div_(combined_scale)
137179
self.optim_state = OptimState.UNSCALED
138180

139181
@property
@@ -147,16 +189,21 @@ def zero_grad(self, *args, **kwargs):
147189
def step(self, *args, **kwargs):
148190
self._maybe_move_fp32_params()
149191
self._set_grad_ptr()
150-
# unscale grads if scaled
151-
if self.optim_state == OptimState.SCALED:
152-
self._unscale_grads()
192+
153193
found_inf = self._check_overflow()
154-
self.grad_scaler.update(found_inf)
155194
if found_inf:
195+
self.optim_state = OptimState.UNSCALED # no need to unscale grad
196+
self.grad_scaler.update(found_inf) # update gradient scaler
156197
self._logger.info(f'Found overflow. Skip step')
157-
self.zero_grad()
198+
self.zero_grad() # reset all gradients
158199
self._update_fp16_params()
159200
return
201+
202+
# unscale grads if scaled
203+
if self.optim_state == OptimState.SCALED:
204+
self._unscale_and_clip_grads()
205+
self.grad_scaler.update(found_inf)
206+
160207
ret = self.optim.step(*args, **kwargs)
161208
self._register_states()
162209
self.zero_grad()

colossalai/nn/parallel/data_parallel.py

+4
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,11 @@ def grad_handle(self, p, grad):
302302
chunk.chunk_total.div_(chunk.pg_size)
303303
else:
304304
chunk.cuda_shard.div_(chunk.pg_size)
305+
# check overflow elements
305306
self.overflow_counter += chunk.has_inf_or_nan
307+
# record l2 norm for gradient clipping
308+
if chunk.l2_norm_flag:
309+
chunk.set_l2_norm()
306310
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
307311
return empty_grad
308312

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from functools import partial
2+
from time import time
3+
4+
import pytest
5+
import torch
6+
import torch.distributed as dist
7+
import torch.multiprocessing as mp
8+
from torch.nn.parallel import DistributedDataParallel as DDP
9+
from torch.testing import assert_close
10+
11+
import colossalai
12+
from colossalai.amp import convert_to_apex_amp
13+
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
14+
from colossalai.gemini.gemini_mgr import GeminiManager
15+
from colossalai.nn.optimizer import HybridAdam
16+
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
17+
from colossalai.nn.parallel import ZeroDDP
18+
from colossalai.testing import parameterize, rerun_if_address_is_in_use
19+
from colossalai.utils import free_port
20+
from colossalai.utils.cuda import get_current_device
21+
from colossalai.utils.model.colo_init_context import ColoInitContext
22+
from tests.components_to_test import run_fwd_bwd
23+
from tests.components_to_test.registry import non_distributed_component_funcs
24+
from tests.test_tensor.common_utils import debug_print, set_seed
25+
26+
27+
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
28+
zero_dict = model.state_dict(only_rank_0=False)
29+
torch_dict = torch_model.state_dict()
30+
31+
for key, value in torch_dict.items():
32+
# key is 'module.model.PARAMETER', so we truncate it
33+
key = key[7:]
34+
if key == 'model.lm_head.weight':
35+
continue
36+
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
37+
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
38+
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
39+
assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)
40+
41+
42+
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
43+
@parameterize('model_name', ['gpt2'])
44+
def exam_grad_clipping(placement_policy, model_name: str):
45+
set_seed(1912)
46+
get_components_func = non_distributed_component_funcs.get_callable(model_name)
47+
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
48+
49+
torch_model = model_builder().cuda()
50+
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=32)
51+
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
52+
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
53+
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
54+
55+
init_dev = get_current_device()
56+
with ColoInitContext(device=init_dev):
57+
model = model_builder()
58+
59+
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
60+
p.data.copy_(torch_p.data)
61+
62+
world_size = torch.distributed.get_world_size()
63+
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
64+
config_dict[world_size]['chunk_size'] = 5000
65+
config_dict[world_size]['keep_gathered'] = False
66+
if placement_policy != 'cuda':
67+
init_device = torch.device('cpu')
68+
else:
69+
init_device = None
70+
chunk_manager = ChunkManager(config_dict, init_device=init_device)
71+
gemini_manager = GeminiManager(placement_policy, chunk_manager)
72+
model = ZeroDDP(model, gemini_manager, pin_memory=True)
73+
74+
optimizer = HybridAdam(model.parameters(), lr=1e-3)
75+
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0)
76+
77+
model.train()
78+
torch_model.train()
79+
80+
set_seed(dist.get_rank() * 3 + 128)
81+
for i, (data, label) in enumerate(train_dataloader):
82+
if i > 2:
83+
break
84+
data = data.cuda()
85+
label = label.cuda()
86+
87+
zero_optim.zero_grad()
88+
torch_optim.zero_grad()
89+
90+
torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim)
91+
loss = run_fwd_bwd(model, data, label, criterion, zero_optim)
92+
assert_close(torch_loss, loss)
93+
94+
import apex.amp as apex_amp
95+
torch.nn.utils.clip_grad_norm_(apex_amp.master_params(torch_optim), 1.0)
96+
torch_optim.step()
97+
zero_optim.step()
98+
99+
check_param(model, torch_model)
100+
101+
102+
def run_dist(rank, world_size, port):
103+
config = {}
104+
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
105+
exam_grad_clipping()
106+
107+
108+
@pytest.mark.dist
109+
@pytest.mark.parametrize('world_size', [1, 2])
110+
@rerun_if_address_is_in_use()
111+
def test_grad_clip(world_size):
112+
run_func = partial(run_dist, world_size=world_size, port=free_port())
113+
mp.spawn(run_func, nprocs=world_size)
114+
115+
116+
if __name__ == '__main__':
117+
test_grad_clip(2)

tests/test_gemini/update/test_optim.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
4242
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
4343
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
4444
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
45-
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-2)
45+
assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)
4646

4747

4848
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])

0 commit comments

Comments
 (0)