[NVIDIA] Update to leverage flashinfer trtllm FP4 MOE throughput kernel (#11563)
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
This commit is contained in:
@@ -49,7 +49,6 @@ from sglang.srt.utils import (
|
||||
is_cpu,
|
||||
is_flashinfer_available,
|
||||
is_hip,
|
||||
next_power_of_2,
|
||||
round_up,
|
||||
)
|
||||
|
||||
@@ -72,16 +71,6 @@ if should_use_flashinfer_trtllm_moe():
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||
# Guess tokens per expert assuming perfect expert distribution first.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
return tile_tokens_dim
|
||||
|
||||
|
||||
def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher:
|
||||
a2a_backend = get_moe_a2a_backend()
|
||||
if a2a_backend.is_none():
|
||||
@@ -1080,9 +1069,7 @@ class FlashInferFP4MoE(FusedMoE):
|
||||
local_expert_offset=self.moe_ep_rank * self.num_local_experts,
|
||||
local_num_experts=self.num_local_experts,
|
||||
routed_scaling_factor=self.moe_runner_config.routed_scaling_factor,
|
||||
tile_tokens_dim=_get_tile_tokens_dim(
|
||||
hidden_states.shape[0], topk_config.top_k, self.num_local_experts
|
||||
),
|
||||
tile_tokens_dim=None,
|
||||
routing_method_type=RoutingMethodType.DeepSeekV3,
|
||||
do_finalize=True,
|
||||
)[0]
|
||||
|
||||
@@ -41,7 +41,6 @@ from sglang.srt.utils import (
|
||||
is_triton_kernels_available,
|
||||
log_info_on_rank0,
|
||||
mxfp_supported,
|
||||
next_power_of_2,
|
||||
round_up,
|
||||
set_weight_attrs,
|
||||
)
|
||||
@@ -597,30 +596,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight = Parameter(w2_weight.data, requires_grad=False)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
|
||||
# Number of tokens in the input tensor.
|
||||
num_tokens = x.shape[0]
|
||||
# Factor to account for the imbalance of the experts.
|
||||
# factor equals to the
|
||||
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
|
||||
# - 1.0 means perfect expert distribution.
|
||||
# - > 1.0 means some experts have more
|
||||
# tokens than the perfect distribution.
|
||||
# - < 1.0 does not make sense.
|
||||
imbalance_factor = 1.3
|
||||
# Calculate the number of tokens per expert
|
||||
# assuming perfect distribution.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
|
||||
# Apply the imbalance factor.
|
||||
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile
|
||||
# as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
|
||||
return tile_tokens_dim
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
@@ -696,7 +671,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
|
||||
layer.num_local_experts, # local num experts
|
||||
None,
|
||||
self._get_tile_tokens_dim(x, top_k),
|
||||
None, # tile_tokens_dim
|
||||
1, # routing_method_type, renormalize
|
||||
True, # do finalize
|
||||
)[0]
|
||||
|
||||
Reference in New Issue
Block a user