Support dispatch low latency (#10263)

Co-authored-by: Kaixi Hou <4001424+kaixih@users.noreply.github.com>
This commit is contained in:
fzyzcjy
2025-10-02 18:02:19 +08:00
committed by GitHub
parent 6a29003410
commit 0b9dfba787
5 changed files with 80 additions and 29 deletions

View File

@@ -80,6 +80,10 @@ CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
USE_CUTLASS_BACKEND_FOR_FP4_GEMM = get_bool_env_var(
"SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM"
)
# TODO make it true by default when the DeepEP PR is merged
CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
"SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH", "false"
)
# Supported activation schemes for the current configuration
ACTIVATION_SCHEMES = ["static"]
@@ -1234,6 +1238,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w13_input_scale = _slice_scale(w13_input_scale)
w2_input_scale = _slice_scale(w2_input_scale)
if CUTEDSL_MOE_NVFP4_DISPATCH:
assert torch.all(w13_input_scale == w13_input_scale[0])
w13_input_scale = w13_input_scale[0]
else:
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
w2_input_scale = layer.w2_input_scale
@@ -1476,7 +1484,9 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
out = flashinfer_cutedsl_moe_masked(
hidden_states=x,
input_global_scale=layer.w13_input_scale_quant,
input_global_scale=(
None if CUTEDSL_MOE_NVFP4_DISPATCH else layer.w13_input_scale_quant
),
w1=layer.w13_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w1_alpha=layer.g1_alphas,