Support dispatch low latency (#10263)
Co-authored-by: Kaixi Hou <4001424+kaixih@users.noreply.github.com>
This commit is contained in:
@@ -80,6 +80,10 @@ CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
|
||||
USE_CUTLASS_BACKEND_FOR_FP4_GEMM = get_bool_env_var(
|
||||
"SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM"
|
||||
)
|
||||
# TODO make it true by default when the DeepEP PR is merged
|
||||
CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
|
||||
"SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH", "false"
|
||||
)
|
||||
|
||||
# Supported activation schemes for the current configuration
|
||||
ACTIVATION_SCHEMES = ["static"]
|
||||
@@ -1234,6 +1238,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
w13_input_scale = _slice_scale(w13_input_scale)
|
||||
w2_input_scale = _slice_scale(w2_input_scale)
|
||||
|
||||
if CUTEDSL_MOE_NVFP4_DISPATCH:
|
||||
assert torch.all(w13_input_scale == w13_input_scale[0])
|
||||
w13_input_scale = w13_input_scale[0]
|
||||
else:
|
||||
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
||||
w2_input_scale = layer.w2_input_scale
|
||||
@@ -1476,7 +1484,9 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
out = flashinfer_cutedsl_moe_masked(
|
||||
hidden_states=x,
|
||||
input_global_scale=layer.w13_input_scale_quant,
|
||||
input_global_scale=(
|
||||
None if CUTEDSL_MOE_NVFP4_DISPATCH else layer.w13_input_scale_quant
|
||||
),
|
||||
w1=layer.w13_weight,
|
||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||
w1_alpha=layer.g1_alphas,
|
||||
|
||||
Reference in New Issue
Block a user