diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index f22685454..e332aac1f 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -40,6 +40,11 @@ SGLang supports various environment variables that can be used to configure its | `SGL_DG_USE_NVRTC` | Use NVRTC (instead of Triton) for JIT compilation (Experimental) | `"0"` | | `SGL_USE_DEEPGEMM_BMM` | Use DeepGEMM for Batched Matrix Multiplication (BMM) operations | `"false"` | +## DeepEP Configuration + +| Environment Variable | Description | Default Value | +| `SGLANG_DEEPEP_BF16_DISPATCH` | Use Bfloat16 for dispatch | `"false"` | + ## Memory Management | Environment Variable | Description | Default Value | diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index ef33665c3..4303fcd4e 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -459,6 +459,8 @@ class DeepEPMoE(EPMoE): assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 return self.forward_deepgemm_contiguous(dispatch_output) elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output): + if get_moe_runner_backend().is_flashinfer_cutedsl(): + return self.forward_flashinfer_cutedsl(dispatch_output) assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 return self.forward_deepgemm_masked(dispatch_output) else: @@ -638,6 +640,22 @@ class DeepEPMoE(EPMoE): return gather_out + def forward_flashinfer_cutedsl( + self, + dispatch_output: DeepEPLLOutput, + ): + hidden_states, _, _, masked_m, _ = dispatch_output + assert self.quant_method is not None + assert self.moe_runner_config.activation == "silu" + + output = self.quant_method.apply_without_routing_weights( + layer=self, + x=hidden_states, + masked_m=masked_m, + moe_runner_config=self.moe_runner_config, + ) + return output + def forward_deepgemm_masked( self, dispatch_output: DeepEPLLOutput, diff --git a/python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py b/python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py new file mode 100644 index 000000000..c8813ff6f --- /dev/null +++ b/python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py @@ -0,0 +1,156 @@ +from typing import Any, Dict, Optional + +import torch +from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked +from sgl_kernel.gemm import ( + scaled_fp4_grouped_quant, + silu_and_mul_scaled_fp4_grouped_quant, +) + + +def get_cute_dtype(input: torch.Tensor) -> str: + if input.dtype == torch.bfloat16: + return "bfloat16" + elif input.dtype == torch.float16: + return "float16" + elif input.dtype == torch.float32: + return "float32" + else: + raise ValueError(f"Unsupported cute dtype {input.dtype}") + + +def flashinfer_cutedsl_moe_masked( + hidden_states: torch.Tensor, + input_global_scale: torch.Tensor, + w1: torch.Tensor, + w1_blockscale: torch.Tensor, + w1_alpha, + w2: torch.Tensor, + a2_global_scale: torch.Tensor, + w2_blockscale: torch.Tensor, + w2_alpha, + masked_m: torch.Tensor, +): + """ + Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL + kernels. + + Args: + hidden_states (torch.Tensor): [num_experts, m, k], bf16 + input_global_scale (torch.Tensor): (l,) + w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8 + w1_blockscale (torch.Tensor): blockscale factors, e4m3, + w1_alpha (torch.Tensor): (l,) + w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8 + a2_global_scale (torch.Tensor): (l,) + w2_blockscale (torch.Tensor): blockscale factors, e4m3, + w2_alpha (torch.Tensor): (l,) + masked_m (torch.Tensor): Masked dimension indices + + Notes: + - Assumes max(masked_m) <= m. + """ + + # === Assertions on dtypes === + assert ( + input_global_scale.dtype == torch.float32 + ), f"input_global_scale must be float32, got {input_global_scale.dtype}" + assert w1.dtype == torch.uint8, f"w1 must be uint8 (fp4 packed), got {w1.dtype}" + assert ( + w1_blockscale.dtype == torch.float8_e4m3fn + ), f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}" + assert ( + w1_alpha.dtype == torch.float32 + ), f"w1_alpha must be float32, got {w1_alpha.dtype}" + assert w2.dtype == torch.uint8, f"w2 must be uint8 (fp4 packed), got {w2.dtype}" + assert ( + a2_global_scale.dtype == torch.float32 + ), f"a2_global_scale must be float32, got {a2_global_scale.dtype}" + assert ( + w2_blockscale.dtype == torch.float8_e4m3fn + ), f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}" + assert ( + w2_alpha.dtype == torch.float32 + ), f"w2_alpha must be float32, got {w2_alpha.dtype}" + + # === Assertions on shapes === + n = w2.shape[-1] * 2 # intermediate dimension + num_experts, m, k = hidden_states.shape + + assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}" + assert ( + w1.shape[-1] * 2 == k + ), f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}" + assert w2.shape[-2:] == ( + k, + n // 2, + ), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n//2)}" + + assert input_global_scale.shape == ( + num_experts, + ), f"input_global_scale must be (l,), got {input_global_scale.shape}" + assert w1_alpha.shape == ( + num_experts, + ), f"w1_alpha must be (l,), got {w1_alpha.shape}" + assert a2_global_scale.shape == ( + num_experts, + ), f"a2_global_scale must be (l,), got {a2_global_scale.shape}" + assert w2_alpha.shape == ( + num_experts, + ), f"w2_alpha must be (l,), got {w2_alpha.shape}" + + aq, aq_sf = scaled_fp4_grouped_quant( + hidden_states, + input_global_scale, + masked_m, + ) + gateup_output = torch.empty( + (num_experts, m, n * 2), dtype=hidden_states.dtype, device=aq.device + ) + gateup_output = gateup_output.permute(1, 2, 0) # requirement of kernel + sf_vec_size = 16 + assert aq_sf.dtype == torch.float8_e4m3fn + assert aq.dtype == torch.uint8 + ab_dtype = "float4_e2m1fn" + sf_dtype = "float8_e4m3fn" + + c_dtype = get_cute_dtype(hidden_states) + + # Gemm1 + + grouped_gemm_nt_masked( + (aq, aq_sf), + (w1.permute(1, 2, 0), w1_blockscale), + gateup_output, + masked_m, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + sf_vec_size=sf_vec_size, + alpha=w1_alpha.view(1, 1, num_experts), + alpha_dtype=get_cute_dtype(w1_alpha), + ) # in logical [m, n, l] + + # SILU and quantization + diq, diq_sf = silu_and_mul_scaled_fp4_grouped_quant( + gateup_output.permute(2, 0, 1), + a2_global_scale, + masked_m, + ) + + # Gemm2 + out = torch.empty_like(hidden_states) + out = out.permute(1, 2, 0) # requirement of kernel + grouped_gemm_nt_masked( + (diq, diq_sf), + (w2.permute(1, 2, 0), w2_blockscale), + out, + masked_m, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + sf_vec_size=sf_vec_size, + alpha=w2_alpha.view(1, 1, num_experts), + alpha_dtype=get_cute_dtype(w2_alpha), + ) # in logical [m, k, l] + return out.permute(2, 0, 1) diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py index 450cff0cb..64ade6546 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py @@ -508,7 +508,8 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): hidden_states, masked_m, event, hook = self._dispatch_core( hidden_states, topk_idx, - use_fp8=True, + # TODO(shuw): pending https://github.com/deepseek-ai/DeepEP/pull/341 + use_fp8=not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"), ) return ( hidden_states, diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index b4e4ec424..9fd6e2646 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -49,6 +49,7 @@ class MoeRunnerBackend(Enum): FLASHINFER = "flashinfer_trtllm" FLASHINFER_CUTLASS = "flashinfer_cutlass" FLASHINFER_MXFP4 = "flashinfer_mxfp4" + FLASHINFER_CUTEDSL = "flashinfer_cutedsl" def is_auto(self): return self == MoeRunnerBackend.AUTO @@ -65,6 +66,9 @@ class MoeRunnerBackend(Enum): def is_flashinfer_cutlass(self): return self == MoeRunnerBackend.FLASHINFER_CUTLASS + def is_flashinfer_cutedsl(self): + return self == MoeRunnerBackend.FLASHINFER_CUTEDSL + def is_flashinfer_mxfp4(self): return self == MoeRunnerBackend.FLASHINFER_MXFP4 diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 89ecd44f5..45a4ba559 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -878,6 +878,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): """Access the global enable_flashinfer_cutlass_moe setting.""" return get_moe_runner_backend().is_flashinfer_cutlass() + @property + def enable_flashinfer_cutedsl_moe(self) -> bool: + from sglang.srt.layers.moe import get_moe_runner_backend + + """Access the global enable_flashinfer_cutedsl_moe setting.""" + return get_moe_runner_backend().is_flashinfer_cutedsl() + def create_weights( self, layer: torch.nn.Module, @@ -1398,5 +1405,38 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input, ).to(x.dtype) # Scale by routed_scaling_factor is fused into select_experts. - return StandardCombineInput(hidden_states=output) + + def apply_without_routing_weights( + self, + layer: FusedMoE, + x: torch.Tensor, + masked_m: torch.Tensor, + moe_runner_config: MoeRunnerConfig, + ) -> torch.Tensor: + assert ( + moe_runner_config.activation == "silu" + ), "Only SiLU activation is supported." + + assert self.enable_flashinfer_cutedsl_moe, "only support flashinfer cutedsl moe" + assert ( + not moe_runner_config.apply_router_weight_on_input + ), "apply_router_weight_on_input is not supported for Flashinfer" + + from sglang.srt.layers.moe.flashinfer_cutedsl_moe import ( + flashinfer_cutedsl_moe_masked, + ) + + out = flashinfer_cutedsl_moe_masked( + hidden_states=x, + input_global_scale=layer.w13_input_scale_quant, + w1=layer.w13_weight, + w1_blockscale=layer.w13_blockscale_swizzled, + w1_alpha=layer.g1_alphas, + w2=layer.w2_weight, + a2_global_scale=layer.w2_input_scale_quant, + w2_blockscale=layer.w2_blockscale_swizzled, + w2_alpha=layer.g2_alphas, + masked_m=masked_m, + ) + return out diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index b5535f6d3..5452e7e8c 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -673,10 +673,14 @@ class DeepseekV2MoE(nn.Module): if shared_output is not None: x = shared_output - x.add_(final_hidden_states, alpha=self.routed_scaling_factor) + if self.experts.should_fuse_routed_scaling_factor_in_topk(): + x.add_(final_hidden_states) + else: + x.add_(final_hidden_states, alpha=self.routed_scaling_factor) final_hidden_states = x else: - final_hidden_states *= self.routed_scaling_factor + if not self.experts.should_fuse_routed_scaling_factor_in_topk(): + final_hidden_states *= self.routed_scaling_factor return final_hidden_states diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 7d713bedf..68061ae97 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -399,6 +399,7 @@ class ServerArgs: enable_ep_moe: bool = False enable_deepep_moe: bool = False enable_flashinfer_cutlass_moe: bool = False + enable_flashinfer_cutedsl_moe: bool = False enable_flashinfer_trtllm_moe: bool = False enable_triton_kernel_moe: bool = False enable_flashinfer_mxfp4_moe: bool = False @@ -420,6 +421,11 @@ class ServerArgs: print_deprecated_warning( "NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead." ) + if self.enable_flashinfer_cutedsl_moe: + self.moe_runner_backend = "flashinfer_cutedsl" + print_deprecated_warning( + "NOTE: --enable-flashinfer-cutedsl-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutedsl' instead." + ) if self.enable_flashinfer_cutlass_moe: self.moe_runner_backend = "flashinfer_cutlass" print_deprecated_warning( @@ -1622,6 +1628,7 @@ class ServerArgs: "flashinfer_trtllm", "flashinfer_cutlass", "flashinfer_mxfp4", + "flashinfer_cutedsl", ], default=ServerArgs.moe_runner_backend, help="Choose the runner backend for MoE.", @@ -2204,6 +2211,11 @@ class ServerArgs: action="store_true", help="(Deprecated) Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP", ) + parser.add_argument( + "--enable-flashinfer-cutedsl-moe", + action="store_true", + help="(Deprecated) Enable FlashInfer CuteDSL MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP", + ) parser.add_argument( "--enable-flashinfer-trtllm-moe", action="store_true", diff --git a/python/sglang/test/test_fp4_moe.py b/python/sglang/test/test_fp4_moe.py index 8f8c8e8a7..e0c616807 100644 --- a/python/sglang/test/test_fp4_moe.py +++ b/python/sglang/test/test_fp4_moe.py @@ -3,12 +3,15 @@ from typing import Callable import pytest import torch +from flashinfer import fp4_quantize from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe -from sgl_kernel import scaled_fp4_quant +from sgl_kernel import scaled_fp4_grouped_quant, scaled_fp4_quant +from torch.nn import functional as F from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType +from sglang.srt.layers.moe.flashinfer_cutedsl_moe import flashinfer_cutedsl_moe_masked from sglang.srt.layers.moe.topk import TopKConfig, select_experts if torch.cuda.get_device_capability() < (10, 0): @@ -78,6 +81,37 @@ def break_fp4_bytes(a, dtype): return values.reshape(m, n * 2).to(dtype=dtype) +def compute_routing(router_logits: torch.Tensor, top_k: int): + routing_weights = torch.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.float() + return routing_weights, selected_experts + + +def prepare_inputs( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + num_experts: int, + topk: int, +): + routing_weights, topk_idx = compute_routing(router_logits, topk) + + masked_m = [] + for i in range(num_experts): + mask = topk_idx.view(-1) == i + masked_m.append(mask.sum()) + + masked_m = torch.tensor(masked_m, dtype=torch.int32) + hidden_states_3d = torch.empty( + (num_experts, max(masked_m), hidden_states.shape[1]), dtype=hidden_states.dtype + ) + for i in range(num_experts): + hidden_states_3d[i, : masked_m[i], :] = hidden_states[topk_idx.view(-1) == i] + + return hidden_states_3d, masked_m, topk_idx, routing_weights + + MNK_FACTORS = [ (2, 1024, 1024), (2, 1024, 1536), @@ -114,6 +148,99 @@ def torch_moe(a, w1, w2, score, topk, expert_map): ).sum(dim=1) +def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + m = w1[i].shape[0] + assert m % 2 == 0 + # Note: w1 and w3 are swapped! + w3_expert, w1_expert = w1[i][m // 2 :, :], w1[i][: m // 2, :] + inter = F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t()) + inter_gs = torch.tensor(1.0).cuda() + inter_q, inter_blockscale = fp4_quantize(inter, inter_gs) + inter = dequantize_nvfp4_to_dtype( + inter_q, + inter_blockscale, + inter_gs, + dtype=inter.dtype, + device=inter.device, + block_size=16, + ).cuda() + out[mask] = inter @ w2[i].transpose(0, 1) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) + + +def flashinfer_cutedsl_grouped_gemm_nt_masked( + hidden_states: torch.Tensor, # 3d + input_global_scale: torch.Tensor, # (l,) + weights: torch.Tensor, + w_global_scale: torch.Tensor, # (l,) + masked_m: torch.Tensor, +): + from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked + + # hidden_states: [l, m, k] + # weights: [l, n, k] + aq, aq_sf = scaled_fp4_grouped_quant( + hidden_states, + input_global_scale, + masked_m.to(hidden_states.device), + ) + num_experts, n, k = weights.shape + bq, bq_sf = scaled_fp4_grouped_quant( + weights, + w_global_scale, + torch.ones(num_experts, device=weights.device, dtype=torch.int32) * n, + ) + + out = torch.zeros( + (num_experts, max(masked_m), n), dtype=weights.dtype, device=aq.device + ) + out = out.permute(1, 2, 0) # requirement of kernel + sf_vec_size = 16 + ab_dtype = "float4_e2m1fn" + sf_dtype = "float8_e4m3fn" + c_dtype = "bfloat16" + alpha = 1.0 / (input_global_scale * w_global_scale).to(out.dtype).view( + 1, 1, num_experts + ) + + def get_cute_dtype(input: torch.Tensor) -> str: + if input.dtype == torch.bfloat16: + return "bfloat16" + elif input.dtype == torch.float16: + return "float16" + elif input.dtype == torch.float32: + return "float32" + else: + raise ValueError(f"Unsupported cute dtype {input.dtype}") + + grouped_gemm_nt_masked( + (aq, aq_sf), + (bq, bq_sf), + out, + masked_m.to(aq.device), + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + sf_vec_size=sf_vec_size, + alpha=alpha, + alpha_dtype=get_cute_dtype(alpha), + ) + + return out + + def check_moe( m: int, n: int, @@ -324,6 +451,248 @@ def test_flashinfer_fp4_moe_no_graph( check_moe(m, n, k, e, topk, dtype, flashinfer_moe_impl, flip_w13=True) +@pytest.mark.parametrize("bs, hidden_dim, inter_dim", [(2, 128, 256), (16, 128, 512)]) +@pytest.mark.parametrize("topk", [1, 2, 4]) +@torch.inference_mode() +def test_flashinfer_cutedsl_moe_masked( + bs: int, hidden_dim: int, inter_dim: int, topk: int +): + torch.manual_seed(42) + device = "cuda" + dtype = torch.bfloat16 + num_experts = 8 + hidden_states = ( + torch.randn(bs, hidden_dim, dtype=torch.bfloat16, device=device) / 5.0 + ) + w1 = ( + torch.randn( + num_experts, 2 * inter_dim, hidden_dim, dtype=torch.bfloat16, device=device + ) + / 10.0 + ) + w2 = ( + torch.randn( + num_experts, hidden_dim, inter_dim, dtype=torch.bfloat16, device=device + ) + / 10.0 + ) + router_logits = torch.randn(bs, num_experts, dtype=torch.float32) + + hidden_states_expanded = ( + hidden_states.view(bs, -1, hidden_dim) + .repeat(1, topk, 1) + .reshape(-1, hidden_dim) + ) + hidden_states_3d, masked_m, topk_idx, routing_weights = prepare_inputs( + hidden_states_expanded, router_logits, num_experts, topk + ) + + w1_amax = w1.abs().amax(dim=(1, 2)).to(torch.float32).to(w1.device) + w2_amax = w2.abs().amax(dim=(1, 2)).to(torch.float32).to(w2.device) + input_global_scale = torch.ones( + (num_experts,), dtype=torch.float32, device=hidden_states.device + ) + + w1_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax + w2_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax + a2_global_scale = torch.ones( + (num_experts,), dtype=torch.float32, device=hidden_states.device + ) # assume intermediate scale is 1.0 + + w1_fp4, w1_blockscale = scaled_fp4_grouped_quant( + w1, + w1_global_scale, + torch.ones(num_experts, dtype=torch.int32, device=w1.device) * 2 * inter_dim, + ) + w2_fp4, w2_blockscale = scaled_fp4_grouped_quant( + w2, + w2_global_scale, + torch.ones(num_experts, dtype=torch.int32, device=w2.device) * hidden_dim, + ) + + w1_alpha = 1.0 / (input_global_scale * w1_global_scale) + w2_alpha = 1.0 / (a2_global_scale * w2_global_scale) + + out = flashinfer_cutedsl_moe_masked( + hidden_states_3d.to(hidden_states.device), + input_global_scale, + w1_fp4.permute(2, 0, 1), + w1_blockscale, + w1_alpha, + w2_fp4.permute(2, 0, 1), + a2_global_scale, + w2_blockscale, + w2_alpha, + masked_m.to(hidden_states.device), + ) + + # reference + a_fp4, a_scale_interleaved = fp4_quantize(hidden_states, input_global_scale) + a_in_dtype = dequantize_nvfp4_to_dtype( + a_fp4, + a_scale_interleaved, + input_global_scale, + dtype=hidden_states.dtype, + device=hidden_states.device, + block_size=16, + ) + w1_d = torch.empty( + (num_experts, 2 * inter_dim, hidden_dim), device=w1.device, dtype=w1.dtype + ) + w2_d = torch.empty( + (num_experts, hidden_dim, inter_dim), device=w2.device, dtype=w2.dtype + ) + + for idx in range(0, num_experts): + w1_fp4_sliced, w1_blockscale_sliced = fp4_quantize( + w1[idx], w1_global_scale[idx] + ) + w2_fp4_sliced, w2_blockscale_sliced = fp4_quantize( + w2[idx], w2_global_scale[idx] + ) + w1_d[idx] = dequantize_nvfp4_to_dtype( + w1_fp4_sliced, + w1_blockscale_sliced, + w1_global_scale[idx], + dtype=w1.dtype, + device=w1.device, + block_size=16, + ) + w2_d[idx] = dequantize_nvfp4_to_dtype( + w2_fp4_sliced, + w2_blockscale_sliced, + w2_global_scale[idx], + dtype=w2.dtype, + device=w2.device, + block_size=16, + ) + + ref_output = torch_moe_nvfp4( + a_in_dtype, + w1_d, + w2_d, + topk, + routing_weights.to(a_in_dtype.device), + topk_idx.to(a_in_dtype.device), + ) + out_weighted = torch.zeros_like(ref_output, device=out.device, dtype=out.dtype) + + positions = torch.nonzero(masked_m[topk_idx], as_tuple=False) + rows, cols = positions[:, 0], positions[:, 1] + experts = topk_idx[rows, cols] + for i in range(num_experts): + mask = experts == i + if mask.any(): + idx = torch.nonzero(mask, as_tuple=False).squeeze(-1) + r, c = rows[idx], cols[idx] + out_weighted[r] += out[i, : len(r), :] * routing_weights[r, c].to( + out.device + ).unsqueeze(-1) + torch.testing.assert_close( + out_weighted.cpu(), ref_output.cpu(), atol=5e-2, rtol=5e-2 + ) + + +@pytest.mark.parametrize( + "bs, hidden_dim, inter_dim, topk", [(2, 128, 256, 2), (16, 128, 512, 5)] +) +@torch.inference_mode() +def test_grouped_gemm_nt_masked( + bs: int, hidden_dim: int, inter_dim: int, topk: int +) -> None: + torch.manual_seed(42) + B = bs + D = hidden_dim + N = inter_dim + num_experts = 8 + hidden_states = torch.randn(B, D, dtype=torch.bfloat16, device="cuda") + weights = torch.randn(num_experts, N, D, dtype=torch.bfloat16, device="cuda") + router_logits = torch.randn(B, num_experts, dtype=torch.float32) + + hidden_states_expanded = ( + hidden_states.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + ) + hidden_states_3d, masked_m, topk_idx, _ = prepare_inputs( + hidden_states_expanded, router_logits, num_experts, topk + ) + + # reference + out = torch.zeros( + (B * topk, weights.shape[1]), dtype=weights.dtype, device=weights.device + ) + for i in range(num_experts): + mask = topk_idx.view(-1) == i + if mask.sum(): + lhs = hidden_states_expanded[mask] + rhs = weights[i] + a_amax = lhs.abs().max().to(torch.float32).to(hidden_states.device) + b_amax = rhs.abs().amax().to(torch.float32).to(weights.device) + a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax + b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax + + lhsq, lhsq_sf = fp4_quantize( + lhs, + a_gs, + ) + rhsq, rhsq_sf = fp4_quantize( + rhs, + b_gs, + ) + + lhs_in_dtype = dequantize_nvfp4_to_dtype( + lhsq, + lhsq_sf, + a_gs, + dtype=hidden_states.dtype, + device=hidden_states.device, + block_size=16, + ) + + rhs_in_dtype = dequantize_nvfp4_to_dtype( + rhsq, + rhsq_sf, + b_gs, + dtype=hidden_states.dtype, + device=hidden_states.device, + block_size=16, + ) + out[mask] = lhs_in_dtype @ rhs_in_dtype.t() + + a_amax = ( + hidden_states_3d.abs() + .amax(dim=(1, 2)) + .to(torch.float32) + .to(hidden_states.device) + ) + b_amax = weights.abs().amax(dim=(1, 2)).to(torch.float32).to(weights.device) + a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax + b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax + out_flashinfer = flashinfer_cutedsl_grouped_gemm_nt_masked( + hidden_states_3d.to(hidden_states.device), a_gs, weights, b_gs, masked_m + ) + + # re-pack out into [num_experts, max_m, n] + out_ref = torch.zeros( + (num_experts, max(masked_m), weights.shape[1]), dtype=out.dtype + ) + expert_slot = [0] * num_experts + for i, expert_id in enumerate(topk_idx.view(-1).tolist()): + out_ref[expert_id, expert_slot[expert_id], :] = out[i] + expert_slot[expert_id] += 1 + + # Note: just to compare the masked position due to cutedsl may write nan + # into unmasked position. + for i in range(num_experts): + torch.testing.assert_close( + out_flashinfer.permute(2, 0, 1)[i, : masked_m[i]], + out_ref.to(out_flashinfer.device)[i, : masked_m[i]], + atol=1e-1, + rtol=5e-2, + ) + + if __name__ == "__main__": test_cutlass_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half) test_flashinfer_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half) + test_flashinfer_cutedsl_moe_masked(16, 128, 512, 4) + test_grouped_gemm_nt_masked(16, 128, 512, 4) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 0d3d769f4..90a8ef5e1 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -53,6 +53,9 @@ DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instru DEFAULT_MODEL_NAME_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test" DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN = "lmsys/sglang-ci-dsv3-test-NextN" +# NVFP4 models +DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST = "nvidia/DeepSeek-R1-0528-FP4" + # FP8 models DEFAULT_MODEL_NAME_FOR_TEST_FP8 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8" DEFAULT_MODEL_NAME_FOR_ACCURACY_TEST_FP8 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8" diff --git a/test/srt/test_cutedsl_flashinfer_8gpu.py b/test/srt/test_cutedsl_flashinfer_8gpu.py new file mode 100644 index 000000000..f062f3589 --- /dev/null +++ b/test/srt/test_cutedsl_flashinfer_8gpu.py @@ -0,0 +1,77 @@ +import os +import unittest +from types import SimpleNamespace + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, + try_cached_model, +) + + +class TestDeepseekR1Nvfp4CuteDSLDeepEP(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = try_cached_model(DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST) + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = [ + "--trust-remote-code", + "--disable-radix-cache", + "--max-running-requests", + "256", + "--chunked-prefill-size", + "2048", + "--tp", + "8", + "--dp", + "8", + "--enable-dp-attention", + "--enable-ep-moe", + "--quantization", + "modelopt_fp4", + "--enable-flashinfer-cutedsl-moe", + "--enable-deepep-moe", + "--deepep-mode", + "low_latency", + ] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + env={ + **os.environ, + "SGLANG_DEEPEP_BF16_DISPATCH": "1", + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "256", + }, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=512, + parallel=512, + max_new_tokens=512, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"Eval accuracy of GSM8K: {metrics=}") + + self.assertGreater(metrics["accuracy"], 0.92) + + +if __name__ == "__main__": + unittest.main()