Add enable_flashinfer_mxfp4_bf16_moe for higher precision and slower moe backend (#9004)
This commit is contained in:
@@ -30,6 +30,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
)
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
get_bool_env_var,
|
||||
@@ -262,6 +263,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
||||
self.with_bias = False
|
||||
self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
|
||||
self.flashinfer_mxfp4_moe_precision = global_server_args_dict[
|
||||
"flashinfer_mxfp4_moe_precision"
|
||||
]
|
||||
|
||||
self.triton_kernel_moe_forward = None
|
||||
self.triton_kernel_moe_with_bias_forward = None
|
||||
@@ -615,11 +619,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
||||
|
||||
if self.use_flashinfer:
|
||||
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
|
||||
x_quant, x_scale = mxfp8_quantize(
|
||||
x, False, alignment=self.hidden_size
|
||||
) # to mxfp8
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||
# When bf16 mode is enabled, we don't need to quantize the input,
|
||||
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
|
||||
# which can theoretically improve performance
|
||||
if self.flashinfer_mxfp4_moe_precision == "bf16":
|
||||
assert x.dtype == torch.bfloat16
|
||||
x_quant = x
|
||||
x_scale = None
|
||||
|
||||
# May be fused later if this code branch is frequently needed
|
||||
origin_hidden_states_dim = x_quant.shape[-1]
|
||||
if self.hidden_size != origin_hidden_states_dim:
|
||||
x_quant = torch.nn.functional.pad(
|
||||
x_quant,
|
||||
(0, self.hidden_size - origin_hidden_states_dim),
|
||||
mode="constant",
|
||||
value=0.0,
|
||||
)
|
||||
elif self.flashinfer_mxfp4_moe_precision == "default":
|
||||
x_quant, x_scale = mxfp8_quantize(x, False, alignment=self.hidden_size)
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
assert x_quant.shape[-1] == self.hidden_size
|
||||
assert TopKOutputChecker.format_is_bypassed(topk_output)
|
||||
|
||||
|
||||
@@ -87,6 +87,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"disable_flashinfer_cutlass_moe_fp4_allgather",
|
||||
"disable_radix_cache",
|
||||
"enable_dp_lm_head",
|
||||
"flashinfer_mxfp4_moe_precision",
|
||||
"enable_flashinfer_allreduce_fusion",
|
||||
"moe_dense_tp_size",
|
||||
"ep_dispatch_algorithm",
|
||||
|
||||
@@ -190,6 +190,7 @@ class ServerArgs:
|
||||
"flashinfer_cutlass",
|
||||
"flashinfer_mxfp4",
|
||||
] = "auto"
|
||||
flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default"
|
||||
enable_flashinfer_allreduce_fusion: bool = False
|
||||
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
|
||||
ep_num_redundant_experts: int = 0
|
||||
@@ -1496,10 +1497,18 @@ class ServerArgs:
|
||||
"triton_kernel",
|
||||
"flashinfer_trtllm",
|
||||
"flashinfer_cutlass",
|
||||
"flashinfer_mxfp4",
|
||||
],
|
||||
default=ServerArgs.moe_runner_backend,
|
||||
help="Choose the runner backend for MoE.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flashinfer-mxfp4-moe-precision",
|
||||
type=str,
|
||||
choices=["mxfp4", "bf16"],
|
||||
default=ServerArgs.flashinfer_mxfp4_moe_precision,
|
||||
help="Choose the computation precision of flashinfer mxfp4 moe",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-allreduce-fusion",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user