[NVIDIA] Add Flashinfer MoE blockscale fp8 backend (#8036)
This commit is contained in:
@@ -47,12 +47,17 @@ from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
is_hip,
|
||||
is_npu,
|
||||
next_power_of_2,
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_npu = is_npu()
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
use_flashinfer_trtllm_moe = (
|
||||
global_server_args_dict["enable_flashinfer_trtllm_moe"]
|
||||
and global_server_args_dict["enable_ep_moe"]
|
||||
)
|
||||
|
||||
if not (_is_npu or _is_hip):
|
||||
from sgl_kernel import silu_and_mul
|
||||
@@ -64,6 +69,13 @@ if _use_aiter:
|
||||
from aiter.fused_moe import fused_moe
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
if use_flashinfer_trtllm_moe:
|
||||
try:
|
||||
import flashinfer.fused_moe as fi_fused_moe
|
||||
except ImportError:
|
||||
fi_fused_moe = None
|
||||
use_flashinfer_trtllm_moe = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -140,6 +152,16 @@ class GroupedGemmRunner(torch.nn.Module):
|
||||
return c
|
||||
|
||||
|
||||
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||
# Guess tokens per expert assuming perfect expert distribution first.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
return tile_tokens_dim
|
||||
|
||||
|
||||
class EPMoE(torch.nn.Module):
|
||||
"""
|
||||
MoE Expert Parallel Impl
|
||||
@@ -776,14 +798,20 @@ class EPMoE(torch.nn.Module):
|
||||
)
|
||||
return
|
||||
|
||||
if shard_id == "w2":
|
||||
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
|
||||
if use_flashinfer_trtllm_moe:
|
||||
actual_shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
|
||||
else:
|
||||
actual_shard_id = shard_id
|
||||
|
||||
if actual_shard_id == "w2":
|
||||
param.data[expert_id] = loaded_weight
|
||||
elif shard_id == "w1":
|
||||
elif actual_shard_id == "w1":
|
||||
param.data[expert_id][: self.intermediate_size, :] = loaded_weight
|
||||
elif shard_id == "w3":
|
||||
elif actual_shard_id == "w3":
|
||||
param.data[expert_id][self.intermediate_size :, :] = loaded_weight
|
||||
else:
|
||||
raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}")
|
||||
raise ValueError(f"Expected shard_id w1,w2 or w3 but got {actual_shard_id}")
|
||||
|
||||
def _load_fp8_scale(
|
||||
self,
|
||||
@@ -820,12 +848,18 @@ class EPMoE(torch.nn.Module):
|
||||
# Weight scales
|
||||
elif "weight_scale" in weight_name:
|
||||
if self.use_block_quant:
|
||||
if use_flashinfer_trtllm_moe:
|
||||
actual_shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
|
||||
else:
|
||||
actual_shard_id = shard_id
|
||||
|
||||
block_n, block_k = self.block_shape[0], self.block_shape[1]
|
||||
if shard_id == "w1":
|
||||
|
||||
if actual_shard_id == "w1":
|
||||
param_data[expert_id][
|
||||
: (self.intermediate_size + block_n - 1) // block_n, :
|
||||
] = loaded_weight
|
||||
elif shard_id == "w3":
|
||||
elif actual_shard_id == "w3":
|
||||
param_data[expert_id][
|
||||
(self.intermediate_size + block_n - 1) // block_n :, :
|
||||
] = loaded_weight
|
||||
@@ -1315,12 +1349,73 @@ class DeepEPMoE(EPMoE):
|
||||
return down_output
|
||||
|
||||
|
||||
class FlashInferEPMoE(EPMoE):
|
||||
def __init__(self, *args, **kwargs):
|
||||
renormalize = kwargs.pop("renormalize", True)
|
||||
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
|
||||
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
|
||||
num_expert_group = kwargs.pop("num_expert_group", None)
|
||||
topk_group = kwargs.pop("topk_group", None)
|
||||
correction_bias = kwargs.pop("correction_bias", None)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.renormalize = renormalize
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
if self.use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.correction_bias = correction_bias
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||
assert use_flashinfer_trtllm_moe
|
||||
assert (
|
||||
self.activation == "silu"
|
||||
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
||||
assert (
|
||||
self.renormalize
|
||||
), "Renormalize is required for flashinfer blockscale fp8 moe"
|
||||
assert (
|
||||
self.num_fused_shared_experts == 0
|
||||
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
|
||||
a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
|
||||
# NOTE: scales of hidden states have to be transposed!
|
||||
a_sf_t = a_sf.t().contiguous()
|
||||
assert fi_fused_moe is not None
|
||||
return fi_fused_moe.trtllm_fp8_block_scale_moe(
|
||||
routing_logits=router_logits.to(torch.float32),
|
||||
routing_bias=self.correction_bias.to(hidden_states.dtype),
|
||||
hidden_states=a_q,
|
||||
hidden_states_scale=a_sf_t,
|
||||
gemm1_weights=self.w13_weight,
|
||||
gemm1_weights_scale=self.w13_weight_scale_inv,
|
||||
gemm2_weights=self.w2_weight,
|
||||
gemm2_weights_scale=self.w2_weight_scale_inv,
|
||||
num_experts=self.num_experts,
|
||||
top_k=self.top_k,
|
||||
n_group=self.num_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
intermediate_size=self.w2_weight.shape[2],
|
||||
local_expert_offset=self.start_expert_id,
|
||||
local_num_experts=self.num_experts_per_partition,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
tile_tokens_dim=_get_tile_tokens_dim(
|
||||
hidden_states.shape[0], self.top_k, self.num_experts
|
||||
),
|
||||
routing_method_type=2, # DeepSeek-styled routing method
|
||||
use_shuffled_weight=False,
|
||||
)
|
||||
|
||||
|
||||
def get_moe_impl_class():
|
||||
if global_server_args_dict["enable_deepep_moe"]:
|
||||
return DeepEPMoE
|
||||
if global_server_args_dict["enable_flashinfer_moe"]:
|
||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
|
||||
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
|
||||
return FusedMoE
|
||||
if use_flashinfer_trtllm_moe:
|
||||
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
|
||||
return FlashInferEPMoE
|
||||
if global_server_args_dict["enable_ep_moe"]:
|
||||
return EPMoE
|
||||
return FusedMoE
|
||||
|
||||
@@ -75,7 +75,7 @@ class FusedMoE(torch.nn.Module):
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
enable_flashinfer_moe: Optional[bool] = False,
|
||||
enable_flashinfer_cutlass_moe: Optional[bool] = False,
|
||||
enable_ep_moe: Optional[bool] = False,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -92,16 +92,16 @@ class FusedMoE(torch.nn.Module):
|
||||
self.num_experts = num_experts
|
||||
self.expert_map = None
|
||||
|
||||
if enable_flashinfer_moe and quant_config is None:
|
||||
if enable_flashinfer_cutlass_moe and quant_config is None:
|
||||
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
||||
enable_flashinfer_moe = False
|
||||
enable_flashinfer_cutlass_moe = False
|
||||
enable_ep_moe = False
|
||||
|
||||
self.enable_flashinfer_moe = enable_flashinfer_moe
|
||||
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
|
||||
if enable_ep_moe:
|
||||
assert (
|
||||
self.enable_flashinfer_moe
|
||||
), "FusedMoE only supports EP with --enable-flashinfer-moe"
|
||||
self.enable_flashinfer_cutlass_moe
|
||||
), "FusedMoE only supports EP with --enable-flashinfer-cutlass-moe"
|
||||
self.ep_size = self.tp_size
|
||||
self.ep_rank = self.tp_rank
|
||||
self.tp_size = 1
|
||||
@@ -141,7 +141,9 @@ class FusedMoE(torch.nn.Module):
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
|
||||
self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
|
||||
self.quant_method.enable_flashinfer_cutlass_moe = (
|
||||
self.enable_flashinfer_cutlass_moe
|
||||
)
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
||||
@@ -711,7 +711,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
" quantization. Please use Blackwell and"
|
||||
" above."
|
||||
)
|
||||
self.enable_flashinfer_moe = False
|
||||
self.enable_flashinfer_cutlass_moe = False
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -865,7 +865,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
|
||||
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
|
||||
|
||||
if self.enable_flashinfer_moe:
|
||||
if self.enable_flashinfer_cutlass_moe:
|
||||
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
|
||||
else:
|
||||
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
||||
@@ -894,7 +894,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
||||
|
||||
# GEMM 2
|
||||
if self.enable_flashinfer_moe:
|
||||
if self.enable_flashinfer_cutlass_moe:
|
||||
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
|
||||
else:
|
||||
w2_input_scale = layer.w2_input_scale
|
||||
@@ -934,7 +934,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
@property
|
||||
def load_up_proj_weight_first(self) -> bool:
|
||||
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
|
||||
return self.enable_flashinfer_moe
|
||||
return self.enable_flashinfer_cutlass_moe
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -954,7 +954,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
if self.enable_flashinfer_moe:
|
||||
if self.enable_flashinfer_cutlass_moe:
|
||||
assert (
|
||||
not apply_router_weight_on_input
|
||||
), "apply_router_weight_on_input is not supported for Flashinfer"
|
||||
|
||||
@@ -88,7 +88,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"enable_deepep_moe",
|
||||
"deepep_mode",
|
||||
"enable_ep_moe",
|
||||
"enable_flashinfer_moe",
|
||||
"enable_flashinfer_cutlass_moe",
|
||||
"enable_flashinfer_trtllm_moe",
|
||||
"enable_flashinfer_allreduce_fusion",
|
||||
"moe_dense_tp_size",
|
||||
"ep_dispatch_algorithm",
|
||||
|
||||
@@ -56,7 +56,11 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
||||
from sglang.srt.layers.moe.ep_moe.layer import (
|
||||
DeepEPMoE,
|
||||
get_moe_impl_class,
|
||||
use_flashinfer_trtllm_moe,
|
||||
)
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
@@ -302,15 +306,19 @@ class DeepseekV2MoE(nn.Module):
|
||||
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
||||
)
|
||||
|
||||
self.topk = TopK(
|
||||
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
||||
renormalize=config.norm_topk_prob,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
self.topk = (
|
||||
TopK(
|
||||
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
||||
renormalize=config.norm_topk_prob,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
if not use_flashinfer_trtllm_moe
|
||||
else None
|
||||
)
|
||||
|
||||
self.experts = get_moe_impl_class()(
|
||||
@@ -332,10 +340,22 @@ class DeepseekV2MoE(nn.Module):
|
||||
# Additional args for FusedMoE
|
||||
**(
|
||||
dict(
|
||||
enable_flashinfer_moe=True,
|
||||
enable_flashinfer_cutlass_moe=True,
|
||||
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
||||
)
|
||||
if global_server_args_dict["enable_flashinfer_moe"]
|
||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
||||
else {}
|
||||
),
|
||||
**(
|
||||
dict(
|
||||
renormalize=config.norm_topk_prob,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
)
|
||||
if use_flashinfer_trtllm_moe
|
||||
else {}
|
||||
),
|
||||
)
|
||||
@@ -455,10 +475,12 @@ class DeepseekV2MoE(nn.Module):
|
||||
with torch.cuda.stream(self.alt_stream):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, topk_output=topk_output
|
||||
)
|
||||
kwargs = {"hidden_states": hidden_states}
|
||||
if self.topk is not None:
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
else:
|
||||
kwargs["router_logits"] = router_logits
|
||||
final_hidden_states = self.experts(**kwargs)
|
||||
if not _is_cuda:
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
current_stream.wait_stream(self.alt_stream)
|
||||
@@ -478,10 +500,12 @@ class DeepseekV2MoE(nn.Module):
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, topk_output=topk_output
|
||||
)
|
||||
kwargs = {"hidden_states": hidden_states}
|
||||
if self.topk is not None:
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
else:
|
||||
kwargs["router_logits"] = router_logits
|
||||
final_hidden_states = self.experts(**kwargs)
|
||||
if not _is_cuda and not _use_aiter:
|
||||
# fused in biased_grouped_topk so we can skip here
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
|
||||
@@ -147,10 +147,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
# Additional args for FusedMoE
|
||||
**(
|
||||
dict(
|
||||
enable_flashinfer_moe=True,
|
||||
enable_flashinfer_cutlass_moe=True,
|
||||
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
||||
)
|
||||
if global_server_args_dict["enable_flashinfer_moe"]
|
||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
||||
else {}
|
||||
),
|
||||
)
|
||||
|
||||
@@ -120,10 +120,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
# Additional args for FusedMoE
|
||||
**(
|
||||
dict(
|
||||
enable_flashinfer_moe=True,
|
||||
enable_flashinfer_cutlass_moe=True,
|
||||
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
||||
)
|
||||
if global_server_args_dict["enable_flashinfer_moe"]
|
||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
||||
else {}
|
||||
),
|
||||
)
|
||||
|
||||
@@ -169,7 +169,8 @@ class ServerArgs:
|
||||
ep_size: int = 1
|
||||
enable_ep_moe: bool = False
|
||||
enable_deepep_moe: bool = False
|
||||
enable_flashinfer_moe: bool = False
|
||||
enable_flashinfer_cutlass_moe: bool = False
|
||||
enable_flashinfer_trtllm_moe: bool = False
|
||||
enable_flashinfer_allreduce_fusion: bool = False
|
||||
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
||||
ep_num_redundant_experts: int = 0
|
||||
@@ -428,12 +429,16 @@ class ServerArgs:
|
||||
), "Please enable dp attention when setting enable_dp_lm_head. "
|
||||
|
||||
# MoE kernel
|
||||
if self.enable_flashinfer_moe:
|
||||
if self.enable_flashinfer_cutlass_moe:
|
||||
assert (
|
||||
self.quantization == "modelopt_fp4"
|
||||
), "modelopt_fp4 quantization is required for Flashinfer MOE"
|
||||
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
||||
|
||||
if self.enable_flashinfer_trtllm_moe:
|
||||
assert self.enable_ep_moe, "EP MoE is required for Flashinfer TRTLLM MOE"
|
||||
logger.warning(f"Flashinfer TRTLLM MoE is enabled.")
|
||||
|
||||
# DeepEP MoE
|
||||
if self.enable_deepep_moe:
|
||||
if self.deepep_mode == "normal":
|
||||
@@ -1293,10 +1298,15 @@ class ServerArgs:
|
||||
help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-moe",
|
||||
"--enable-flashinfer-cutlass-moe",
|
||||
action="store_true",
|
||||
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-trtllm-moe",
|
||||
action="store_true",
|
||||
help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP with --enable-ep-moe",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-allreduce-fusion",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user