Files
sglang/python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py

157 lines
4.9 KiB
Python

from typing import Any, Dict, Optional
import torch
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
from sgl_kernel.gemm import (
scaled_fp4_grouped_quant,
silu_and_mul_scaled_fp4_grouped_quant,
)
def get_cute_dtype(input: torch.Tensor) -> str:
if input.dtype == torch.bfloat16:
return "bfloat16"
elif input.dtype == torch.float16:
return "float16"
elif input.dtype == torch.float32:
return "float32"
else:
raise ValueError(f"Unsupported cute dtype {input.dtype}")
def flashinfer_cutedsl_moe_masked(
hidden_states: torch.Tensor,
input_global_scale: torch.Tensor,
w1: torch.Tensor,
w1_blockscale: torch.Tensor,
w1_alpha,
w2: torch.Tensor,
a2_global_scale: torch.Tensor,
w2_blockscale: torch.Tensor,
w2_alpha,
masked_m: torch.Tensor,
):
"""
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
kernels.
Args:
hidden_states (torch.Tensor): [num_experts, m, k], bf16
input_global_scale (torch.Tensor): (l,)
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
w1_alpha (torch.Tensor): (l,)
w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8
a2_global_scale (torch.Tensor): (l,)
w2_blockscale (torch.Tensor): blockscale factors, e4m3,
w2_alpha (torch.Tensor): (l,)
masked_m (torch.Tensor): Masked dimension indices
Notes:
- Assumes max(masked_m) <= m.
"""
# === Assertions on dtypes ===
assert (
input_global_scale.dtype == torch.float32
), f"input_global_scale must be float32, got {input_global_scale.dtype}"
assert w1.dtype == torch.uint8, f"w1 must be uint8 (fp4 packed), got {w1.dtype}"
assert (
w1_blockscale.dtype == torch.float8_e4m3fn
), f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
assert (
w1_alpha.dtype == torch.float32
), f"w1_alpha must be float32, got {w1_alpha.dtype}"
assert w2.dtype == torch.uint8, f"w2 must be uint8 (fp4 packed), got {w2.dtype}"
assert (
a2_global_scale.dtype == torch.float32
), f"a2_global_scale must be float32, got {a2_global_scale.dtype}"
assert (
w2_blockscale.dtype == torch.float8_e4m3fn
), f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}"
assert (
w2_alpha.dtype == torch.float32
), f"w2_alpha must be float32, got {w2_alpha.dtype}"
# === Assertions on shapes ===
n = w2.shape[-1] * 2 # intermediate dimension
num_experts, m, k = hidden_states.shape
assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
assert (
w1.shape[-1] * 2 == k
), f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}"
assert w2.shape[-2:] == (
k,
n // 2,
), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n//2)}"
assert input_global_scale.shape == (
num_experts,
), f"input_global_scale must be (l,), got {input_global_scale.shape}"
assert w1_alpha.shape == (
num_experts,
), f"w1_alpha must be (l,), got {w1_alpha.shape}"
assert a2_global_scale.shape == (
num_experts,
), f"a2_global_scale must be (l,), got {a2_global_scale.shape}"
assert w2_alpha.shape == (
num_experts,
), f"w2_alpha must be (l,), got {w2_alpha.shape}"
aq, aq_sf = scaled_fp4_grouped_quant(
hidden_states,
input_global_scale,
masked_m,
)
gateup_output = torch.empty(
(num_experts, m, n * 2), dtype=hidden_states.dtype, device=aq.device
)
gateup_output = gateup_output.permute(1, 2, 0) # requirement of kernel
sf_vec_size = 16
assert aq_sf.dtype == torch.float8_e4m3fn
assert aq.dtype == torch.uint8
ab_dtype = "float4_e2m1fn"
sf_dtype = "float8_e4m3fn"
c_dtype = get_cute_dtype(hidden_states)
# Gemm1
grouped_gemm_nt_masked(
(aq, aq_sf),
(w1.permute(1, 2, 0), w1_blockscale),
gateup_output,
masked_m,
ab_dtype=ab_dtype,
sf_dtype=sf_dtype,
c_dtype=c_dtype,
sf_vec_size=sf_vec_size,
alpha=w1_alpha.view(1, 1, num_experts),
alpha_dtype=get_cute_dtype(w1_alpha),
) # in logical [m, n, l]
# SILU and quantization
diq, diq_sf = silu_and_mul_scaled_fp4_grouped_quant(
gateup_output.permute(2, 0, 1),
a2_global_scale,
masked_m,
)
# Gemm2
out = torch.empty_like(hidden_states)
out = out.permute(1, 2, 0) # requirement of kernel
grouped_gemm_nt_masked(
(diq, diq_sf),
(w2.permute(1, 2, 0), w2_blockscale),
out,
masked_m,
ab_dtype=ab_dtype,
sf_dtype=sf_dtype,
c_dtype=c_dtype,
sf_vec_size=sf_vec_size,
alpha=w2_alpha.view(1, 1, num_experts),
alpha_dtype=get_cute_dtype(w2_alpha),
) # in logical [m, k, l]
return out.permute(2, 0, 1)