Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fused scaled_silu_and_mul (for Mixtral) #221

Merged
merged 2 commits into from
Mar 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions aiter/fused_moe_bf16_asm.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,24 @@ def ck_moe_2stages(a1,
fc1_scale, a1_scale, block_size)

# g1u0
if w2.shape[2] == w1.shape[1]:
a2 = F.gelu(a2)
# g1u1
if (w2.shape[2] != w1.shape[1]) and (w2.dtype == torch.float8_e4m3fnuz):
tmp = torch.empty((M, topk, inter_dim), dtype=torch.float8_e4m3fnuz, device=device)
if a2_scale == None:
a2_scale = torch.empty(1, dtype=torch.float, device=device)
aiter.scaled_silu_and_mul(tmp, a2, a2_scale)
a2 = tmp
else:
if w2.shape[2] == w1.shape[1]:
a2 = F.gelu(a2)
# g1u1
else:
tmp = torch.empty((M, topk, inter_dim), dtype=dtype, device=device)
aiter.silu_and_mul(tmp, a2)
a2 = tmp
if w2.dtype == torch.float8_e4m3fnuz:
if w2.dtype == torch.float8_e4m3fnuz:
a2, a2_scale = aiter.per_tensor_quant_fp8_hip(a2, a2_scale)
# a2, a2_scale = aiter.per_tensor_quant(a2, quant_dtype=w2.dtype)
else:
else:
if not hasattr(ck_moe_2stages, "one_float_tensor"):
ck_moe_2stages.one_float_tensor = torch.tensor(
1.0, dtype=torch.float, device=device)
Expand Down
4 changes: 2 additions & 2 deletions aiter/ops/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

MD_NAME = "module_activation"


@compile_ops("module_activation")
def silu_and_mul(out: Tensor, input: Tensor): ...

@compile_ops("module_activation")
def scaled_silu_and_mul(out: Tensor, input: Tensor, scale: Tensor): ...

@compile_ops("module_activation")
def gelu_and_mul(out: Tensor, input: Tensor): ...


@compile_ops("module_activation")
def gelu_tanh_and_mul(out: Tensor, input: Tensor): ...
3 changes: 2 additions & 1 deletion csrc/include/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
#include <torch/extension.h>

void silu_and_mul(torch::Tensor &out, torch::Tensor &input);
void scaled_silu_and_mul(torch::Tensor &out, torch::Tensor &input, torch::Tensor &scale);
void gelu_and_mul(torch::Tensor &out, torch::Tensor &input);
void gelu_tanh_and_mul(torch::Tensor &out, torch::Tensor &input);
void gelu_tanh_and_mul(torch::Tensor &out, torch::Tensor &input);
1 change: 1 addition & 0 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#define ACTIVATION_PYBIND \
m.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); \
m.def("scaled_silu_and_mul", &scaled_silu_and_mul, "Activation function used in scaled SwiGLU."); \
m.def("gelu_and_mul", &gelu_and_mul, "Activation function used in GELU."); \
m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Activation function used in GELU tanh.");

Expand Down
49 changes: 49 additions & 0 deletions csrc/kernels/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
#include "hip_compat.h"
#include "dispatch_utils.h"

#ifdef USE_ROCM
#include "quant_utils.cuh"
#endif

namespace vllm
{

Expand All @@ -42,6 +46,26 @@ namespace vllm
}
}

// Scaled activation and gating kernel template.
#ifdef USE_ROCM
using fp8_type = __hip_fp8_e4m3;
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void scaled_act_and_mul_kernel(
c10::Float8_e4m3fnuz* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d, const float scale) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
float r = ACT_FN(x) * y * scale;
out[token_idx * d + idx] = c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(r, __HIP_SATFINITE, __HIP_E4M3_FNUZ),
c10::Float8_e4m3fnuz::from_bits());
}
}
#endif

template <typename T>
__device__ __forceinline__ T silu_kernel(const T &x)
{
Expand Down Expand Up @@ -88,13 +112,38 @@ namespace vllm
input.scalar_type(), "act_and_mul_kernel", [&] { vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); });
// Launch activation and gating kernel.
#ifdef USE_ROCM
#define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \
vllm::scaled_act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>( \
out.data_ptr<c10::Float8_e4m3fnuz>(), \
input.data_ptr<scalar_t>(), d, \
1.0 / (*scale.data_ptr<float>())); \
});
#endif

void silu_and_mul(torch::Tensor &out, // [..., d]
torch::Tensor &input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}

void scaled_silu_and_mul(torch::Tensor &out, // [..., d]
torch::Tensor &input, // [..., 2 * d]
torch::Tensor &scale)
{
LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}

void gelu_and_mul(torch::Tensor &out, // [..., d]
torch::Tensor &input) // [..., 2 * d]
{
Expand Down