diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index a81f9b05d7d7..af135db44632 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -435,6 +435,7 @@ class GeminiPlugin(DPPluginBase): enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False. + use_deep_gemm (bool, optional): Whether to use deep_gemm for fp8 matmul. Defaults to False. verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False. fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False. """ @@ -479,6 +480,7 @@ def __init__( enable_jit_fused: bool = False, enable_async_reduce: bool = True, use_fp8: bool = False, + use_deep_gemm: bool = False, verbose: bool = False, fp8_communication: bool = False, ) -> None: @@ -517,6 +519,7 @@ def __init__( enable_async_reduce=enable_async_reduce, fp8_communication=fp8_communication, use_fp8=use_fp8, + use_deep_gemm=use_deep_gemm, ) self.zero_optim_config = dict( gpu_margin_mem_ratio=gpu_margin_mem_ratio, diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1e0f7be240f6..ac8c2ab9daf3 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -33,7 +33,7 @@ from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.quantization import BnbQuantizationConfig, quantize_model -from colossalai.quantization.fp8_hook import FP8Hook +from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp from colossalai.shardformer.policies.base_policy import Policy @@ -70,6 +70,7 @@ def __init__( custom_policy: Policy, overlap_allgather: bool = False, use_fp8: bool = False, + use_deep_gemm: bool = False, ) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.shard_config = shard_config @@ -80,6 +81,7 @@ def __init__( self.require_grad_sync = True self.overlap_allgather = overlap_allgather self.use_fp8 = use_fp8 + self.use_deep_gemm = use_deep_gemm shardformer = ShardFormer(shard_config) if custom_policy is not None: @@ -119,7 +121,10 @@ def __init__( super().__init__(module) self.op_hooks = [] if use_fp8: - self.op_hooks.append(FP8Hook()) + if use_deep_gemm: + self.op_hooks.append(FP8DeepGemmHook()) + else: + self.op_hooks.append(FP8Hook()) if overlap_allgather: self.op_hooks.append(ZeroOpHook()) if use_fp8 or overlap_allgather: @@ -1044,6 +1049,7 @@ def __init__( overlap_allgather: bool = False, fp8_communication: bool = False, use_fp8: bool = False, + use_deep_gemm: bool = False, inner_ring_size: int = None, ) -> None: super().__init__() @@ -1097,6 +1103,7 @@ def __init__( self.enable_jit_fused = enable_jit_fused self.enable_sequence_parallelism = enable_sequence_parallelism self.use_fp8 = use_fp8 + self.use_deep_gemm = use_deep_gemm if dp_outside: self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) @@ -1323,6 +1330,7 @@ def configure( custom_policy=self.custom_policy, overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]), use_fp8=self.use_fp8, + use_deep_gemm=self.use_deep_gemm, ) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if zero_stage == 0: diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 642969be3a68..0c02189bbb19 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -37,7 +37,7 @@ from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.quantization import BnbQuantizationConfig, quantize_model -from colossalai.quantization.fp8_hook import FP8Hook +from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.zero import LowLevelZeroOptimizer @@ -72,6 +72,7 @@ def __init__( overlap_allgather: bool = False, cast_inputs: bool = True, use_fp8: bool = False, + use_deep_gemm: bool = False, ) -> None: super().__init__(module) self.dtype = None @@ -92,7 +93,10 @@ def __init__( if overlap_allgather: self.op_hooks.append(ZeroOpHook()) if use_fp8: - self.op_hooks.append(FP8Hook()) + if use_deep_gemm: + self.op_hooks.append(FP8DeepGemmHook()) + else: + self.op_hooks.append(FP8Hook()) if overlap_allgather or use_fp8: for p in module.parameters(): if p.requires_grad and type(p) is not ColoParameter: @@ -400,6 +404,7 @@ class LowLevelZeroPlugin(DPPluginBase): cpu_offload (bool, optional): whether to offload grad, master weight and optimizer state to cpu. Defaults to False. verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False. use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False. + use_deep_gemm (bool, optional): Whether to use deep_gemm matmul. Defaults to False. fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False. extra_dp_size (int, optional): The number of extra data parallel groups. Defaults to 1. """ @@ -427,6 +432,7 @@ def __init__( cast_inputs: bool = True, fp8_communication: bool = False, use_fp8: bool = False, + use_deep_gemm: bool = False, extra_dp_size: int = 1, ) -> None: super().__init__() @@ -466,6 +472,7 @@ def __init__( self.cast_inputs = cast_inputs self.use_fp8 = use_fp8 + self.use_deep_gemm = use_deep_gemm # set class name with stage, for better error message setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") @@ -585,6 +592,7 @@ def configure( overlap_allgather=self.zero_optim_kwargs["overlap_allgather"], cast_inputs=self.cast_inputs, use_fp8=self.use_fp8, + use_deep_gemm=self.use_deep_gemm, ) # TODO: Support Galore + ZeRO diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index a733fc5f5261..8e0ac6bcfff2 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -171,6 +171,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism. use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False. + use_deep_gemm (bool, optional): Whether to use deep gemm for fp8 training. Defaults to False. fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False. """ @@ -222,6 +223,7 @@ def __init__( overlap_allgather: bool = False, fp8_communication: bool = False, use_fp8: bool = False, + use_deep_gemm: bool = False, ) -> None: self.logger = get_dist_logger() if overlap_communication or zero_stage == 2: @@ -359,6 +361,7 @@ def __init__( self.mixed_dp_group = self.dp_group self.use_fp8 = use_fp8 + self.use_deep_gemm = use_deep_gemm self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, @@ -465,6 +468,7 @@ def configure( ddp_config=self.ddp_config, custom_policy=self.custom_policy, use_fp8=self.use_fp8, + use_deep_gemm=self.use_deep_gemm, ) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: diff --git a/colossalai/quantization/deep_gemm_utils.py b/colossalai/quantization/deep_gemm_utils.py new file mode 100644 index 000000000000..72a103175ad7 --- /dev/null +++ b/colossalai/quantization/deep_gemm_utils.py @@ -0,0 +1,94 @@ +# This file was modifed from https://github.com/deepseek-ai/DeepGEMM +# as the utils are not included in library +# Thanks for developing and open-sourcing the performant kernel + +# Original LICENSE: + +# MIT License + +# Copyright (c) 2025 DeepSeek + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import warnings +from typing import Tuple + +import torch + +__WARNING_MSG = "Couldn't find deep_gemm library, please install from https://github.com/deepseek-ai/DeepGEMM and run corresponding tests" +try: + from deep_gemm import ceil_div, gemm_fp8_fp8_bf16_nt + + IS_DEEP_GEMM_AVAIL = True +except ImportError: + IS_DEEP_GEMM_AVAIL = False + warnings.warn(__WARNING_MSG) + + def ceil_dev(*args, **kwargs): # to surpass code lint + raise NotImplementedError(__WARNING_MSG) + + def gemm_fp8_fp8_bf16_nt(*args, **kwargs): + raise NotImplementedError(__WARNING_MSG) + + +def deepgemm_fp8_gemm( + lhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor], out: torch.Tensor +) -> None: + gemm_fp8_fp8_bf16_nt(lhs, rhs, out) + + +# TODO: There seems to be better kernel implemented in triton +@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Casting input tensor to float8_e4m3fn percicision and cooresponding scaler in token-wise mannar + Args: + x (`torch.Tensor`): + Matmul x in x @ y.t(), where x.shape() is (m, k) + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: x_float8_e4m3fn and scaler + """ + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) + + +# TODO: There seems to be better kernel implemented in triton +@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +def per_block_cast_to_fp8(y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Casting input tensor to float8_e4m3fn percicision and cooresponding scaler in block-wise mannar + Args: + y (`torch.Tensor`): + Matmul y in x @ y.t(), where y.shape() is (n, k) + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: y_float8_e4m3fn and scaler + """ + assert y.dim() == 2 + m, n = y.shape + x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=y.dtype, device=y.device) + x_padded[:m, :n] = y + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index e23da5cccd4d..4b7c4835411a 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -8,6 +8,7 @@ from packaging.version import Version from torch.distributed import ReduceOp +from .deep_gemm_utils import deepgemm_fp8_gemm, per_block_cast_to_fp8, per_token_cast_to_fp8 from .fp8_config import dynamic_kernel SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0") @@ -699,17 +700,11 @@ def all_gather_fp8_lagacy( ret = cur_buffer[SCALE_BYTES:].view(fp8_type) ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret) cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale - # cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8) dist.all_gather(combined_buffers, cur_buffer, group=group, async_op=async_op) for out, buf in zip(output_list, combined_buffers): scale = buf[:SCALE_BYTES].clone().view(scale.dtype) output = buf[SCALE_BYTES:].view(fp8_type) cast_from_fp8(output.view(shape), scale, input_type, out=out) - # output = combined_buffer.view(world_size, -1)[:, SCALE_BYTES:].view(fp8_type) - # scales = combined_buffer.view(world_size, -1)[:, :SCALE_BYTES].view(torch.float) - # output = output.float() * scales - # for i, out in enumerate(output_list): - # out.copy_(output[i].view(shape)) @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89) @@ -805,7 +800,11 @@ def forward( scale_a=inv_scale_x, scale_b=inv_scale_w, use_fast_accum=True, - )[0] + ) + + if isinstance(out, tuple): + out = out[0] + return out.reshape(*ctx.x_shape[:-1], w.shape[0]) @staticmethod @@ -819,7 +818,11 @@ def backward(ctx: Any, out_grad) -> Any: scale_a=out_grad_scale, scale_b=ctx.inv_scale_w, use_fast_accum=True, - )[0] + ) + + if isinstance(x_grad, tuple): + x_grad = x_grad[0] + w_grad = torch._scaled_mm( out_grad_fp8.t().contiguous(), ctx.x_fp8.t().contiguous().t(), @@ -827,13 +830,78 @@ def backward(ctx: Any, out_grad) -> Any: scale_a=out_grad_scale, scale_b=ctx.inv_scale_x, use_fast_accum=True, - )[0] + ) + + if isinstance(w_grad, tuple): + w_grad = w_grad[0] + bias_grad = None if ctx.has_bias: bias_grad = out_grad.sum(0) return x_grad.reshape(ctx.x_shape), w_grad, bias_grad +class _LinearFp8DeepGemm(torch.autograd.Function): + """ + Behave similar to torch.nn.functional.linear + """ + + def forward(ctx: Any, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + has_batch_dim = False + if x.dim() == 3: + has_batch_dim = True + if x.size(1) != 1: + raise ValueError(f"Batched fp8 deep_gemm is not supported, found x shape: {x.shape}") + x = x.squeeze(1) + ctx.has_batch_dim = has_batch_dim + + # x: (m, k), w: (n, k) + # x @ w_t -> (m, k) @ (k, n) -> deep_gemm((m, k), (n, k)) + (m, k), (n, _) = x.shape, w.shape + x_per_tok, w_per_blk = per_token_cast_to_fp8(x), per_block_cast_to_fp8(w) + + out = torch.empty((m, n), dtype=torch.bfloat16, device=x.device) # NOTE: DeepGemm only supports bf16 output + deepgemm_fp8_gemm(x_per_tok, w_per_blk, out) + + ctx.w_t_per_plk = per_block_cast_to_fp8(w.t()) + ctx.x_t_per_blk = per_block_cast_to_fp8(x.t()) + ctx.mnk = (m, n, k) + if has_batch_dim: + out = out.unsqueeze(1) + return out + + def backward(ctx: Any, o_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # o_grad: (m, n) + # x_grad: (m, k) -> (m, n) @ (n, k) -> deep_gemm((m, n), (k, n)) + # w_grad: (n, k) -> (m, n).t() @ (m, k) -> deep_gemm((m, n).t(), (k, m)) + if ctx.has_batch_dim: + o_grad = o_grad.squeeze(1) + + m, n, k = ctx.mnk + o_per_tok = per_token_cast_to_fp8(o_grad) + + x_grad = torch.empty((m, k), dtype=torch.bfloat16, device=o_grad.device) + deepgemm_fp8_gemm(o_per_tok, ctx.w_t_per_plk, x_grad) + + o_grad_t_per_tok = per_token_cast_to_fp8(o_grad.t()) + w_grad = torch.empty((n, k), dtype=torch.bfloat16, device=o_grad.device) + deepgemm_fp8_gemm(o_grad_t_per_tok, ctx.x_t_per_blk, w_grad) + + if ctx.has_batch_dim: + x_grad = x_grad.unsqueeze(1) + + return x_grad, w_grad + + +def linear_fp8_deep_gemm( + input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None +) -> torch.Tensor: + o = _LinearFp8DeepGemm.apply(input, weight) + if bias is not None: + o += bias + return o + + @torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=dynamic_kernel) def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: return _LinearFp8.apply(input, weight, bias) diff --git a/colossalai/quantization/fp8_hook.py b/colossalai/quantization/fp8_hook.py index 6171dd755a9d..8569358a4672 100644 --- a/colossalai/quantization/fp8_hook.py +++ b/colossalai/quantization/fp8_hook.py @@ -1,6 +1,6 @@ import torch.nn.functional as F -from colossalai.quantization.fp8 import linear_fp8 +from colossalai.quantization.fp8 import linear_fp8, linear_fp8_deep_gemm from colossalai.tensor.param_op_hook import ColoParamOpHook @@ -21,3 +21,23 @@ def rewrite_op(self, func): if func is F.linear: return linear_fp8 return func + + +class FP8DeepGemmHook(ColoParamOpHook): + + def pre_forward(self, params) -> None: + pass + + def post_forward(self, params) -> None: + pass + + def pre_backward(self, params) -> None: + pass + + def post_backward(self, params) -> None: + pass + + def rewrite_op(self, func): + if func is F.linear: + return linear_fp8_deep_gemm + return func diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 9e89e88272e0..f4e2ce0aea1c 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -15,7 +15,7 @@ from colossalai.interface import ModelWrapper from colossalai.lazy import LazyTensor from colossalai.logging import get_dist_logger -from colossalai.quantization.fp8_hook import FP8Hook +from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.d_tensor import ( distribute_tensor, @@ -101,6 +101,7 @@ def __init__( enable_async_reduce: bool = True, fp8_communication: bool = False, use_fp8: bool = False, + use_deep_gemm: bool = False, ) -> None: assert mixed_precision in (torch.float16, torch.bfloat16) reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False @@ -142,7 +143,10 @@ def __init__( self.param_op_hook = GeminiZeROHook(self.gemini_manager) self.hooks = [self.param_op_hook] if use_fp8: - self.hooks.append(FP8Hook()) + if use_deep_gemm: + self.hooks.append(FP8DeepGemmHook()) + else: + self.hooks.append(FP8Hook()) self.fp32_params: List[torch.Tensor] = list() self.fp16_params: List[ColoParameter] = list() self.grads_device: Dict[torch.Tensor, torch.device] = dict() diff --git a/docs/source/en/features/mixed_precision_training_with_booster.md b/docs/source/en/features/mixed_precision_training_with_booster.md index 1e17c2bb584d..1af0745780e5 100644 --- a/docs/source/en/features/mixed_precision_training_with_booster.md +++ b/docs/source/en/features/mixed_precision_training_with_booster.md @@ -63,7 +63,7 @@ However, there are other operations, like reductions, which require the dynamic We supported three AMP training methods and allowed the user to train with AMP with no code. If you want to train with amp, just assign `mixed_precision` with `fp16` when you instantiate the `Booster`. Next we will support `bf16`. -Currently we only support `fp8` mixed precision training for the `Linear` layer. Please specify the `use_fp8` parameter when create the plugin object. +Currently we only support `fp8` mixed precision training for the `Linear` layer, please specify the `use_fp8` parameter when create the plugin object. `deep_gemm` fp8 matmul is adopted which can be enabled by specifying `use_deep_gemm`. To reduce the communication volume inter nodes in low-bandwidth scenarios, we support FP8 communication compression. Please specify the `fp8_communication` parameter when create the plugin object. diff --git a/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md b/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md index 93a69830cadf..83e2adbfb418 100644 --- a/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md +++ b/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md @@ -59,7 +59,7 @@ AMP 代表自动混合精度训练。 我们支持三种 AMP 训练方法,并允许用户在没有改变代码的情况下使用 AMP 进行训练。booster 支持 amp 特性注入,如果您要使用混合精度训练,则在创建 booster 实例时指定`mixed_precision`参数; 后续将会拓展`bf16`. -我们目前只支持`Linear`层的`fp8`混合精度训练,如果您需要使用,请在创建 plugin实例时指定`use_fp8`参数。 +我们目前只支持`Linear`层的`fp8`混合精度训练,如果您需要使用,请在创建 plugin实例时指定`use_fp8`参数,`deep_gemm`fp8矩阵乘法适配请指定`use_deep_gemm`参数。 为了减少低带宽场景下多机之间的通讯负载,我们还支持了FP8通讯。如果您需要使用,请在创建 plugin实例时指定`fp8_communication`参数。 diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 2964f83f4f86..231173696696 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -107,6 +107,7 @@ def main(): parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") parser.add_argument("--no_cache", action="store_true") parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication") + parser.add_argument("--use_deep_gemm", action="store_true", default=False, help="for using deep gemm") parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear") parser.add_argument("--overlap_p2p", action="store_true", default=True, help="for using overlap p2p") parser.add_argument("--overlap_allgather", action="store_true") @@ -159,6 +160,7 @@ def empty_init(): max_prefetch=args.prefetch_num, enable_async_reduce=not args.disable_async_reduce, use_fp8=args.use_fp8, + use_deep_gemm=args.use_deep_gemm, fp8_communication=args.use_fp8_comm, ) elif args.plugin == "gemini_auto": @@ -173,6 +175,7 @@ def empty_init(): enable_async_reduce=not args.disable_async_reduce, enable_flash_attention=args.xformers, use_fp8=args.use_fp8, + use_deep_gemm=args.use_deep_gemm, fp8_communication=args.use_fp8_comm, ) elif args.plugin == "fsdp": @@ -252,6 +255,7 @@ def empty_init(): enable_metadata_cache=not args.no_cache, overlap_allgather=args.overlap_allgather, use_fp8=args.use_fp8, + use_deep_gemm=args.use_deep_gemm, fp8_communication=args.use_fp8_comm, scheduler_nodes=scheduler_nodes, **hybrid_kwargs, @@ -271,6 +275,7 @@ def empty_init(): precision="bf16", overlap_p2p=args.overlap_p2p, use_fp8=args.use_fp8, + use_deep_gemm=args.use_deep_gemm, fp8_communication=args.use_fp8_comm, ) else: diff --git a/tests/test_fp8/test_fp8_deepgemm.py b/tests/test_fp8/test_fp8_deepgemm.py new file mode 100644 index 000000000000..5f1de4f9c0ae --- /dev/null +++ b/tests/test_fp8/test_fp8_deepgemm.py @@ -0,0 +1,31 @@ +import pytest +import torch +import torch.nn.functional as F +from torch.testing import assert_close + +from colossalai.accelerator import get_accelerator +from colossalai.quantization.fp8 import linear_fp8_deep_gemm +from colossalai.utils import get_current_device + +m, k, n = 128, 384, 256 +DTYPE = torch.bfloat16 + + +@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0") +def test_fp8_linear(): + # create tensors + x = torch.rand((m, k), device=get_current_device(), dtype=DTYPE, requires_grad=True) + w = torch.rand((n, k), device=get_current_device(), dtype=DTYPE, requires_grad=True) + bias = torch.rand(n, device=get_current_device(), dtype=DTYPE, requires_grad=True) + ref_w = w.clone().detach().requires_grad_() + ref_x = x.clone().detach().requires_grad_() + + out = linear_fp8_deep_gemm(x, w, bias) + assert out.shape == x.shape[:-1] + (n,) + out.sum().backward() + ref_out = F.linear(ref_x, ref_w, bias) + ref_out.sum().backward() + + assert_close(out, ref_out) + assert_close(x.grad, ref_x.grad) + assert_close(w.grad, ref_w.grad) diff --git a/tests/test_fp8/test_fp8_hook.py b/tests/test_fp8/test_fp8_hook.py index abd5d09e128e..a7bc3b7b92f6 100644 --- a/tests/test_fp8/test_fp8_hook.py +++ b/tests/test_fp8/test_fp8_hook.py @@ -4,8 +4,8 @@ import torch.nn.functional as F from colossalai.accelerator import get_accelerator -from colossalai.quantization.fp8 import linear_fp8 -from colossalai.quantization.fp8_hook import FP8Hook +from colossalai.quantization.fp8 import linear_fp8, linear_fp8_deep_gemm +from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.utils import get_current_device @@ -20,6 +20,12 @@ def new_linear_fp8(x, w, bias=None): return linear_fp8(x, w, bias) +def new_deepgemm_fp8_gemm(lhs, rhs, out=None): + global TRIGGERED + TRIGGERED = True + return linear_fp8_deep_gemm(lhs, rhs, out) + + class FP8TestHook(FP8Hook): def rewrite_op(self, func): func = super().rewrite_op(func) @@ -30,13 +36,26 @@ def rewrite_op(self, func): return func -D_IN, D_OUT = 16, 32 +class DeepGemmTestHook(FP8DeepGemmHook): + def rewrite_op(self, func): + func = super().rewrite_op(func) + if func is linear_fp8_deep_gemm: + global REPLACED + REPLACED = True + return new_deepgemm_fp8_gemm + return func + + +D_IN, D_OUT = 128, 128 B, S = 2, 64 DTYPE = torch.bfloat16 @pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0") def test_fp8_hook(): + global REPLACED, TRIGGERED + REPLACED = False + TRIGGERED = False # create tensors w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE)) x = torch.rand(B, S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True) @@ -48,3 +67,21 @@ def test_fp8_hook(): assert o.shape == (B, S, D_OUT) assert REPLACED assert TRIGGERED + + +@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0") +def test_fp8_deep_gemm_hook(): + global REPLACED, TRIGGERED + REPLACED = False + TRIGGERED = False + # create tensors + w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE)) + x = torch.rand(S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True) + w.__class__ = ColoParameter + w.__init__(w, requires_grad=True) + hook = DeepGemmTestHook() + with ColoParamOpHookManager.use_hooks(hook): + o = F.linear(x, w) + assert o.shape == (S, D_OUT) + assert REPLACED + assert TRIGGERED