Skip to content

Commit b803401

Browse files
authored
[zero] add chunk_managerV2 for all-gather chunk (#1441)
1 parent 3b26516 commit b803401

File tree

3 files changed

+298
-0
lines changed

3 files changed

+298
-0
lines changed

colossalai/gemini/update/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .chunkv2 import ChunkV2
2+
from .chunk_mgrv2 import ChunkManagerV2
23
from .search_utils import clasify_params, search_chunk_configuration
+221
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
import torch
2+
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
3+
from collections import deque
4+
5+
from colossalai.utils import get_current_device
6+
from colossalai.tensor import ColoTensor
7+
from colossalai.gemini.chunk import ChunkFullError, TensorState
8+
from colossalai.gemini.update import ChunkV2 as Chunk
9+
10+
11+
class ChunkManagerV2:
12+
"""
13+
A manager class to manipulate the tensors in chunks.
14+
15+
Args:
16+
chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager.
17+
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
18+
pin_memory (bool): if ture, all chunks have a piece of pinned memory in CPU.
19+
"""
20+
21+
def __init__(self, chunk_configuration: Dict[int, Dict],
22+
init_device: Optional[torch.device] = None,
23+
pin_memory: bool = False) -> None:
24+
25+
self.device = init_device or get_current_device()
26+
self.size_config: Dict[int, int] = dict()
27+
self.kwargs_config = chunk_configuration
28+
for k, v in self.kwargs_config.items():
29+
self.size_config[k] = v.pop('chunk_size')
30+
v['init_device'] = self.device
31+
v['pin_memory'] = pin_memory
32+
33+
self.chunk_groups: Dict[str, Deque] = dict()
34+
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
35+
self.accessed_chunks: Set[Chunk] = set()
36+
self.lazy_release_tensors: List[torch.Tensor] = list()
37+
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
38+
39+
def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int) -> None:
40+
"""Append a tensor to a chunk.
41+
"""
42+
assert tensor not in self.tensor_chunk_map
43+
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager"
44+
assert config_key in self.size_config
45+
46+
chunk_size = self.size_config[config_key]
47+
chunk_kwargs = self.kwargs_config[config_key]
48+
group_name = "{}_{}".format(group_type, config_key)
49+
chunk_group = self.__get_chunk_group(group_name)
50+
51+
try:
52+
# append the tensor to the last chunk
53+
chunk_group[-1].append_tensor(tensor)
54+
except (IndexError, ChunkFullError):
55+
# the except statement will be triggered when there is no chunk or
56+
# the last chunk in the chunk group is full
57+
# this will create a new chunk and allocate this chunk to its corresponding process
58+
if chunk_group:
59+
# the chunk group is not empty
60+
# close the last chunk
61+
self.__close_one_chunk(chunk_group[-1])
62+
63+
if tensor.numel() > chunk_size:
64+
chunk_size = tensor.numel()
65+
chunk = Chunk(
66+
chunk_size=chunk_size,
67+
process_group=tensor.process_group,
68+
dtype=tensor.dtype,
69+
**chunk_kwargs
70+
)
71+
72+
chunk_group.append(chunk)
73+
chunk.append_tensor(tensor)
74+
self.__add_memory_usage(chunk.memory_usage)
75+
76+
self.tensor_chunk_map[tensor] = chunk_group[-1]
77+
78+
def close_all_groups(self):
79+
"""Close all the chunks of all groups.
80+
"""
81+
for group_name in self.chunk_groups:
82+
self.__close_one_chunk(self.chunk_groups[group_name][-1])
83+
84+
def access_chunk(self, chunk: Chunk) -> None:
85+
"""Make the chunk can be used for calculation.
86+
"""
87+
if chunk in self.accessed_chunks:
88+
return
89+
self.__sub_memroy_usage(chunk.memory_usage)
90+
chunk.access_chunk()
91+
self.__add_memory_usage(chunk.memory_usage)
92+
self.accessed_chunks.add(chunk)
93+
94+
def release_chunk(self, chunk: Chunk) -> None:
95+
"""Scatter the chunk in CUDA.
96+
"""
97+
if chunk not in self.accessed_chunks:
98+
return
99+
if chunk.can_release:
100+
self.__sub_memroy_usage(chunk.memory_usage)
101+
chunk.release_chunk()
102+
self.__add_memory_usage(chunk.memory_usage)
103+
self.accessed_chunks.remove(chunk)
104+
105+
def move_chunk(self, chunk: Chunk, device: torch.device) -> None:
106+
"""Move the shard of the chunk to the target device.
107+
"""
108+
if not chunk.can_move or chunk.device_type == device.type:
109+
return
110+
self.__sub_memroy_usage(chunk.memory_usage)
111+
chunk.shard_move(device)
112+
self.__add_memory_usage(chunk.memory_usage)
113+
114+
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
115+
"""Transit tensor state according to pre-defined state machine.
116+
"""
117+
chunk = self.tensor_chunk_map[tensor]
118+
chunk.tensor_trans_state(tensor, state)
119+
120+
def reduce_chunk(self, chunk: Chunk) -> bool:
121+
"""Reduce or all reduce the chunk.
122+
"""
123+
if not chunk.can_reduce:
124+
return False
125+
self.__sub_memroy_usage(chunk.memory_usage)
126+
chunk.release_chunk()
127+
self.__add_memory_usage(chunk.memory_usage)
128+
return True
129+
130+
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
131+
"""
132+
Copy data to the chunk.
133+
134+
Args:
135+
tensor (torch.Tensor): the tensor used to retrive meta information
136+
data (torch.Tensor): the tensor to be copied to the chunk
137+
"""
138+
chunk = self.tensor_chunk_map[tensor]
139+
chunk.copy_tensor_to_chunk_slice(tensor, data)
140+
141+
def get_chunk(self, tensor: torch.Tensor) -> Chunk:
142+
"""
143+
Return the chunk owning the tensor.
144+
145+
Args:
146+
tensor (torch.Tensor): a torch tensor object
147+
"""
148+
return self.tensor_chunk_map[tensor]
149+
150+
def add_lazy_release_tensors(self, tensors: List[torch.Tensor]) -> None:
151+
"""
152+
Add tensors to the buffer for lazy release.
153+
154+
Args:
155+
tensors (List[torch.Tensor]): the tensors to be released lazily
156+
"""
157+
self.lazy_release_tensors.extend(tensors)
158+
159+
def exec_lazy_release(self) -> None:
160+
"""
161+
Execute release for tensors added to the lazy release buffer.
162+
"""
163+
164+
for chunk in self.get_chunks(self.lazy_release_tensors):
165+
self.release_chunk(chunk)
166+
self.lazy_release_tensors.clear()
167+
168+
def __repr__(self) -> str:
169+
msg = ['Chunk Manager Information:\n',
170+
'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n']
171+
for group_name, group in self.chunk_groups.items():
172+
msg.append(f'Group {group_name}:\n')
173+
for i, chunk in enumerate(group):
174+
msg.append(f'[{i}] {chunk}\n')
175+
return ''.join(msg)
176+
177+
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
178+
"""
179+
Get all chunks owning the input tensors.
180+
181+
Args:
182+
tensors (Iterable[torch.Tensor]): the tensors used to look for chunks
183+
"""
184+
chunks = []
185+
for tensor in tensors:
186+
chunk = self.get_chunk(tensor)
187+
if chunk not in chunks:
188+
chunks.append(chunk)
189+
return tuple(chunks)
190+
191+
def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
192+
"""Add extern static tensor to chunk manager.
193+
Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
194+
They are "static", which means their shape, dtype, device never change.
195+
Thus, their memory usage never changes.
196+
197+
Args:
198+
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
199+
"""
200+
assert tensor not in self.tensor_chunk_map
201+
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
202+
203+
def __get_chunk_group(self, group_name: str) -> Deque:
204+
"""Register a chunk group.
205+
"""
206+
if group_name not in self.chunk_groups:
207+
self.chunk_groups[group_name] = deque()
208+
return self.chunk_groups[group_name]
209+
210+
def __close_one_chunk(self, chunk: Chunk):
211+
self.__sub_memroy_usage(chunk.memory_usage)
212+
chunk.close_chunk(self.device)
213+
self.__add_memory_usage(chunk.memory_usage)
214+
215+
def __sub_memroy_usage(self, usage: Dict[str, int]):
216+
for k, v in usage.items():
217+
self.total_mem[k] -= v
218+
219+
def __add_memory_usage(self, usage: Dict[str, int]):
220+
for k, v in usage.items():
221+
self.total_mem[k] += v
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import torch
2+
import colossalai
3+
import pytest
4+
import torch.multiprocessing as mp
5+
from functools import partial
6+
from colossalai.gemini.update import ChunkManagerV2
7+
from colossalai.testing import rerun_if_address_is_in_use, parameterize
8+
from colossalai.utils import free_port
9+
from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec
10+
from tests.test_tensor.common_utils import debug_print
11+
12+
CUDA_MEM_0 = {False: 512, True: 1024}
13+
CUDA_MEM_1 = {False: 0, True: 1024}
14+
CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}}
15+
16+
17+
@parameterize('keep_gathered', [True, False])
18+
@parameterize('pin_memory', [True, False])
19+
def exam_chunk_memory(keep_gathered, pin_memory):
20+
pg = ProcessGroup()
21+
22+
debug_print([0], "keep_gathered: {}, pin_memory: {}".format(
23+
keep_gathered, pin_memory))
24+
25+
params = [ColoTensor(torch.rand(8, 8), spec=ColoTensorSpec(pg)) for _ in range(3)]
26+
config = {
27+
2: dict(
28+
chunk_size=128,
29+
keep_gathered=keep_gathered
30+
)
31+
}
32+
33+
chunk_manager = ChunkManagerV2(config, pin_memory=pin_memory)
34+
assert chunk_manager.total_mem['cpu'] == 0
35+
assert chunk_manager.total_mem['cuda'] == 0
36+
37+
for p in params:
38+
chunk_manager.append_tensor(p, 'param', 2)
39+
chunk_manager.close_all_groups()
40+
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
41+
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered]
42+
43+
chunks = chunk_manager.get_chunks(params)
44+
45+
for chunk in chunks:
46+
chunk_manager.access_chunk(chunk)
47+
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
48+
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[True]
49+
50+
for chunk in chunks:
51+
chunk_manager.release_chunk(chunk)
52+
53+
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
54+
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered]
55+
56+
for chunk in chunks:
57+
chunk_manager.move_chunk(chunk, torch.device('cpu'))
58+
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][True]
59+
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_1[keep_gathered]
60+
61+
62+
def run_dist(rank, world_size, port):
63+
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
64+
exam_chunk_memory()
65+
66+
67+
@pytest.mark.dist
68+
@pytest.mark.parametrize('world_size', [2])
69+
@rerun_if_address_is_in_use()
70+
def test_chunk_manager(world_size):
71+
run_func = partial(run_dist, world_size=world_size, port=free_port())
72+
mp.spawn(run_func, nprocs=world_size)
73+
74+
75+
if __name__ == '__main__':
76+
test_chunk_manager(2)

0 commit comments

Comments
 (0)