[Perf] Auto enable best flashinfer mxfp4 kernel in b200 (#8898)
This commit is contained in:
@@ -206,13 +206,13 @@ class FusedMoE(torch.nn.Module):
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
|
||||
"enable_flashinfer_mxfp4_moe", False
|
||||
)
|
||||
if (
|
||||
self.quant_config is not None
|
||||
and self.quant_config.get_name() == "mxfp4"
|
||||
and (
|
||||
get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_MOE")
|
||||
or get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE")
|
||||
)
|
||||
and self.use_enable_flashinfer_mxfp4_moe
|
||||
):
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
@@ -3,22 +3,20 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import importlib.util
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
# from vllm.model_executor.layers.fused_moe import (
|
||||
# FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
|
||||
# FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
get_bool_env_var,
|
||||
@@ -32,11 +30,6 @@ from sglang.srt.utils import (
|
||||
|
||||
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
||||
|
||||
# Environment variables for FlashInfer MXFP4 MoE backend
|
||||
USE_FLASHINFER_MXFP4_MOE = get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_MOE", "false")
|
||||
USE_FLASHINFER_MXFP4_BF16_MOE = get_bool_env_var(
|
||||
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE", "false"
|
||||
)
|
||||
|
||||
if is_flashinfer_available():
|
||||
# from flashinfer.fused_moe import cutlass_fused_moe
|
||||
@@ -193,7 +186,12 @@ class Mxfp4Config(QuantizationConfig):
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return Mxfp4MoEMethod(use_triton_kernels=True, with_bias=True)
|
||||
use_flashinfer = global_server_args_dict.get(
|
||||
"enable_flashinfer_mxfp4_moe", False
|
||||
)
|
||||
return Mxfp4MoEMethod(
|
||||
use_triton_kernels=True, with_bias=True, use_flashinfer=use_flashinfer
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Mxfp4 attention layer is not implemented")
|
||||
return None
|
||||
@@ -204,11 +202,18 @@ class Mxfp4Config(QuantizationConfig):
|
||||
|
||||
class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, use_triton_kernels: bool = True, with_bias: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
use_triton_kernels: bool = True,
|
||||
with_bias: bool = True,
|
||||
use_flashinfer: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.topk_indices_dtype = None
|
||||
self.use_triton_kernels = use_triton_kernels
|
||||
self.with_bias = with_bias
|
||||
self.use_flashinfer = use_flashinfer
|
||||
|
||||
self.triton_kernel_moe_forward = None
|
||||
self.triton_kernel_moe_with_bias_forward = None
|
||||
if torch.cuda.is_available() and has_triton_kernels:
|
||||
@@ -239,7 +244,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
||||
# for to hold non-uniform sharded tensor as well as swizzling
|
||||
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE:
|
||||
if self.use_flashinfer:
|
||||
intermediate_size_per_partition_after_pad = round_up(intermediate_size, 256)
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
elif is_hip():
|
||||
@@ -319,7 +324,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w2_weight_bias, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE:
|
||||
if self.use_flashinfer:
|
||||
logger.info(
|
||||
"Shuffling MoE weights for FlashInfer, it might take a while..."
|
||||
)
|
||||
@@ -544,20 +549,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE:
|
||||
# When USE_FLASHINFER_MXFP4_BF16_MOE is enabled, we don't need to quantize the input,
|
||||
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
|
||||
# which can theoretically improve performance
|
||||
if USE_FLASHINFER_MXFP4_BF16_MOE:
|
||||
assert x.dtype == torch.bfloat16
|
||||
x_quant = x
|
||||
x_scale = None
|
||||
else:
|
||||
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||
if self.use_flashinfer:
|
||||
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
|
||||
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||
|
||||
topk_weights, topk_ids, router_logits = topk_output
|
||||
top_k = topk_weights.shape[-1]
|
||||
top_k, router_logits = topk_output
|
||||
|
||||
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
||||
router_logits.to(torch.bfloat16),
|
||||
|
||||
@@ -107,6 +107,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"num_reserved_decode_tokens",
|
||||
"weight_loader_disable_mmap",
|
||||
"enable_triton_kernel_moe",
|
||||
"enable_flashinfer_mxfp4_moe",
|
||||
"enable_multimodal",
|
||||
"enable_symm_mem",
|
||||
"quantization",
|
||||
|
||||
@@ -102,11 +102,15 @@ class GptOssSparseMoeBlock(nn.Module):
|
||||
f"the number of experts {config.num_local_experts}."
|
||||
)
|
||||
|
||||
self.topk = TopK(
|
||||
top_k=config.num_experts_per_tok,
|
||||
renormalize=True,
|
||||
)
|
||||
if global_server_args_dict["enable_flashinfer_mxfp4_moe"]:
|
||||
self.topk = None
|
||||
else:
|
||||
self.topk = TopK(
|
||||
top_k=config.num_experts_per_tok,
|
||||
renormalize=True,
|
||||
)
|
||||
|
||||
self.top_k = config.num_experts_per_tok
|
||||
experts_type = get_moe_impl_class()
|
||||
extra_kwargs = {}
|
||||
if experts_type.__name__ == "FusedMoE":
|
||||
@@ -176,7 +180,7 @@ class GptOssSparseMoeBlock(nn.Module):
|
||||
if self.topk is not None:
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
else:
|
||||
kwargs["router_logits"] = router_logits
|
||||
kwargs["topk_output"] = (self.top_k, router_logits)
|
||||
final_hidden_states = self.experts(**kwargs)
|
||||
|
||||
if self.tp_size > 1:
|
||||
|
||||
@@ -248,6 +248,7 @@ class ServerArgs:
|
||||
disable_fast_image_processor: bool = False
|
||||
enable_return_hidden_states: bool = False
|
||||
enable_triton_kernel_moe: bool = False
|
||||
enable_flashinfer_mxfp4_moe: bool = False
|
||||
|
||||
# Debug tensor dumps
|
||||
debug_tensor_dump_output_folder: Optional[str] = None
|
||||
@@ -476,18 +477,10 @@ class ServerArgs:
|
||||
or self.attention_backend == "triton"
|
||||
)
|
||||
|
||||
# Check if FlashInfer MXFP4 MoE is enabled
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
USE_FLASHINFER_MXFP4_MOE = get_bool_env_var(
|
||||
"SGLANG_USE_FLASHINFER_MXFP4_MOE", "false"
|
||||
)
|
||||
USE_FLASHINFER_MXFP4_BF16_MOE = get_bool_env_var(
|
||||
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE", "false"
|
||||
)
|
||||
|
||||
# Only enable Triton kernel MoE if FlashInfer is not enabled
|
||||
if not (USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE):
|
||||
if is_sm100_supported():
|
||||
self.enable_flashinfer_mxfp4_moe = True
|
||||
self.enable_triton_kernel_moe = False
|
||||
else:
|
||||
self.enable_triton_kernel_moe = True
|
||||
|
||||
self.disable_hybrid_swa_memory = True
|
||||
@@ -1846,6 +1839,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Use triton moe grouped gemm kernel.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-mxfp4-moe",
|
||||
action="store_true",
|
||||
help="Enable FlashInfer MXFP4 MoE backend for modelopt_fp4 quant on Blackwell.",
|
||||
)
|
||||
|
||||
# Debug tensor dumps
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user