[Bug] Fix input arguments of flashinfer_trtllm_moe (#9317)
This commit is contained in:
@@ -932,11 +932,11 @@ class FlashInferFusedMoE(FusedMoE):
|
||||
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||
assert self.use_flashinfer_trtllm_moe
|
||||
assert (
|
||||
self.activation == "silu"
|
||||
self.moe_runner_config.activation == "silu"
|
||||
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
||||
assert self.quant_method is not None
|
||||
assert (
|
||||
self.renormalize
|
||||
topk_output.topk_config.renormalize
|
||||
), "Renormalize is required for flashinfer blockscale fp8 moe"
|
||||
assert (
|
||||
self.num_fused_shared_experts == 0
|
||||
|
||||
@@ -85,8 +85,8 @@ if _is_npu:
|
||||
class TopKConfig:
|
||||
top_k: int
|
||||
use_grouped_topk: bool = False
|
||||
topk_group: int = 0
|
||||
num_expert_group: int = 0
|
||||
topk_group: Optional[int] = None
|
||||
num_expert_group: Optional[int] = None
|
||||
renormalize: bool = True
|
||||
num_fused_shared_experts: int = 0
|
||||
custom_routing_function: Optional[Callable] = None
|
||||
@@ -189,8 +189,8 @@ class TopK(CustomOp):
|
||||
top_k: int,
|
||||
*,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int = 0,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
renormalize: bool = True,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
@@ -427,8 +427,8 @@ def grouped_topk_gpu(
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
@@ -492,8 +492,8 @@ def grouped_topk_cpu(
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
@@ -522,8 +522,8 @@ def biased_grouped_topk_impl(
|
||||
correction_bias: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
@@ -615,8 +615,8 @@ def biased_grouped_topk_gpu(
|
||||
correction_bias: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
@@ -690,8 +690,8 @@ def biased_grouped_topk_cpu(
|
||||
correction_bias: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
compiled: bool = True,
|
||||
num_fused_shared_experts: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
|
||||
@@ -445,7 +445,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if self.use_marlin:
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
@@ -1087,7 +1086,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
|
||||
activation = moe_runner_config.activation
|
||||
routed_scaling_factor = moe_runner_config.routed_scaling_factor
|
||||
|
||||
@@ -1105,9 +1103,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
# NOTE: scales of hidden states have to be transposed!
|
||||
a_sf_t = a_sf.t().contiguous()
|
||||
|
||||
assert (
|
||||
topk_config.num_expert_group is not None
|
||||
and topk_config.topk_group is not None
|
||||
), "Current trtllm_fp8_block_scale_moe kernel does not support these two arguments as None"
|
||||
|
||||
if topk_config.correction_bias is None:
|
||||
correction_bias = topk_config.correction_bias.to(x.dtype)
|
||||
else:
|
||||
correction_bias = None
|
||||
return trtllm_fp8_block_scale_moe(
|
||||
routing_logits=router_logits.to(torch.float32),
|
||||
routing_bias=layer.correction_bias.to(x.dtype),
|
||||
routing_bias=correction_bias,
|
||||
hidden_states=a_q,
|
||||
hidden_states_scale=a_sf_t,
|
||||
gemm1_weights=layer.w13_weight,
|
||||
@@ -1121,9 +1128,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
intermediate_size=layer.w2_weight.shape[2],
|
||||
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
|
||||
local_num_experts=layer.num_local_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
routed_scaling_factor=(
|
||||
routed_scaling_factor if routed_scaling_factor is not None else 1.0
|
||||
),
|
||||
tile_tokens_dim=get_tile_tokens_dim(
|
||||
x.shape[0], layer.top_k, layer.num_experts
|
||||
x.shape[0], topk_config.top_k, layer.num_experts
|
||||
),
|
||||
routing_method_type=2, # DeepSeek-styled routing method
|
||||
use_shuffled_weight=False,
|
||||
|
||||
Reference in New Issue
Block a user