Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[deepgemm] adapt deepgemm #6245

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ class GeminiPlugin(DPPluginBase):
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
use_deep_gemm (bool, optional): Whether to use deep_gemm for fp8 matmul. Defaults to False.
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
"""
Expand Down Expand Up @@ -479,6 +480,7 @@ def __init__(
enable_jit_fused: bool = False,
enable_async_reduce: bool = True,
use_fp8: bool = False,
use_deep_gemm: bool = False,
verbose: bool = False,
fp8_communication: bool = False,
) -> None:
Expand Down Expand Up @@ -517,6 +519,7 @@ def __init__(
enable_async_reduce=enable_async_reduce,
fp8_communication=fp8_communication,
use_fp8=use_fp8,
use_deep_gemm=use_deep_gemm,
)
self.zero_optim_config = dict(
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
Expand Down
12 changes: 10 additions & 2 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp
from colossalai.shardformer.policies.base_policy import Policy
Expand Down Expand Up @@ -70,6 +70,7 @@ def __init__(
custom_policy: Policy,
overlap_allgather: bool = False,
use_fp8: bool = False,
use_deep_gemm: bool = False,
) -> None:
self.stage_manager = shard_config.pipeline_stage_manager
self.shard_config = shard_config
Expand All @@ -80,6 +81,7 @@ def __init__(
self.require_grad_sync = True
self.overlap_allgather = overlap_allgather
self.use_fp8 = use_fp8
self.use_deep_gemm = use_deep_gemm

shardformer = ShardFormer(shard_config)
if custom_policy is not None:
Expand Down Expand Up @@ -119,7 +121,10 @@ def __init__(
super().__init__(module)
self.op_hooks = []
if use_fp8:
self.op_hooks.append(FP8Hook())
if use_deep_gemm:
self.op_hooks.append(FP8DeepGemmHook())
else:
self.op_hooks.append(FP8Hook())
if overlap_allgather:
self.op_hooks.append(ZeroOpHook())
if use_fp8 or overlap_allgather:
Expand Down Expand Up @@ -1044,6 +1049,7 @@ def __init__(
overlap_allgather: bool = False,
fp8_communication: bool = False,
use_fp8: bool = False,
use_deep_gemm: bool = False,
inner_ring_size: int = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -1097,6 +1103,7 @@ def __init__(
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.use_fp8 = use_fp8
self.use_deep_gemm = use_deep_gemm
if dp_outside:
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
Expand Down Expand Up @@ -1323,6 +1330,7 @@ def configure(
custom_policy=self.custom_policy,
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
use_fp8=self.use_fp8,
use_deep_gemm=self.use_deep_gemm,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if zero_stage == 0:
Expand Down
12 changes: 10 additions & 2 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero import LowLevelZeroOptimizer
Expand Down Expand Up @@ -72,6 +72,7 @@ def __init__(
overlap_allgather: bool = False,
cast_inputs: bool = True,
use_fp8: bool = False,
use_deep_gemm: bool = False,
) -> None:
super().__init__(module)
self.dtype = None
Expand All @@ -92,7 +93,10 @@ def __init__(
if overlap_allgather:
self.op_hooks.append(ZeroOpHook())
if use_fp8:
self.op_hooks.append(FP8Hook())
if use_deep_gemm:
self.op_hooks.append(FP8DeepGemmHook())
else:
self.op_hooks.append(FP8Hook())
if overlap_allgather or use_fp8:
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
Expand Down Expand Up @@ -400,6 +404,7 @@ class LowLevelZeroPlugin(DPPluginBase):
cpu_offload (bool, optional): whether to offload grad, master weight and optimizer state to cpu. Defaults to False.
verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
use_deep_gemm (bool, optional): Whether to use deep_gemm matmul. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
extra_dp_size (int, optional): The number of extra data parallel groups. Defaults to 1.
"""
Expand Down Expand Up @@ -427,6 +432,7 @@ def __init__(
cast_inputs: bool = True,
fp8_communication: bool = False,
use_fp8: bool = False,
use_deep_gemm: bool = False,
extra_dp_size: int = 1,
) -> None:
super().__init__()
Expand Down Expand Up @@ -466,6 +472,7 @@ def __init__(
self.cast_inputs = cast_inputs

self.use_fp8 = use_fp8
self.use_deep_gemm = use_deep_gemm
# set class name with stage, for better error message
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")

Expand Down Expand Up @@ -585,6 +592,7 @@ def configure(
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
cast_inputs=self.cast_inputs,
use_fp8=self.use_fp8,
use_deep_gemm=self.use_deep_gemm,
)

# TODO: Support Galore + ZeRO
Expand Down
4 changes: 4 additions & 0 deletions colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
use_deep_gemm (bool, optional): Whether to use deep gemm for fp8 training. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
"""

Expand Down Expand Up @@ -222,6 +223,7 @@ def __init__(
overlap_allgather: bool = False,
fp8_communication: bool = False,
use_fp8: bool = False,
use_deep_gemm: bool = False,
) -> None:
self.logger = get_dist_logger()
if overlap_communication or zero_stage == 2:
Expand Down Expand Up @@ -359,6 +361,7 @@ def __init__(
self.mixed_dp_group = self.dp_group

self.use_fp8 = use_fp8
self.use_deep_gemm = use_deep_gemm

self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
Expand Down Expand Up @@ -465,6 +468,7 @@ def configure(
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
use_fp8=self.use_fp8,
use_deep_gemm=self.use_deep_gemm,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
Expand Down
94 changes: 94 additions & 0 deletions colossalai/quantization/deep_gemm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# This file was modifed from https://github.com/deepseek-ai/DeepGEMM
# as the utils are not included in library
# Thanks for developing and open-sourcing the performant kernel

# Original LICENSE:

# MIT License

# Copyright (c) 2025 DeepSeek

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import warnings
from typing import Tuple

import torch

__WARNING_MSG = "Couldn't find deep_gemm library, please install from https://github.com/deepseek-ai/DeepGEMM and run corresponding tests"
try:
from deep_gemm import ceil_div, gemm_fp8_fp8_bf16_nt

IS_DEEP_GEMM_AVAIL = True
except ImportError:
IS_DEEP_GEMM_AVAIL = False
warnings.warn(__WARNING_MSG)

def ceil_dev(*args, **kwargs): # to surpass code lint
raise NotImplementedError(__WARNING_MSG)

def gemm_fp8_fp8_bf16_nt(*args, **kwargs):
raise NotImplementedError(__WARNING_MSG)


def deepgemm_fp8_gemm(
lhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor], out: torch.Tensor
) -> None:
gemm_fp8_fp8_bf16_nt(lhs, rhs, out)


# TODO: There seems to be better kernel implemented in triton
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Casting input tensor to float8_e4m3fn percicision and cooresponding scaler in token-wise mannar
Args:
x (`torch.Tensor`):
Matmul x in x @ y.t(), where x.shape() is (m, k)

Returns:
`Tuple[torch.Tensor, torch.Tensor]`: x_float8_e4m3fn and scaler
"""
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)


# TODO: There seems to be better kernel implemented in triton
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
def per_block_cast_to_fp8(y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Casting input tensor to float8_e4m3fn percicision and cooresponding scaler in block-wise mannar
Args:
y (`torch.Tensor`):
Matmul y in x @ y.t(), where y.shape() is (n, k)

Returns:
`Tuple[torch.Tensor, torch.Tensor]`: y_float8_e4m3fn and scaler
"""
assert y.dim() == 2
m, n = y.shape
x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=y.dtype, device=y.device)
x_padded[:m, :n] = y
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
Loading