[Bug] Fix input arguments of flashinfer_trtllm_moe (#9317)
This commit is contained in:
@@ -932,11 +932,11 @@ class FlashInferFusedMoE(FusedMoE):
|
|||||||
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||||
assert self.use_flashinfer_trtllm_moe
|
assert self.use_flashinfer_trtllm_moe
|
||||||
assert (
|
assert (
|
||||||
self.activation == "silu"
|
self.moe_runner_config.activation == "silu"
|
||||||
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
assert (
|
assert (
|
||||||
self.renormalize
|
topk_output.topk_config.renormalize
|
||||||
), "Renormalize is required for flashinfer blockscale fp8 moe"
|
), "Renormalize is required for flashinfer blockscale fp8 moe"
|
||||||
assert (
|
assert (
|
||||||
self.num_fused_shared_experts == 0
|
self.num_fused_shared_experts == 0
|
||||||
|
|||||||
@@ -85,8 +85,8 @@ if _is_npu:
|
|||||||
class TopKConfig:
|
class TopKConfig:
|
||||||
top_k: int
|
top_k: int
|
||||||
use_grouped_topk: bool = False
|
use_grouped_topk: bool = False
|
||||||
topk_group: int = 0
|
topk_group: Optional[int] = None
|
||||||
num_expert_group: int = 0
|
num_expert_group: Optional[int] = None
|
||||||
renormalize: bool = True
|
renormalize: bool = True
|
||||||
num_fused_shared_experts: int = 0
|
num_fused_shared_experts: int = 0
|
||||||
custom_routing_function: Optional[Callable] = None
|
custom_routing_function: Optional[Callable] = None
|
||||||
@@ -189,8 +189,8 @@ class TopK(CustomOp):
|
|||||||
top_k: int,
|
top_k: int,
|
||||||
*,
|
*,
|
||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
topk_group: int = 0,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: int = 0,
|
num_expert_group: Optional[int] = None,
|
||||||
renormalize: bool = True,
|
renormalize: bool = True,
|
||||||
num_fused_shared_experts: int = 0,
|
num_fused_shared_experts: int = 0,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
@@ -427,8 +427,8 @@ def grouped_topk_gpu(
|
|||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
topk: int,
|
topk: int,
|
||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
num_expert_group: int = 0,
|
num_expert_group: Optional[int] = None,
|
||||||
topk_group: int = 0,
|
topk_group: Optional[int] = None,
|
||||||
num_fused_shared_experts: int = 0,
|
num_fused_shared_experts: int = 0,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||||
@@ -492,8 +492,8 @@ def grouped_topk_cpu(
|
|||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
topk: int,
|
topk: int,
|
||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
num_expert_group: int = 0,
|
num_expert_group: Optional[int] = None,
|
||||||
topk_group: int = 0,
|
topk_group: Optional[int] = None,
|
||||||
num_fused_shared_experts: int = 0,
|
num_fused_shared_experts: int = 0,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||||
@@ -522,8 +522,8 @@ def biased_grouped_topk_impl(
|
|||||||
correction_bias: torch.Tensor,
|
correction_bias: torch.Tensor,
|
||||||
topk: int,
|
topk: int,
|
||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
num_expert_group: int = 0,
|
num_expert_group: Optional[int] = None,
|
||||||
topk_group: int = 0,
|
topk_group: Optional[int] = None,
|
||||||
num_fused_shared_experts: int = 0,
|
num_fused_shared_experts: int = 0,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||||
@@ -615,8 +615,8 @@ def biased_grouped_topk_gpu(
|
|||||||
correction_bias: torch.Tensor,
|
correction_bias: torch.Tensor,
|
||||||
topk: int,
|
topk: int,
|
||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
num_expert_group: int = 0,
|
num_expert_group: Optional[int] = None,
|
||||||
topk_group: int = 0,
|
topk_group: Optional[int] = None,
|
||||||
num_fused_shared_experts: int = 0,
|
num_fused_shared_experts: int = 0,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||||
@@ -690,8 +690,8 @@ def biased_grouped_topk_cpu(
|
|||||||
correction_bias: torch.Tensor,
|
correction_bias: torch.Tensor,
|
||||||
topk: int,
|
topk: int,
|
||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
num_expert_group: int = 0,
|
num_expert_group: Optional[int] = None,
|
||||||
topk_group: int = 0,
|
topk_group: Optional[int] = None,
|
||||||
compiled: bool = True,
|
compiled: bool = True,
|
||||||
num_fused_shared_experts: int = 0,
|
num_fused_shared_experts: int = 0,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
|||||||
@@ -445,7 +445,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
return apply_fp8_marlin_linear(
|
return apply_fp8_marlin_linear(
|
||||||
input=x,
|
input=x,
|
||||||
@@ -1087,7 +1086,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
moe_runner_config: MoeRunnerConfig,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
activation = moe_runner_config.activation
|
activation = moe_runner_config.activation
|
||||||
routed_scaling_factor = moe_runner_config.routed_scaling_factor
|
routed_scaling_factor = moe_runner_config.routed_scaling_factor
|
||||||
|
|
||||||
@@ -1105,9 +1103,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
# NOTE: scales of hidden states have to be transposed!
|
# NOTE: scales of hidden states have to be transposed!
|
||||||
a_sf_t = a_sf.t().contiguous()
|
a_sf_t = a_sf.t().contiguous()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
topk_config.num_expert_group is not None
|
||||||
|
and topk_config.topk_group is not None
|
||||||
|
), "Current trtllm_fp8_block_scale_moe kernel does not support these two arguments as None"
|
||||||
|
|
||||||
|
if topk_config.correction_bias is None:
|
||||||
|
correction_bias = topk_config.correction_bias.to(x.dtype)
|
||||||
|
else:
|
||||||
|
correction_bias = None
|
||||||
return trtllm_fp8_block_scale_moe(
|
return trtllm_fp8_block_scale_moe(
|
||||||
routing_logits=router_logits.to(torch.float32),
|
routing_logits=router_logits.to(torch.float32),
|
||||||
routing_bias=layer.correction_bias.to(x.dtype),
|
routing_bias=correction_bias,
|
||||||
hidden_states=a_q,
|
hidden_states=a_q,
|
||||||
hidden_states_scale=a_sf_t,
|
hidden_states_scale=a_sf_t,
|
||||||
gemm1_weights=layer.w13_weight,
|
gemm1_weights=layer.w13_weight,
|
||||||
@@ -1121,9 +1128,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
intermediate_size=layer.w2_weight.shape[2],
|
intermediate_size=layer.w2_weight.shape[2],
|
||||||
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
|
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
|
||||||
local_num_experts=layer.num_local_experts,
|
local_num_experts=layer.num_local_experts,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=(
|
||||||
|
routed_scaling_factor if routed_scaling_factor is not None else 1.0
|
||||||
|
),
|
||||||
tile_tokens_dim=get_tile_tokens_dim(
|
tile_tokens_dim=get_tile_tokens_dim(
|
||||||
x.shape[0], layer.top_k, layer.num_experts
|
x.shape[0], topk_config.top_k, layer.num_experts
|
||||||
),
|
),
|
||||||
routing_method_type=2, # DeepSeek-styled routing method
|
routing_method_type=2, # DeepSeek-styled routing method
|
||||||
use_shuffled_weight=False,
|
use_shuffled_weight=False,
|
||||||
|
|||||||
Reference in New Issue
Block a user