[NVIDIA] [3/N] Nvfp4 Masked Gemm: Add flashinfer grouped_gemm_nt_masked (#9199)
This commit is contained in:
@@ -459,6 +459,8 @@ class DeepEPMoE(EPMoE):
|
||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||
return self.forward_deepgemm_contiguous(dispatch_output)
|
||||
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
||||
if get_moe_runner_backend().is_flashinfer_cutedsl():
|
||||
return self.forward_flashinfer_cutedsl(dispatch_output)
|
||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||
return self.forward_deepgemm_masked(dispatch_output)
|
||||
else:
|
||||
@@ -638,6 +640,22 @@ class DeepEPMoE(EPMoE):
|
||||
|
||||
return gather_out
|
||||
|
||||
def forward_flashinfer_cutedsl(
|
||||
self,
|
||||
dispatch_output: DeepEPLLOutput,
|
||||
):
|
||||
hidden_states, _, _, masked_m, _ = dispatch_output
|
||||
assert self.quant_method is not None
|
||||
assert self.moe_runner_config.activation == "silu"
|
||||
|
||||
output = self.quant_method.apply_without_routing_weights(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
masked_m=masked_m,
|
||||
moe_runner_config=self.moe_runner_config,
|
||||
)
|
||||
return output
|
||||
|
||||
def forward_deepgemm_masked(
|
||||
self,
|
||||
dispatch_output: DeepEPLLOutput,
|
||||
|
||||
156
python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py
Normal file
156
python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
|
||||
from sgl_kernel.gemm import (
|
||||
scaled_fp4_grouped_quant,
|
||||
silu_and_mul_scaled_fp4_grouped_quant,
|
||||
)
|
||||
|
||||
|
||||
def get_cute_dtype(input: torch.Tensor) -> str:
|
||||
if input.dtype == torch.bfloat16:
|
||||
return "bfloat16"
|
||||
elif input.dtype == torch.float16:
|
||||
return "float16"
|
||||
elif input.dtype == torch.float32:
|
||||
return "float32"
|
||||
else:
|
||||
raise ValueError(f"Unsupported cute dtype {input.dtype}")
|
||||
|
||||
|
||||
def flashinfer_cutedsl_moe_masked(
|
||||
hidden_states: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_blockscale: torch.Tensor,
|
||||
w1_alpha,
|
||||
w2: torch.Tensor,
|
||||
a2_global_scale: torch.Tensor,
|
||||
w2_blockscale: torch.Tensor,
|
||||
w2_alpha,
|
||||
masked_m: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
|
||||
kernels.
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): [num_experts, m, k], bf16
|
||||
input_global_scale (torch.Tensor): (l,)
|
||||
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
|
||||
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
|
||||
w1_alpha (torch.Tensor): (l,)
|
||||
w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8
|
||||
a2_global_scale (torch.Tensor): (l,)
|
||||
w2_blockscale (torch.Tensor): blockscale factors, e4m3,
|
||||
w2_alpha (torch.Tensor): (l,)
|
||||
masked_m (torch.Tensor): Masked dimension indices
|
||||
|
||||
Notes:
|
||||
- Assumes max(masked_m) <= m.
|
||||
"""
|
||||
|
||||
# === Assertions on dtypes ===
|
||||
assert (
|
||||
input_global_scale.dtype == torch.float32
|
||||
), f"input_global_scale must be float32, got {input_global_scale.dtype}"
|
||||
assert w1.dtype == torch.uint8, f"w1 must be uint8 (fp4 packed), got {w1.dtype}"
|
||||
assert (
|
||||
w1_blockscale.dtype == torch.float8_e4m3fn
|
||||
), f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
|
||||
assert (
|
||||
w1_alpha.dtype == torch.float32
|
||||
), f"w1_alpha must be float32, got {w1_alpha.dtype}"
|
||||
assert w2.dtype == torch.uint8, f"w2 must be uint8 (fp4 packed), got {w2.dtype}"
|
||||
assert (
|
||||
a2_global_scale.dtype == torch.float32
|
||||
), f"a2_global_scale must be float32, got {a2_global_scale.dtype}"
|
||||
assert (
|
||||
w2_blockscale.dtype == torch.float8_e4m3fn
|
||||
), f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}"
|
||||
assert (
|
||||
w2_alpha.dtype == torch.float32
|
||||
), f"w2_alpha must be float32, got {w2_alpha.dtype}"
|
||||
|
||||
# === Assertions on shapes ===
|
||||
n = w2.shape[-1] * 2 # intermediate dimension
|
||||
num_experts, m, k = hidden_states.shape
|
||||
|
||||
assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
|
||||
assert (
|
||||
w1.shape[-1] * 2 == k
|
||||
), f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}"
|
||||
assert w2.shape[-2:] == (
|
||||
k,
|
||||
n // 2,
|
||||
), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n//2)}"
|
||||
|
||||
assert input_global_scale.shape == (
|
||||
num_experts,
|
||||
), f"input_global_scale must be (l,), got {input_global_scale.shape}"
|
||||
assert w1_alpha.shape == (
|
||||
num_experts,
|
||||
), f"w1_alpha must be (l,), got {w1_alpha.shape}"
|
||||
assert a2_global_scale.shape == (
|
||||
num_experts,
|
||||
), f"a2_global_scale must be (l,), got {a2_global_scale.shape}"
|
||||
assert w2_alpha.shape == (
|
||||
num_experts,
|
||||
), f"w2_alpha must be (l,), got {w2_alpha.shape}"
|
||||
|
||||
aq, aq_sf = scaled_fp4_grouped_quant(
|
||||
hidden_states,
|
||||
input_global_scale,
|
||||
masked_m,
|
||||
)
|
||||
gateup_output = torch.empty(
|
||||
(num_experts, m, n * 2), dtype=hidden_states.dtype, device=aq.device
|
||||
)
|
||||
gateup_output = gateup_output.permute(1, 2, 0) # requirement of kernel
|
||||
sf_vec_size = 16
|
||||
assert aq_sf.dtype == torch.float8_e4m3fn
|
||||
assert aq.dtype == torch.uint8
|
||||
ab_dtype = "float4_e2m1fn"
|
||||
sf_dtype = "float8_e4m3fn"
|
||||
|
||||
c_dtype = get_cute_dtype(hidden_states)
|
||||
|
||||
# Gemm1
|
||||
|
||||
grouped_gemm_nt_masked(
|
||||
(aq, aq_sf),
|
||||
(w1.permute(1, 2, 0), w1_blockscale),
|
||||
gateup_output,
|
||||
masked_m,
|
||||
ab_dtype=ab_dtype,
|
||||
sf_dtype=sf_dtype,
|
||||
c_dtype=c_dtype,
|
||||
sf_vec_size=sf_vec_size,
|
||||
alpha=w1_alpha.view(1, 1, num_experts),
|
||||
alpha_dtype=get_cute_dtype(w1_alpha),
|
||||
) # in logical [m, n, l]
|
||||
|
||||
# SILU and quantization
|
||||
diq, diq_sf = silu_and_mul_scaled_fp4_grouped_quant(
|
||||
gateup_output.permute(2, 0, 1),
|
||||
a2_global_scale,
|
||||
masked_m,
|
||||
)
|
||||
|
||||
# Gemm2
|
||||
out = torch.empty_like(hidden_states)
|
||||
out = out.permute(1, 2, 0) # requirement of kernel
|
||||
grouped_gemm_nt_masked(
|
||||
(diq, diq_sf),
|
||||
(w2.permute(1, 2, 0), w2_blockscale),
|
||||
out,
|
||||
masked_m,
|
||||
ab_dtype=ab_dtype,
|
||||
sf_dtype=sf_dtype,
|
||||
c_dtype=c_dtype,
|
||||
sf_vec_size=sf_vec_size,
|
||||
alpha=w2_alpha.view(1, 1, num_experts),
|
||||
alpha_dtype=get_cute_dtype(w2_alpha),
|
||||
) # in logical [m, k, l]
|
||||
return out.permute(2, 0, 1)
|
||||
@@ -508,7 +508,8 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
hidden_states, masked_m, event, hook = self._dispatch_core(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
use_fp8=True,
|
||||
# TODO(shuw): pending https://github.com/deepseek-ai/DeepEP/pull/341
|
||||
use_fp8=not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"),
|
||||
)
|
||||
return (
|
||||
hidden_states,
|
||||
|
||||
@@ -49,6 +49,7 @@ class MoeRunnerBackend(Enum):
|
||||
FLASHINFER = "flashinfer_trtllm"
|
||||
FLASHINFER_CUTLASS = "flashinfer_cutlass"
|
||||
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
|
||||
FLASHINFER_CUTEDSL = "flashinfer_cutedsl"
|
||||
|
||||
def is_auto(self):
|
||||
return self == MoeRunnerBackend.AUTO
|
||||
@@ -65,6 +66,9 @@ class MoeRunnerBackend(Enum):
|
||||
def is_flashinfer_cutlass(self):
|
||||
return self == MoeRunnerBackend.FLASHINFER_CUTLASS
|
||||
|
||||
def is_flashinfer_cutedsl(self):
|
||||
return self == MoeRunnerBackend.FLASHINFER_CUTEDSL
|
||||
|
||||
def is_flashinfer_mxfp4(self):
|
||||
return self == MoeRunnerBackend.FLASHINFER_MXFP4
|
||||
|
||||
|
||||
@@ -878,6 +878,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
"""Access the global enable_flashinfer_cutlass_moe setting."""
|
||||
return get_moe_runner_backend().is_flashinfer_cutlass()
|
||||
|
||||
@property
|
||||
def enable_flashinfer_cutedsl_moe(self) -> bool:
|
||||
from sglang.srt.layers.moe import get_moe_runner_backend
|
||||
|
||||
"""Access the global enable_flashinfer_cutedsl_moe setting."""
|
||||
return get_moe_runner_backend().is_flashinfer_cutedsl()
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -1398,5 +1405,38 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
|
||||
).to(x.dtype)
|
||||
# Scale by routed_scaling_factor is fused into select_experts.
|
||||
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
def apply_without_routing_weights(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
masked_m: torch.Tensor,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
moe_runner_config.activation == "silu"
|
||||
), "Only SiLU activation is supported."
|
||||
|
||||
assert self.enable_flashinfer_cutedsl_moe, "only support flashinfer cutedsl moe"
|
||||
assert (
|
||||
not moe_runner_config.apply_router_weight_on_input
|
||||
), "apply_router_weight_on_input is not supported for Flashinfer"
|
||||
|
||||
from sglang.srt.layers.moe.flashinfer_cutedsl_moe import (
|
||||
flashinfer_cutedsl_moe_masked,
|
||||
)
|
||||
|
||||
out = flashinfer_cutedsl_moe_masked(
|
||||
hidden_states=x,
|
||||
input_global_scale=layer.w13_input_scale_quant,
|
||||
w1=layer.w13_weight,
|
||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||
w1_alpha=layer.g1_alphas,
|
||||
w2=layer.w2_weight,
|
||||
a2_global_scale=layer.w2_input_scale_quant,
|
||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||
w2_alpha=layer.g2_alphas,
|
||||
masked_m=masked_m,
|
||||
)
|
||||
return out
|
||||
|
||||
@@ -673,10 +673,14 @@ class DeepseekV2MoE(nn.Module):
|
||||
|
||||
if shared_output is not None:
|
||||
x = shared_output
|
||||
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
|
||||
if self.experts.should_fuse_routed_scaling_factor_in_topk():
|
||||
x.add_(final_hidden_states)
|
||||
else:
|
||||
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
|
||||
final_hidden_states = x
|
||||
else:
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
if not self.experts.should_fuse_routed_scaling_factor_in_topk():
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
@@ -399,6 +399,7 @@ class ServerArgs:
|
||||
enable_ep_moe: bool = False
|
||||
enable_deepep_moe: bool = False
|
||||
enable_flashinfer_cutlass_moe: bool = False
|
||||
enable_flashinfer_cutedsl_moe: bool = False
|
||||
enable_flashinfer_trtllm_moe: bool = False
|
||||
enable_triton_kernel_moe: bool = False
|
||||
enable_flashinfer_mxfp4_moe: bool = False
|
||||
@@ -420,6 +421,11 @@ class ServerArgs:
|
||||
print_deprecated_warning(
|
||||
"NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead."
|
||||
)
|
||||
if self.enable_flashinfer_cutedsl_moe:
|
||||
self.moe_runner_backend = "flashinfer_cutedsl"
|
||||
print_deprecated_warning(
|
||||
"NOTE: --enable-flashinfer-cutedsl-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutedsl' instead."
|
||||
)
|
||||
if self.enable_flashinfer_cutlass_moe:
|
||||
self.moe_runner_backend = "flashinfer_cutlass"
|
||||
print_deprecated_warning(
|
||||
@@ -1622,6 +1628,7 @@ class ServerArgs:
|
||||
"flashinfer_trtllm",
|
||||
"flashinfer_cutlass",
|
||||
"flashinfer_mxfp4",
|
||||
"flashinfer_cutedsl",
|
||||
],
|
||||
default=ServerArgs.moe_runner_backend,
|
||||
help="Choose the runner backend for MoE.",
|
||||
@@ -2204,6 +2211,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="(Deprecated) Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-cutedsl-moe",
|
||||
action="store_true",
|
||||
help="(Deprecated) Enable FlashInfer CuteDSL MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-trtllm-moe",
|
||||
action="store_true",
|
||||
|
||||
@@ -3,12 +3,15 @@ from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from flashinfer import fp4_quantize
|
||||
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
|
||||
from sgl_kernel import scaled_fp4_quant
|
||||
from sgl_kernel import scaled_fp4_grouped_quant, scaled_fp4_quant
|
||||
from torch.nn import functional as F
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
||||
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
||||
from sglang.srt.layers.moe.flashinfer_cutedsl_moe import flashinfer_cutedsl_moe_masked
|
||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||
|
||||
if torch.cuda.get_device_capability() < (10, 0):
|
||||
@@ -78,6 +81,37 @@ def break_fp4_bytes(a, dtype):
|
||||
return values.reshape(m, n * 2).to(dtype=dtype)
|
||||
|
||||
|
||||
def compute_routing(router_logits: torch.Tensor, top_k: int):
|
||||
routing_weights = torch.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
routing_weights = routing_weights.float()
|
||||
return routing_weights, selected_experts
|
||||
|
||||
|
||||
def prepare_inputs(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
):
|
||||
routing_weights, topk_idx = compute_routing(router_logits, topk)
|
||||
|
||||
masked_m = []
|
||||
for i in range(num_experts):
|
||||
mask = topk_idx.view(-1) == i
|
||||
masked_m.append(mask.sum())
|
||||
|
||||
masked_m = torch.tensor(masked_m, dtype=torch.int32)
|
||||
hidden_states_3d = torch.empty(
|
||||
(num_experts, max(masked_m), hidden_states.shape[1]), dtype=hidden_states.dtype
|
||||
)
|
||||
for i in range(num_experts):
|
||||
hidden_states_3d[i, : masked_m[i], :] = hidden_states[topk_idx.view(-1) == i]
|
||||
|
||||
return hidden_states_3d, masked_m, topk_idx, routing_weights
|
||||
|
||||
|
||||
MNK_FACTORS = [
|
||||
(2, 1024, 1024),
|
||||
(2, 1024, 1536),
|
||||
@@ -114,6 +148,99 @@ def torch_moe(a, w1, w2, score, topk, expert_map):
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
m = w1[i].shape[0]
|
||||
assert m % 2 == 0
|
||||
# Note: w1 and w3 are swapped!
|
||||
w3_expert, w1_expert = w1[i][m // 2 :, :], w1[i][: m // 2, :]
|
||||
inter = F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t())
|
||||
inter_gs = torch.tensor(1.0).cuda()
|
||||
inter_q, inter_blockscale = fp4_quantize(inter, inter_gs)
|
||||
inter = dequantize_nvfp4_to_dtype(
|
||||
inter_q,
|
||||
inter_blockscale,
|
||||
inter_gs,
|
||||
dtype=inter.dtype,
|
||||
device=inter.device,
|
||||
block_size=16,
|
||||
).cuda()
|
||||
out[mask] = inter @ w2[i].transpose(0, 1)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
def flashinfer_cutedsl_grouped_gemm_nt_masked(
|
||||
hidden_states: torch.Tensor, # 3d
|
||||
input_global_scale: torch.Tensor, # (l,)
|
||||
weights: torch.Tensor,
|
||||
w_global_scale: torch.Tensor, # (l,)
|
||||
masked_m: torch.Tensor,
|
||||
):
|
||||
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
|
||||
|
||||
# hidden_states: [l, m, k]
|
||||
# weights: [l, n, k]
|
||||
aq, aq_sf = scaled_fp4_grouped_quant(
|
||||
hidden_states,
|
||||
input_global_scale,
|
||||
masked_m.to(hidden_states.device),
|
||||
)
|
||||
num_experts, n, k = weights.shape
|
||||
bq, bq_sf = scaled_fp4_grouped_quant(
|
||||
weights,
|
||||
w_global_scale,
|
||||
torch.ones(num_experts, device=weights.device, dtype=torch.int32) * n,
|
||||
)
|
||||
|
||||
out = torch.zeros(
|
||||
(num_experts, max(masked_m), n), dtype=weights.dtype, device=aq.device
|
||||
)
|
||||
out = out.permute(1, 2, 0) # requirement of kernel
|
||||
sf_vec_size = 16
|
||||
ab_dtype = "float4_e2m1fn"
|
||||
sf_dtype = "float8_e4m3fn"
|
||||
c_dtype = "bfloat16"
|
||||
alpha = 1.0 / (input_global_scale * w_global_scale).to(out.dtype).view(
|
||||
1, 1, num_experts
|
||||
)
|
||||
|
||||
def get_cute_dtype(input: torch.Tensor) -> str:
|
||||
if input.dtype == torch.bfloat16:
|
||||
return "bfloat16"
|
||||
elif input.dtype == torch.float16:
|
||||
return "float16"
|
||||
elif input.dtype == torch.float32:
|
||||
return "float32"
|
||||
else:
|
||||
raise ValueError(f"Unsupported cute dtype {input.dtype}")
|
||||
|
||||
grouped_gemm_nt_masked(
|
||||
(aq, aq_sf),
|
||||
(bq, bq_sf),
|
||||
out,
|
||||
masked_m.to(aq.device),
|
||||
ab_dtype=ab_dtype,
|
||||
sf_dtype=sf_dtype,
|
||||
c_dtype=c_dtype,
|
||||
sf_vec_size=sf_vec_size,
|
||||
alpha=alpha,
|
||||
alpha_dtype=get_cute_dtype(alpha),
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def check_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
@@ -324,6 +451,248 @@ def test_flashinfer_fp4_moe_no_graph(
|
||||
check_moe(m, n, k, e, topk, dtype, flashinfer_moe_impl, flip_w13=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bs, hidden_dim, inter_dim", [(2, 128, 256), (16, 128, 512)])
|
||||
@pytest.mark.parametrize("topk", [1, 2, 4])
|
||||
@torch.inference_mode()
|
||||
def test_flashinfer_cutedsl_moe_masked(
|
||||
bs: int, hidden_dim: int, inter_dim: int, topk: int
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
num_experts = 8
|
||||
hidden_states = (
|
||||
torch.randn(bs, hidden_dim, dtype=torch.bfloat16, device=device) / 5.0
|
||||
)
|
||||
w1 = (
|
||||
torch.randn(
|
||||
num_experts, 2 * inter_dim, hidden_dim, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
/ 10.0
|
||||
)
|
||||
w2 = (
|
||||
torch.randn(
|
||||
num_experts, hidden_dim, inter_dim, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
/ 10.0
|
||||
)
|
||||
router_logits = torch.randn(bs, num_experts, dtype=torch.float32)
|
||||
|
||||
hidden_states_expanded = (
|
||||
hidden_states.view(bs, -1, hidden_dim)
|
||||
.repeat(1, topk, 1)
|
||||
.reshape(-1, hidden_dim)
|
||||
)
|
||||
hidden_states_3d, masked_m, topk_idx, routing_weights = prepare_inputs(
|
||||
hidden_states_expanded, router_logits, num_experts, topk
|
||||
)
|
||||
|
||||
w1_amax = w1.abs().amax(dim=(1, 2)).to(torch.float32).to(w1.device)
|
||||
w2_amax = w2.abs().amax(dim=(1, 2)).to(torch.float32).to(w2.device)
|
||||
input_global_scale = torch.ones(
|
||||
(num_experts,), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
|
||||
w1_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
|
||||
w2_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
|
||||
a2_global_scale = torch.ones(
|
||||
(num_experts,), dtype=torch.float32, device=hidden_states.device
|
||||
) # assume intermediate scale is 1.0
|
||||
|
||||
w1_fp4, w1_blockscale = scaled_fp4_grouped_quant(
|
||||
w1,
|
||||
w1_global_scale,
|
||||
torch.ones(num_experts, dtype=torch.int32, device=w1.device) * 2 * inter_dim,
|
||||
)
|
||||
w2_fp4, w2_blockscale = scaled_fp4_grouped_quant(
|
||||
w2,
|
||||
w2_global_scale,
|
||||
torch.ones(num_experts, dtype=torch.int32, device=w2.device) * hidden_dim,
|
||||
)
|
||||
|
||||
w1_alpha = 1.0 / (input_global_scale * w1_global_scale)
|
||||
w2_alpha = 1.0 / (a2_global_scale * w2_global_scale)
|
||||
|
||||
out = flashinfer_cutedsl_moe_masked(
|
||||
hidden_states_3d.to(hidden_states.device),
|
||||
input_global_scale,
|
||||
w1_fp4.permute(2, 0, 1),
|
||||
w1_blockscale,
|
||||
w1_alpha,
|
||||
w2_fp4.permute(2, 0, 1),
|
||||
a2_global_scale,
|
||||
w2_blockscale,
|
||||
w2_alpha,
|
||||
masked_m.to(hidden_states.device),
|
||||
)
|
||||
|
||||
# reference
|
||||
a_fp4, a_scale_interleaved = fp4_quantize(hidden_states, input_global_scale)
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
a_fp4,
|
||||
a_scale_interleaved,
|
||||
input_global_scale,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
block_size=16,
|
||||
)
|
||||
w1_d = torch.empty(
|
||||
(num_experts, 2 * inter_dim, hidden_dim), device=w1.device, dtype=w1.dtype
|
||||
)
|
||||
w2_d = torch.empty(
|
||||
(num_experts, hidden_dim, inter_dim), device=w2.device, dtype=w2.dtype
|
||||
)
|
||||
|
||||
for idx in range(0, num_experts):
|
||||
w1_fp4_sliced, w1_blockscale_sliced = fp4_quantize(
|
||||
w1[idx], w1_global_scale[idx]
|
||||
)
|
||||
w2_fp4_sliced, w2_blockscale_sliced = fp4_quantize(
|
||||
w2[idx], w2_global_scale[idx]
|
||||
)
|
||||
w1_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w1_fp4_sliced,
|
||||
w1_blockscale_sliced,
|
||||
w1_global_scale[idx],
|
||||
dtype=w1.dtype,
|
||||
device=w1.device,
|
||||
block_size=16,
|
||||
)
|
||||
w2_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w2_fp4_sliced,
|
||||
w2_blockscale_sliced,
|
||||
w2_global_scale[idx],
|
||||
dtype=w2.dtype,
|
||||
device=w2.device,
|
||||
block_size=16,
|
||||
)
|
||||
|
||||
ref_output = torch_moe_nvfp4(
|
||||
a_in_dtype,
|
||||
w1_d,
|
||||
w2_d,
|
||||
topk,
|
||||
routing_weights.to(a_in_dtype.device),
|
||||
topk_idx.to(a_in_dtype.device),
|
||||
)
|
||||
out_weighted = torch.zeros_like(ref_output, device=out.device, dtype=out.dtype)
|
||||
|
||||
positions = torch.nonzero(masked_m[topk_idx], as_tuple=False)
|
||||
rows, cols = positions[:, 0], positions[:, 1]
|
||||
experts = topk_idx[rows, cols]
|
||||
for i in range(num_experts):
|
||||
mask = experts == i
|
||||
if mask.any():
|
||||
idx = torch.nonzero(mask, as_tuple=False).squeeze(-1)
|
||||
r, c = rows[idx], cols[idx]
|
||||
out_weighted[r] += out[i, : len(r), :] * routing_weights[r, c].to(
|
||||
out.device
|
||||
).unsqueeze(-1)
|
||||
torch.testing.assert_close(
|
||||
out_weighted.cpu(), ref_output.cpu(), atol=5e-2, rtol=5e-2
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bs, hidden_dim, inter_dim, topk", [(2, 128, 256, 2), (16, 128, 512, 5)]
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_grouped_gemm_nt_masked(
|
||||
bs: int, hidden_dim: int, inter_dim: int, topk: int
|
||||
) -> None:
|
||||
torch.manual_seed(42)
|
||||
B = bs
|
||||
D = hidden_dim
|
||||
N = inter_dim
|
||||
num_experts = 8
|
||||
hidden_states = torch.randn(B, D, dtype=torch.bfloat16, device="cuda")
|
||||
weights = torch.randn(num_experts, N, D, dtype=torch.bfloat16, device="cuda")
|
||||
router_logits = torch.randn(B, num_experts, dtype=torch.float32)
|
||||
|
||||
hidden_states_expanded = (
|
||||
hidden_states.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
)
|
||||
hidden_states_3d, masked_m, topk_idx, _ = prepare_inputs(
|
||||
hidden_states_expanded, router_logits, num_experts, topk
|
||||
)
|
||||
|
||||
# reference
|
||||
out = torch.zeros(
|
||||
(B * topk, weights.shape[1]), dtype=weights.dtype, device=weights.device
|
||||
)
|
||||
for i in range(num_experts):
|
||||
mask = topk_idx.view(-1) == i
|
||||
if mask.sum():
|
||||
lhs = hidden_states_expanded[mask]
|
||||
rhs = weights[i]
|
||||
a_amax = lhs.abs().max().to(torch.float32).to(hidden_states.device)
|
||||
b_amax = rhs.abs().amax().to(torch.float32).to(weights.device)
|
||||
a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
|
||||
b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
|
||||
|
||||
lhsq, lhsq_sf = fp4_quantize(
|
||||
lhs,
|
||||
a_gs,
|
||||
)
|
||||
rhsq, rhsq_sf = fp4_quantize(
|
||||
rhs,
|
||||
b_gs,
|
||||
)
|
||||
|
||||
lhs_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
lhsq,
|
||||
lhsq_sf,
|
||||
a_gs,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
block_size=16,
|
||||
)
|
||||
|
||||
rhs_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
rhsq,
|
||||
rhsq_sf,
|
||||
b_gs,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
block_size=16,
|
||||
)
|
||||
out[mask] = lhs_in_dtype @ rhs_in_dtype.t()
|
||||
|
||||
a_amax = (
|
||||
hidden_states_3d.abs()
|
||||
.amax(dim=(1, 2))
|
||||
.to(torch.float32)
|
||||
.to(hidden_states.device)
|
||||
)
|
||||
b_amax = weights.abs().amax(dim=(1, 2)).to(torch.float32).to(weights.device)
|
||||
a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
|
||||
b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
|
||||
out_flashinfer = flashinfer_cutedsl_grouped_gemm_nt_masked(
|
||||
hidden_states_3d.to(hidden_states.device), a_gs, weights, b_gs, masked_m
|
||||
)
|
||||
|
||||
# re-pack out into [num_experts, max_m, n]
|
||||
out_ref = torch.zeros(
|
||||
(num_experts, max(masked_m), weights.shape[1]), dtype=out.dtype
|
||||
)
|
||||
expert_slot = [0] * num_experts
|
||||
for i, expert_id in enumerate(topk_idx.view(-1).tolist()):
|
||||
out_ref[expert_id, expert_slot[expert_id], :] = out[i]
|
||||
expert_slot[expert_id] += 1
|
||||
|
||||
# Note: just to compare the masked position due to cutedsl may write nan
|
||||
# into unmasked position.
|
||||
for i in range(num_experts):
|
||||
torch.testing.assert_close(
|
||||
out_flashinfer.permute(2, 0, 1)[i, : masked_m[i]],
|
||||
out_ref.to(out_flashinfer.device)[i, : masked_m[i]],
|
||||
atol=1e-1,
|
||||
rtol=5e-2,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_cutlass_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half)
|
||||
test_flashinfer_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half)
|
||||
test_flashinfer_cutedsl_moe_masked(16, 128, 512, 4)
|
||||
test_grouped_gemm_nt_masked(16, 128, 512, 4)
|
||||
|
||||
@@ -53,6 +53,9 @@ DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instru
|
||||
DEFAULT_MODEL_NAME_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test"
|
||||
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN = "lmsys/sglang-ci-dsv3-test-NextN"
|
||||
|
||||
# NVFP4 models
|
||||
DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST = "nvidia/DeepSeek-R1-0528-FP4"
|
||||
|
||||
# FP8 models
|
||||
DEFAULT_MODEL_NAME_FOR_TEST_FP8 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8"
|
||||
DEFAULT_MODEL_NAME_FOR_ACCURACY_TEST_FP8 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8"
|
||||
|
||||
Reference in New Issue
Block a user