Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zerobubble] Merge Main. #6107

Merged
merged 190 commits into from
Nov 5, 2024
Merged
Changes from 1 commit
Commits
Show all changes
190 commits
Select commit Hold shift + click to select a range
f5a52e1
fp8 operators for compressed communication
BurkeHulk Jul 1, 2024
6991819
Merge branch 'hpcaitech:main' into feature/fp8_comm
BurkeHulk Jul 4, 2024
e17f835
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 4, 2024
dbfa7d3
fix typo
GuangyaoZhang Jul 10, 2024
1e19594
fix scaling algorithm in FP8 casting
BurkeHulk Jul 12, 2024
e881901
support fp8 communication in pipeline parallelism
BurkeHulk Jul 12, 2024
6601874
add fp8_communication flag in the script
BurkeHulk Jul 12, 2024
1f1b856
Merge remote-tracking branch 'origin/feature/fp8_comm' into feature/f…
BurkeHulk Jul 12, 2024
51f916b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2024
9470701
Merge pull request #5885 from BurkeHulk/feature/fp8_comm
BurkeHulk Jul 16, 2024
457a0de
shardformer fp8
GuangyaoZhang Jul 8, 2024
5a310b9
fix rebase
GuangyaoZhang Jul 17, 2024
6a20f07
remove all to all
GuangyaoZhang Jul 17, 2024
d0bdb51
Merge pull request #5899 from BurkeHulk/SP_fp8
GuangyaoZhang Jul 18, 2024
5b969fd
fix shardformer fp8 communication training degradation
GuangyaoZhang Jul 18, 2024
62661cd
Merge pull request #5921 from BurkeHulk/fp8_fix
GuangyaoZhang Jul 18, 2024
5fd0592
[fp8] support all-gather flat tensor (#5932)
ver217 Jul 24, 2024
ae486ce
[fp8] add fp8 comm for low level zero
ver217 Aug 2, 2024
91e596d
[test] add zero fp8 test case
ver217 Aug 2, 2024
c297e21
Merge pull request #5961 from ver217/feature/zeor-fp8
BurkeHulk Aug 2, 2024
53cb960
[Feature] llama shardformer fp8 support (#5938)
GuangyaoZhang Aug 5, 2024
0c10afd
[FP8] rebase main (#5963)
flybird11111 Aug 6, 2024
afb26de
[fp8]support all2all fp8 (#5953)
flybird11111 Aug 6, 2024
76ea164
[fp8] add fp8 linear (#5967)
ver217 Aug 7, 2024
ccabcf6
[fp8] support fp8 amp for hybrid parallel plugin (#5975)
ver217 Aug 7, 2024
7739629
fix (#5976)
flybird11111 Aug 7, 2024
b480eec
[Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928)
BurkeHulk Aug 8, 2024
4b9bec8
[test ci]Feature/fp8 comm (#5981)
flybird11111 Aug 8, 2024
8241c0c
[fp8] support gemini plugin (#5978)
ver217 Aug 9, 2024
e4aadee
[fp8] use torch compile (torch >= 2.3.0) (#5979)
botbw Aug 9, 2024
f1a3a32
[fp8]Moe support fp8 communication (#5977)
flybird11111 Aug 9, 2024
b2483c8
[fp8] support hybrid parallel plugin (#5982)
wangbluo Aug 12, 2024
0978080
[fp8] refactor fp8 linear with compile (#5993)
ver217 Aug 13, 2024
597b206
[fp8] support asynchronous FP8 communication (#5997)
flybird11111 Aug 14, 2024
88fa096
[fp8] update torch.compile for linear_fp8 to >= 2.4.0 (#6004)
botbw Aug 15, 2024
1a2e90d
[fp8] linear perf enhancement
botbw Aug 15, 2024
20722a8
[fp8]update reduce-scatter test (#6002)
flybird11111 Aug 15, 2024
3f09a61
[fp8] add use_fp8 option for MoeHybridParallelPlugin (#6009)
wangbluo Aug 16, 2024
0a51319
[fp8] zero support fp8 linear. (#6006)
flybird11111 Aug 16, 2024
4cf79fa
merge
wangbluo Aug 17, 2024
81272e9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2024
02636c5
fix the merge
wangbluo Aug 19, 2024
52289e4
Merge branch 'fp8_merge' of https://github.com/wangbluo/ColossalAI in…
wangbluo Aug 19, 2024
1a5847e
fix the merge
wangbluo Aug 19, 2024
3353042
fix the merge
wangbluo Aug 19, 2024
64aad96
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2024
4c82bfc
fix the merge
wangbluo Aug 19, 2024
0d8e82a
Merge branch 'fp8_merge' of https://github.com/wangbluo/ColossalAI in…
wangbluo Aug 19, 2024
12b4401
fix
wangbluo Aug 19, 2024
2eb3683
fix
wangbluo Aug 19, 2024
88b3f06
fix the merge
wangbluo Aug 19, 2024
1f703e0
fix
wangbluo Aug 19, 2024
5382311
fix
wangbluo Aug 20, 2024
f7acfa1
fix
wangbluo Aug 20, 2024
2ee6235
fix
wangbluo Aug 20, 2024
2e4cbe3
fix
wangbluo Aug 20, 2024
2d362ac
fix merge
wangbluo Aug 20, 2024
eb5ba40
fix the merge
wangbluo Aug 21, 2024
193030f
fix
wangbluo Aug 21, 2024
6aface9
fix
wangbluo Aug 21, 2024
698c8b9
fix
wangbluo Aug 21, 2024
8b8e282
fix
wangbluo Aug 21, 2024
eea37da
[fp8] Merge feature/fp8_comm to main branch of Colossalai (#6016)
wangbluo Aug 22, 2024
d77e66a
Merge pull request #6023 from wangbluo/fp8_merge
wangbluo Aug 22, 2024
971b16a
fix
wangbluo Aug 22, 2024
a292554
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 22, 2024
afe845f
Merge pull request #6024 from wangbluo/fix_merge
wangbluo Aug 22, 2024
caab4a3
Merge branch 'main' into feature/fp8_comm
ver217 Aug 22, 2024
0bc9a87
Update train_dpo.py
flybird11111 Aug 23, 2024
3b0df30
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2024
9e76764
Update low_level_zero_plugin.py
flybird11111 Aug 23, 2024
0bf46c5
Merge pull request #6029 from hpcaitech/flybird11111-patch-1
wangbluo Aug 23, 2024
dae3999
fix
wangbluo Aug 26, 2024
80d24ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2024
4a6f31e
Merge pull request #6033 from wangbluo/fix
wangbluo Aug 26, 2024
17904cb
Merge pull request #6012 from hpcaitech/feature/fp8_comm
ver217 Aug 27, 2024
d383449
[CI] Remove triton version for compatibility bug; update req torch >=…
Edenzzzz Aug 27, 2024
cc1b0ef
[plugin] hotfix zero plugin (#6036)
ver217 Aug 28, 2024
4a68efb
[Colossal-LLaMA] Refactor latest APIs (#6030)
TongLi3701 Aug 28, 2024
0d3a85d
add fused norm (#6038)
TongLi3701 Aug 28, 2024
e96a076
[FP8] unsqueeze scale to make it compatible with torch.compile (#6040)
GuangyaoZhang Aug 29, 2024
e9032fb
[colossalai/checkpoint_io/...] fix bug in load_state_dict_into_model;…
flymin Sep 2, 2024
c650a90
[Hotfix] Remove deprecated install (#6042)
TongLi3701 Sep 3, 2024
c3b5caf
[fp8] optimize all-gather (#6043)
ver217 Sep 3, 2024
26e5539
[fp8] fix linear hook (#6046)
ver217 Sep 3, 2024
5ce6dd7
[fp8] disable all_to_all_fp8 in intranode (#6045)
BurkeHulk Sep 9, 2024
b3db105
[release] update version (#6041)
ver217 Sep 10, 2024
8fd25d6
[Feature] Split cross-entropy computation in SP (#5959)
Edenzzzz Sep 10, 2024
c54c4fc
[hotfix] moe hybrid parallelism benchmark & follow-up fix (#6048)
botbw Sep 10, 2024
13946c4
[fp8] hotfix backward hook (#6053)
ver217 Sep 11, 2024
a35a078
[doc] update sp doc (#6055)
flybird11111 Sep 11, 2024
fdd84b9
fix the sp
wangbluo Sep 13, 2024
216d54e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2024
0a01e2a
fix the attn
wangbluo Sep 13, 2024
683179c
fix
wangbluo Sep 13, 2024
6eb8832
fix
wangbluo Sep 13, 2024
f393867
fix
wangbluo Sep 13, 2024
dc03217
fix
wangbluo Sep 13, 2024
696fced
[fp8] fix missing fp8_comm flag in mixtral (#6057)
botbw Sep 13, 2024
0b14a55
fix
wangbluo Sep 13, 2024
0ad3129
fix
wangbluo Sep 13, 2024
b582319
fix
wangbluo Sep 13, 2024
f20b066
[fp8] Disable all_gather intranode. Disable Redundant all_gather fp8 …
GuangyaoZhang Sep 14, 2024
bdb125f
[doc] FP8 training and communication document (#6050)
GuangyaoZhang Sep 14, 2024
827ef3e
fix
wangbluo Sep 14, 2024
37e3523
Merge pull request #6061 from wangbluo/sp_fix
wangbluo Sep 14, 2024
10e4f7d
fix
wangbluo Sep 16, 2024
63314ce
Merge pull request #6064 from wangbluo/fix_attn
wangbluo Sep 18, 2024
4fa6b95
[moe] add parallel strategy for shared_expert && fix test for deepsee…
botbw Sep 18, 2024
f9546ba
[ColossalEval] support for vllm (#6056)
Camille7777 Sep 18, 2024
dabc2e7
[release] update version (#6062)
ver217 Sep 19, 2024
cbaa104
release FP8 news (#6068)
binmakeswell Sep 25, 2024
cfd9eda
fix the ring attn
wangbluo Sep 25, 2024
65c8297
fix the attn
wangbluo Sep 25, 2024
6fb1322
fix
wangbluo Sep 25, 2024
91ed32c
fix
wangbluo Sep 25, 2024
6705dad
fix
wangbluo Sep 25, 2024
f4daf04
add funding news (#6072)
binmakeswell Sep 26, 2024
3fab921
fix
wangbluo Sep 26, 2024
3532f77
fix
wangbluo Oct 9, 2024
3f5bec8
[feat] support zbv in mixtral benchmark;
duanjunwen Oct 9, 2024
b635dd0
fix
wangbluo Oct 9, 2024
9ee80fc
[fix] MixtralForCausalLMPolicy get_held_layer support zbv;
duanjunwen Oct 10, 2024
72b507a
[feat] update MixtralPipelineForwards --> mixtral_model_forward; supp…
duanjunwen Oct 10, 2024
646b3c5
[shardformer] fix linear 1d row and support uneven splits for fused q…
ver217 Oct 10, 2024
e234dfa
[feat] support MixtralPipelineForwards--> mixtral_for_causal_lm_forwa…
duanjunwen Oct 10, 2024
f98384a
fix
wangbluo Oct 10, 2024
5ecc27e
fix
wangbluo Oct 10, 2024
6b2c506
Update README.md (#6087)
supercooledith Oct 10, 2024
efe3042
fix
wangbluo Oct 10, 2024
dc2cdaf
[shardformer] optimize seq parallelism (#6086)
ver217 Oct 11, 2024
0002ae5
fix
wangbluo Oct 11, 2024
1507a75
fix
wangbluo Oct 11, 2024
0ca16d5
[fix] fix llama, mixtral benchmark zbv loss none bug; update mixtral …
duanjunwen Oct 11, 2024
4e0e99b
fix the test
wangbluo Oct 11, 2024
703bb5c
fix the test
wangbluo Oct 11, 2024
4c8e85e
[Coati] Train DPO using PP (#6054)
TongLi3701 Oct 11, 2024
e1e86f9
fix
wangbluo Oct 14, 2024
d891e50
fix
wangbluo Oct 14, 2024
cfade4c
[feat] Linear1D_COL/ROW support zbv WeightGradStore;
duanjunwen Oct 14, 2024
a11b4b5
[feat] support use_zbv in llama, mixtral modeling; only replace Linea…
duanjunwen Oct 14, 2024
abd4551
[fix] fix test case; moe error in second iter
duanjunwen Oct 14, 2024
160e9a4
[feat]EPMixtralSparseMoeBlock (op in MOE) support zbv;
duanjunwen Oct 14, 2024
23199e3
fix
wangbluo Oct 14, 2024
3201377
fix
wangbluo Oct 14, 2024
fe9208f
fix
wangbluo Oct 14, 2024
8ff7d0c
fix
wangbluo Oct 14, 2024
3dc08c8
fix
wangbluo Oct 15, 2024
6be9862
fix
wangbluo Oct 15, 2024
fd92789
fix
wangbluo Oct 15, 2024
bc7eead
fix
wangbluo Oct 15, 2024
9912cc8
[fix] fix bwd b; now bwd w only for Layer replaced by Linear1D_Col/Ro…
duanjunwen Oct 15, 2024
52dcc73
Merge branch 'feature/zerobubble' of github.com:hpcaitech/ColossalAI …
duanjunwen Oct 15, 2024
83cf2f8
fix
wangbluo Oct 15, 2024
dcd41d0
Merge pull request #6071 from wangbluo/ring_attention
wangbluo Oct 15, 2024
90939b7
[fix] debug zbv llama test;
duanjunwen Oct 15, 2024
62c13e7
[Ring Attention] Improve comments (#6085)
Edenzzzz Oct 16, 2024
e76308c
[fix] rm use_zbv flag in Shardconfig; rm debug info;
duanjunwen Oct 16, 2024
705b18e
[fix] add & fix llama test
duanjunwen Oct 16, 2024
cd61353
[pipeline] hotfix backward for multiple outputs (#6090)
ver217 Oct 16, 2024
2bcd0b6
[ckpt] add safetensors util
botbw Oct 14, 2024
3b1d7d1
[chore] refactor
botbw Oct 14, 2024
5ddad48
[fp8] add fallback and make compile option configurable (#6092)
ver217 Oct 18, 2024
58d8b8a
[misc] fit torch api upgradation and remove legecy import (#6093)
ver217 Oct 18, 2024
19baab5
[release] update version (#6094)
ver217 Oct 21, 2024
b10339d
fix lora ckpt save format (ColoTensor to Tensor)
BurkeHulk Oct 21, 2024
6d6cafa
pre-commit fix
BurkeHulk Oct 21, 2024
dee63cc
Merge pull request #6096 from BurkeHulk/hotfix/lora_ckpt
BurkeHulk Oct 21, 2024
80a8ca9
[extension] hotfix compile check (#6099)
ver217 Oct 24, 2024
4294ae8
[doc] sora solution news (#6100)
binmakeswell Oct 24, 2024
2eca112
[feat] support meta cache, meta_grad_send, meta_tensor_send; fix runt…
duanjunwen Oct 24, 2024
89a9a60
[MCTS] Add self-refined MCTS (#6098)
TongLi3701 Oct 24, 2024
d0ec221
[fix\ fix fail case test_shard_llama
duanjunwen Oct 25, 2024
cc0dfdd
[fix] fix test_shard_llama
duanjunwen Oct 25, 2024
03fa79a
[fix] fix llama modeling policy;
duanjunwen Oct 25, 2024
6377aa0
[fix] fix test_shard_llama ci;
duanjunwen Oct 28, 2024
5aee426
[fix] fix test zerobubble
duanjunwen Oct 28, 2024
fafe049
[fix] fix handle name; rm useless comments;
duanjunwen Oct 29, 2024
fa3ccda
[fix] fix send recv signature;
duanjunwen Oct 29, 2024
982e4ee
[fix] fix comment in llama & benchmark
duanjunwen Oct 29, 2024
d2e05a9
[feat] support no tensor parallel Linear in shardformer; Add test for…
duanjunwen Oct 30, 2024
5f09243
[fix] fix linear (no tp) ops func name;
duanjunwen Oct 31, 2024
c2e8f61
[checkpointio] fix hybrid plugin model save (#6106)
ver217 Oct 31, 2024
2f583c1
[pre-commit.ci] pre-commit autoupdate (#6078)
pre-commit-ci[bot] Oct 31, 2024
1d328ff
Merge branch 'main' into dev/zero_bubble
duanjunwen Nov 1, 2024
c82c75a
Merge branch 'feature/zerobubble' of github.com:hpcaitech/ColossalAI …
duanjunwen Nov 1, 2024
3b5c314
[fix] fix fp8 args in HybridParallel
duanjunwen Nov 1, 2024
5b5fbcf
[fix] fix hybridparall use_fp8 config
duanjunwen Nov 1, 2024
0218e67
[fix] fix use_fp8 flag
duanjunwen Nov 1, 2024
8e40087
[fix] fix model zoo init
duanjunwen Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[fp8] support all-gather flat tensor (#5932)
ver217 authored Jul 24, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 5fd0592767c1dcdf88f89d2c37cf399acb52c2b9
76 changes: 76 additions & 0 deletions colossalai/quantization/fp8.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any

import numpy as np
import torch
import torch.distributed as dist

@@ -202,3 +203,78 @@ def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2
out = out.view(fp8_type)
summed_out += cast_from_fp8(out, scale, input_type)
output.data = summed_out


def split_chunk_by_channel(
chunk: torch.Tensor, channel_size: int, num_channels: int, rank: int = 0, world_size: int = 1
):
offset = chunk.numel() * rank
end = offset + chunk.numel()
break_points = [x for x in range(0, channel_size * num_channels + 1, channel_size) if offset <= x <= end]
if len(break_points) == 0 or break_points[0] > offset:
break_points.insert(0, offset)
if break_points[-1] < end:
break_points.append(end)
sizes = [b - a for a, b in zip(break_points[:-1], break_points[1:])]
return chunk.split(sizes)


def all_gather_into_tensor_flat_fp8(
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
output_shape: torch.Size,
group: dist.ProcessGroup,
fp8_format: str = "e4m3",
):
"""all gather into tensor in fp8 format
Args:
output_tensor (torch.Tensor): output tensor, which is flattened
input_tensor (torch.Tensor): input tensor, which is flattened
group (dist.ProcessGroup): process group
fp8_format (str, optional): fp8 format, e4m3 or e5m2. Defaults to "e4m3".
"""
assert input_tensor.dim() == 1 and output_tensor.dim() == 1, "input/output tensor should be flattened"
world_size = dist.get_world_size(group)
assert (
output_tensor.numel() == input_tensor.numel() * world_size
), "output tensor size should be world_size times of input tensor size"

input_type = output_tensor.dtype

fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
fp8_max = torch.finfo(fp8_type).max

if len(output_shape) == 2:
per_channel_max = torch.zeros(output_shape[0], device=output_tensor.device, dtype=torch.float)
num_channels, channel_size = output_shape
rank = dist.get_rank(group)
channel_start_idx = (input_tensor.numel() * rank) // channel_size
per_channel_splits = split_chunk_by_channel(input_tensor, channel_size, num_channels, rank, world_size)
for i, per_channel_split in enumerate(per_channel_splits):
idx = i + channel_start_idx
if idx < num_channels:
per_channel_max[idx] = per_channel_split.abs().max().float()
dist.all_reduce(per_channel_max, op=dist.ReduceOp.MAX, group=group)
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
scale = fp8_max / per_channel_max
fp8_input = input_tensor.float()
fp8_per_channel_splits = split_chunk_by_channel(fp8_input, channel_size, num_channels, rank, world_size)
for i, per_channel_split in enumerate(fp8_per_channel_splits):
idx = i + channel_start_idx
if idx < num_channels:
per_channel_split.mul_(scale[idx])
fp8_input = fp8_input.to(fp8_type)
else:
per_tensor_max = input_tensor.abs().max().float()
dist.all_reduce(per_tensor_max, op=dist.ReduceOp.MAX, group=group)
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
scale = fp8_max / per_tensor_max
fp8_input = (scale * input_tensor.float()).to(fp8_type)
scale_inv = 1.0 / scale
buffer = torch.empty_like(output_tensor, dtype=fp8_type)
dist.all_gather_into_tensor(buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group)
numel = np.prod(output_shape)
valid_buffer = buffer[:numel].reshape(output_shape)
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type)
output_tensor[:numel].copy_(valid_buffer.view(-1))
40 changes: 40 additions & 0 deletions tests/test_fp8/test_fp8_allgather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed.distributed_c10d import _get_default_group
from torch.testing import assert_close

from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn


@parameterize("shape", [(3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)])
@parameterize("dtype", [torch.bfloat16, torch.float16])
def check_4gpu(shape, dtype):
world_size = dist.get_world_size()
rank = dist.get_rank()
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
flat_padded_x = x.view(-1)
if flat_padded_x.size(0) % world_size != 0:
pad_size = world_size - flat_padded_x.size(0) % world_size
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
output = torch.empty_like(flat_padded_x)
chunk = flat_padded_x.chunk(world_size)[rank].clone()
all_gather_into_tensor_flat_fp8(output, chunk, x.shape, group=_get_default_group())
assert_close(output[: x.numel()], x.view(-1), rtol=0.1, atol=0.1)


def run_dist(rank, world_size, port):
launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_4gpu()


@rerun_if_address_is_in_use()
def test_all_gather():
spawn(run_dist, 4)


if __name__ == "__main__":
test_all_gather()