[Perf] Auto enable best flashinfer mxfp4 kernel in b200 (#8898)
This commit is contained in:
@@ -206,13 +206,13 @@ class FusedMoE(torch.nn.Module):
|
|||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
|
||||||
|
"enable_flashinfer_mxfp4_moe", False
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
self.quant_config is not None
|
self.quant_config is not None
|
||||||
and self.quant_config.get_name() == "mxfp4"
|
and self.quant_config.get_name() == "mxfp4"
|
||||||
and (
|
and self.use_enable_flashinfer_mxfp4_moe
|
||||||
get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_MOE")
|
|
||||||
or get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE")
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
hidden_size = round_up(hidden_size, 256)
|
hidden_size = round_up(hidden_size, 256)
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
|
|||||||
@@ -3,22 +3,20 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import importlib
|
import importlib.util
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Callable, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
# from vllm.model_executor.layers.fused_moe import (
|
|
||||||
# FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
|
|
||||||
# FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
|
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
FusedMoEMethodBase,
|
FusedMoEMethodBase,
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
direct_register_custom_op,
|
direct_register_custom_op,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
@@ -32,11 +30,6 @@ from sglang.srt.utils import (
|
|||||||
|
|
||||||
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
||||||
|
|
||||||
# Environment variables for FlashInfer MXFP4 MoE backend
|
|
||||||
USE_FLASHINFER_MXFP4_MOE = get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_MOE", "false")
|
|
||||||
USE_FLASHINFER_MXFP4_BF16_MOE = get_bool_env_var(
|
|
||||||
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE", "false"
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_flashinfer_available():
|
if is_flashinfer_available():
|
||||||
# from flashinfer.fused_moe import cutlass_fused_moe
|
# from flashinfer.fused_moe import cutlass_fused_moe
|
||||||
@@ -193,7 +186,12 @@ class Mxfp4Config(QuantizationConfig):
|
|||||||
):
|
):
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
return Mxfp4MoEMethod(use_triton_kernels=True, with_bias=True)
|
use_flashinfer = global_server_args_dict.get(
|
||||||
|
"enable_flashinfer_mxfp4_moe", False
|
||||||
|
)
|
||||||
|
return Mxfp4MoEMethod(
|
||||||
|
use_triton_kernels=True, with_bias=True, use_flashinfer=use_flashinfer
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Mxfp4 attention layer is not implemented")
|
raise NotImplementedError("Mxfp4 attention layer is not implemented")
|
||||||
return None
|
return None
|
||||||
@@ -204,11 +202,18 @@ class Mxfp4Config(QuantizationConfig):
|
|||||||
|
|
||||||
class Mxfp4MoEMethod(FusedMoEMethodBase):
|
class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||||
|
|
||||||
def __init__(self, use_triton_kernels: bool = True, with_bias: bool = True):
|
def __init__(
|
||||||
|
self,
|
||||||
|
use_triton_kernels: bool = True,
|
||||||
|
with_bias: bool = True,
|
||||||
|
use_flashinfer: bool = False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.topk_indices_dtype = None
|
self.topk_indices_dtype = None
|
||||||
self.use_triton_kernels = use_triton_kernels
|
self.use_triton_kernels = use_triton_kernels
|
||||||
self.with_bias = with_bias
|
self.with_bias = with_bias
|
||||||
|
self.use_flashinfer = use_flashinfer
|
||||||
|
|
||||||
self.triton_kernel_moe_forward = None
|
self.triton_kernel_moe_forward = None
|
||||||
self.triton_kernel_moe_with_bias_forward = None
|
self.triton_kernel_moe_with_bias_forward = None
|
||||||
if torch.cuda.is_available() and has_triton_kernels:
|
if torch.cuda.is_available() and has_triton_kernels:
|
||||||
@@ -239,7 +244,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
||||||
# for to hold non-uniform sharded tensor as well as swizzling
|
# for to hold non-uniform sharded tensor as well as swizzling
|
||||||
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE:
|
if self.use_flashinfer:
|
||||||
intermediate_size_per_partition_after_pad = round_up(intermediate_size, 256)
|
intermediate_size_per_partition_after_pad = round_up(intermediate_size, 256)
|
||||||
hidden_size = round_up(hidden_size, 256)
|
hidden_size = round_up(hidden_size, 256)
|
||||||
elif is_hip():
|
elif is_hip():
|
||||||
@@ -319,7 +324,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
set_weight_attrs(w2_weight_bias, extra_weight_attrs)
|
set_weight_attrs(w2_weight_bias, extra_weight_attrs)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer):
|
def process_weights_after_loading(self, layer):
|
||||||
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE:
|
if self.use_flashinfer:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Shuffling MoE weights for FlashInfer, it might take a while..."
|
"Shuffling MoE weights for FlashInfer, it might take a while..."
|
||||||
)
|
)
|
||||||
@@ -544,20 +549,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
activation_alpha: Optional[float] = None,
|
activation_alpha: Optional[float] = None,
|
||||||
swiglu_limit: Optional[float] = None,
|
swiglu_limit: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE:
|
if self.use_flashinfer:
|
||||||
# When USE_FLASHINFER_MXFP4_BF16_MOE is enabled, we don't need to quantize the input,
|
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
|
||||||
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
|
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
|
||||||
# which can theoretically improve performance
|
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||||
if USE_FLASHINFER_MXFP4_BF16_MOE:
|
|
||||||
assert x.dtype == torch.bfloat16
|
|
||||||
x_quant = x
|
|
||||||
x_scale = None
|
|
||||||
else:
|
|
||||||
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
|
|
||||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
|
||||||
|
|
||||||
topk_weights, topk_ids, router_logits = topk_output
|
top_k, router_logits = topk_output
|
||||||
top_k = topk_weights.shape[-1]
|
|
||||||
|
|
||||||
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
||||||
router_logits.to(torch.bfloat16),
|
router_logits.to(torch.bfloat16),
|
||||||
|
|||||||
@@ -107,6 +107,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|||||||
"num_reserved_decode_tokens",
|
"num_reserved_decode_tokens",
|
||||||
"weight_loader_disable_mmap",
|
"weight_loader_disable_mmap",
|
||||||
"enable_triton_kernel_moe",
|
"enable_triton_kernel_moe",
|
||||||
|
"enable_flashinfer_mxfp4_moe",
|
||||||
"enable_multimodal",
|
"enable_multimodal",
|
||||||
"enable_symm_mem",
|
"enable_symm_mem",
|
||||||
"quantization",
|
"quantization",
|
||||||
|
|||||||
@@ -102,11 +102,15 @@ class GptOssSparseMoeBlock(nn.Module):
|
|||||||
f"the number of experts {config.num_local_experts}."
|
f"the number of experts {config.num_local_experts}."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.topk = TopK(
|
if global_server_args_dict["enable_flashinfer_mxfp4_moe"]:
|
||||||
top_k=config.num_experts_per_tok,
|
self.topk = None
|
||||||
renormalize=True,
|
else:
|
||||||
)
|
self.topk = TopK(
|
||||||
|
top_k=config.num_experts_per_tok,
|
||||||
|
renormalize=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.top_k = config.num_experts_per_tok
|
||||||
experts_type = get_moe_impl_class()
|
experts_type = get_moe_impl_class()
|
||||||
extra_kwargs = {}
|
extra_kwargs = {}
|
||||||
if experts_type.__name__ == "FusedMoE":
|
if experts_type.__name__ == "FusedMoE":
|
||||||
@@ -176,7 +180,7 @@ class GptOssSparseMoeBlock(nn.Module):
|
|||||||
if self.topk is not None:
|
if self.topk is not None:
|
||||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||||
else:
|
else:
|
||||||
kwargs["router_logits"] = router_logits
|
kwargs["topk_output"] = (self.top_k, router_logits)
|
||||||
final_hidden_states = self.experts(**kwargs)
|
final_hidden_states = self.experts(**kwargs)
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
|
|||||||
@@ -248,6 +248,7 @@ class ServerArgs:
|
|||||||
disable_fast_image_processor: bool = False
|
disable_fast_image_processor: bool = False
|
||||||
enable_return_hidden_states: bool = False
|
enable_return_hidden_states: bool = False
|
||||||
enable_triton_kernel_moe: bool = False
|
enable_triton_kernel_moe: bool = False
|
||||||
|
enable_flashinfer_mxfp4_moe: bool = False
|
||||||
|
|
||||||
# Debug tensor dumps
|
# Debug tensor dumps
|
||||||
debug_tensor_dump_output_folder: Optional[str] = None
|
debug_tensor_dump_output_folder: Optional[str] = None
|
||||||
@@ -476,18 +477,10 @@ class ServerArgs:
|
|||||||
or self.attention_backend == "triton"
|
or self.attention_backend == "triton"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if FlashInfer MXFP4 MoE is enabled
|
if is_sm100_supported():
|
||||||
from sglang.srt.utils import get_bool_env_var
|
self.enable_flashinfer_mxfp4_moe = True
|
||||||
|
self.enable_triton_kernel_moe = False
|
||||||
USE_FLASHINFER_MXFP4_MOE = get_bool_env_var(
|
else:
|
||||||
"SGLANG_USE_FLASHINFER_MXFP4_MOE", "false"
|
|
||||||
)
|
|
||||||
USE_FLASHINFER_MXFP4_BF16_MOE = get_bool_env_var(
|
|
||||||
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE", "false"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only enable Triton kernel MoE if FlashInfer is not enabled
|
|
||||||
if not (USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE):
|
|
||||||
self.enable_triton_kernel_moe = True
|
self.enable_triton_kernel_moe = True
|
||||||
|
|
||||||
self.disable_hybrid_swa_memory = True
|
self.disable_hybrid_swa_memory = True
|
||||||
@@ -1846,6 +1839,11 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Use triton moe grouped gemm kernel.",
|
help="Use triton moe grouped gemm kernel.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-flashinfer-mxfp4-moe",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable FlashInfer MXFP4 MoE backend for modelopt_fp4 quant on Blackwell.",
|
||||||
|
)
|
||||||
|
|
||||||
# Debug tensor dumps
|
# Debug tensor dumps
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user