[NVIDIA] Enable Flashinfer MoE blockscale fp8 backend for TP MoE (#8450)
Co-authored-by: kushanam <42385577+kushanam@users.noreply.github.com>
This commit is contained in:
@@ -25,14 +25,22 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
silu_and_mul_triton_kernel,
|
||||
tma_align_input_scale,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
||||
FlashInferFusedMoE,
|
||||
FusedMoE,
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
||||
from sglang.srt.layers.quantization.fp8 import (
|
||||
Fp8Config,
|
||||
Fp8MoEMethod,
|
||||
get_tile_tokens_dim,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
is_fp8_fnuz,
|
||||
sglang_per_token_group_quant_fp8,
|
||||
@@ -49,7 +57,6 @@ from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
is_hip,
|
||||
is_npu,
|
||||
next_power_of_2,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -63,10 +70,7 @@ _is_hip = is_hip()
|
||||
_is_npu = is_npu()
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
use_flashinfer_trtllm_moe = (
|
||||
global_server_args_dict["enable_flashinfer_trtllm_moe"]
|
||||
and global_server_args_dict["enable_ep_moe"]
|
||||
)
|
||||
|
||||
|
||||
if not (_is_npu or _is_hip):
|
||||
from sgl_kernel import silu_and_mul
|
||||
@@ -76,26 +80,9 @@ if _use_aiter:
|
||||
from aiter.fused_moe import fused_moe
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
if use_flashinfer_trtllm_moe:
|
||||
try:
|
||||
import flashinfer.fused_moe as fi_fused_moe
|
||||
except ImportError:
|
||||
fi_fused_moe = None
|
||||
use_flashinfer_trtllm_moe = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||
# Guess tokens per expert assuming perfect expert distribution first.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
return tile_tokens_dim
|
||||
|
||||
|
||||
class EPMoE(FusedMoE):
|
||||
"""
|
||||
MoE Expert Parallel Impl
|
||||
@@ -731,10 +718,10 @@ class FlashInferEPMoE(EPMoE):
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.correction_bias = correction_bias
|
||||
self.use_flashinfer_trtllm_moe = use_flashinfer_trtllm_moe
|
||||
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||
assert use_flashinfer_trtllm_moe
|
||||
assert self.use_flashinfer_trtllm_moe
|
||||
assert (
|
||||
self.activation == "silu"
|
||||
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
||||
@@ -747,8 +734,9 @@ class FlashInferEPMoE(EPMoE):
|
||||
a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
|
||||
# NOTE: scales of hidden states have to be transposed!
|
||||
a_sf_t = a_sf.t().contiguous()
|
||||
assert fi_fused_moe is not None
|
||||
return fi_fused_moe.trtllm_fp8_block_scale_moe(
|
||||
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
|
||||
|
||||
return trtllm_fp8_block_scale_moe(
|
||||
routing_logits=router_logits.to(torch.float32),
|
||||
routing_bias=self.correction_bias.to(hidden_states.dtype),
|
||||
hidden_states=a_q,
|
||||
@@ -765,7 +753,7 @@ class FlashInferEPMoE(EPMoE):
|
||||
local_expert_offset=self.start_expert_id,
|
||||
local_num_experts=self.num_local_experts,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
tile_tokens_dim=_get_tile_tokens_dim(
|
||||
tile_tokens_dim=get_tile_tokens_dim(
|
||||
hidden_states.shape[0], self.top_k, self.num_experts
|
||||
),
|
||||
routing_method_type=2, # DeepSeek-styled routing method
|
||||
@@ -779,9 +767,6 @@ def get_moe_impl_class():
|
||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
|
||||
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
|
||||
return FusedMoE
|
||||
if use_flashinfer_trtllm_moe:
|
||||
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
|
||||
return FlashInferEPMoE
|
||||
if global_server_args_dict["enable_ep_moe"]:
|
||||
return EPMoE
|
||||
return FusedMoE
|
||||
return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE
|
||||
return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from packaging import version as pkg_version
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_rank,
|
||||
@@ -33,6 +36,15 @@ _is_cpu = is_cpu()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def should_use_flashinfer_trtllm_moe():
|
||||
return global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
|
||||
not importlib.util.find_spec("flashinfer")
|
||||
or pkg_version.parse(__import__("flashinfer").__version__)
|
||||
>= pkg_version.parse("0.2.9rc1")
|
||||
)
|
||||
|
||||
|
||||
class FusedMoeWeightScaleSupported(Enum):
|
||||
TENSOR = "tensor"
|
||||
CHANNEL = "channel"
|
||||
@@ -455,7 +467,7 @@ class FusedMoE(torch.nn.Module):
|
||||
)
|
||||
|
||||
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
|
||||
if getattr(self, "use_flashinfer_trtllm_moe", False):
|
||||
if should_use_flashinfer_trtllm_moe():
|
||||
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
|
||||
|
||||
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
|
||||
@@ -687,3 +699,44 @@ class FusedMoE(torch.nn.Module):
|
||||
for expert_id in range(num_experts)
|
||||
for shard_id in ["w1", "w2", "w3"]
|
||||
]
|
||||
|
||||
|
||||
class FlashInferFusedMoE(FusedMoE):
|
||||
def __init__(self, *args, **kwargs):
|
||||
renormalize = kwargs.pop("renormalize", True)
|
||||
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
|
||||
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
|
||||
num_expert_group = kwargs.pop("num_expert_group", None)
|
||||
topk_group = kwargs.pop("topk_group", None)
|
||||
correction_bias = kwargs.pop("correction_bias", None)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.renormalize = renormalize
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
if self.use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.correction_bias = correction_bias
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||
assert self.quant_method is not None
|
||||
assert (
|
||||
self.renormalize
|
||||
), "Renormalize is required for flashinfer blockscale fp8 moe"
|
||||
assert (
|
||||
self.num_fused_shared_experts == 0
|
||||
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply_with_router_logits(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
activation=self.activation,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
@@ -72,6 +72,7 @@ from sglang.srt.utils import (
|
||||
is_hip,
|
||||
is_npu,
|
||||
log_info_on_rank0,
|
||||
next_power_of_2,
|
||||
print_warning_once,
|
||||
set_weight_attrs,
|
||||
use_intel_amx_backend,
|
||||
@@ -490,6 +491,16 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
)
|
||||
|
||||
|
||||
def get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||
# Guess tokens per expert assuming perfect expert distribution first.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
return tile_tokens_dim
|
||||
|
||||
|
||||
class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for FP8.
|
||||
Supports loading FP8 checkpoints with static weight scale and
|
||||
@@ -1076,6 +1087,47 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
def apply_with_router_logits(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
activation == "silu"
|
||||
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
||||
a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1])
|
||||
# NOTE: scales of hidden states have to be transposed!
|
||||
a_sf_t = a_sf.t().contiguous()
|
||||
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
|
||||
|
||||
return trtllm_fp8_block_scale_moe(
|
||||
routing_logits=router_logits.to(torch.float32),
|
||||
routing_bias=layer.correction_bias.to(x.dtype),
|
||||
hidden_states=a_q,
|
||||
hidden_states_scale=a_sf_t,
|
||||
gemm1_weights=layer.w13_weight,
|
||||
gemm1_weights_scale=layer.w13_weight_scale_inv,
|
||||
gemm2_weights=layer.w2_weight,
|
||||
gemm2_weights_scale=layer.w2_weight_scale_inv,
|
||||
num_experts=layer.num_experts,
|
||||
top_k=layer.top_k,
|
||||
n_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
intermediate_size=layer.w2_weight.shape[2],
|
||||
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
|
||||
local_num_experts=layer.num_local_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
tile_tokens_dim=get_tile_tokens_dim(
|
||||
x.shape[0], layer.top_k, layer.num_experts
|
||||
),
|
||||
routing_method_type=2, # DeepSeek-styled routing method
|
||||
use_shuffled_weight=False,
|
||||
)
|
||||
|
||||
def maybe_apply_hip_fused_experts(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
|
||||
@@ -59,7 +59,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import (
|
||||
DeepEPMoE,
|
||||
get_moe_impl_class,
|
||||
use_flashinfer_trtllm_moe,
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
@@ -317,7 +317,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
if not use_flashinfer_trtllm_moe
|
||||
if not should_use_flashinfer_trtllm_moe()
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -352,11 +352,10 @@ class DeepseekV2MoE(nn.Module):
|
||||
renormalize=config.norm_topk_prob,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
)
|
||||
if use_flashinfer_trtllm_moe
|
||||
if should_use_flashinfer_trtllm_moe()
|
||||
else {}
|
||||
),
|
||||
)
|
||||
|
||||
@@ -52,7 +52,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import (
|
||||
DeepEPMoE,
|
||||
get_moe_impl_class,
|
||||
use_flashinfer_trtllm_moe,
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
@@ -426,7 +426,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
if not use_flashinfer_trtllm_moe
|
||||
if not should_use_flashinfer_trtllm_moe()
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -465,7 +465,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
)
|
||||
if use_flashinfer_trtllm_moe
|
||||
if should_use_flashinfer_trtllm_moe()
|
||||
else {}
|
||||
),
|
||||
)
|
||||
|
||||
@@ -460,10 +460,6 @@ class ServerArgs:
|
||||
f"Flashinfer cutlass MoE and EP MoE are enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
||||
)
|
||||
|
||||
if self.enable_flashinfer_trtllm_moe:
|
||||
assert self.enable_ep_moe, "EP MoE is required for Flashinfer TRTLLM MOE"
|
||||
logger.warning(f"Flashinfer TRTLLM MoE is enabled.")
|
||||
|
||||
# DeepEP MoE
|
||||
if self.enable_deepep_moe:
|
||||
if self.deepep_mode == "normal":
|
||||
|
||||
Reference in New Issue
Block a user