[NVIDIA] Fix breakage of using trtllm-gen fp8 moe (#8773)
This commit is contained in:
@@ -673,66 +673,6 @@ class DeepEPMoE(EPMoE):
|
||||
return down_output
|
||||
|
||||
|
||||
class FlashInferEPMoE(EPMoE):
|
||||
def __init__(self, *args, **kwargs):
|
||||
renormalize = kwargs.pop("renormalize", True)
|
||||
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
|
||||
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
|
||||
num_expert_group = kwargs.pop("num_expert_group", None)
|
||||
topk_group = kwargs.pop("topk_group", None)
|
||||
correction_bias = kwargs.pop("correction_bias", None)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.renormalize = renormalize
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
if self.use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.correction_bias = correction_bias
|
||||
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||
assert self.use_flashinfer_trtllm_moe
|
||||
assert (
|
||||
self.activation == "silu"
|
||||
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
||||
assert (
|
||||
self.renormalize
|
||||
), "Renormalize is required for flashinfer blockscale fp8 moe"
|
||||
assert (
|
||||
self.num_fused_shared_experts == 0
|
||||
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
|
||||
a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
|
||||
# NOTE: scales of hidden states have to be transposed!
|
||||
a_sf_t = a_sf.t().contiguous()
|
||||
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
|
||||
|
||||
return trtllm_fp8_block_scale_moe(
|
||||
routing_logits=router_logits.to(torch.float32),
|
||||
routing_bias=self.correction_bias.to(hidden_states.dtype),
|
||||
hidden_states=a_q,
|
||||
hidden_states_scale=a_sf_t,
|
||||
gemm1_weights=self.w13_weight,
|
||||
gemm1_weights_scale=self.w13_weight_scale_inv,
|
||||
gemm2_weights=self.w2_weight,
|
||||
gemm2_weights_scale=self.w2_weight_scale_inv,
|
||||
num_experts=self.num_experts,
|
||||
top_k=self.top_k,
|
||||
n_group=self.num_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
intermediate_size=self.w2_weight.shape[2],
|
||||
local_expert_offset=self.start_expert_id,
|
||||
local_num_experts=self.num_local_experts,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
tile_tokens_dim=get_tile_tokens_dim(
|
||||
hidden_states.shape[0], self.top_k, self.num_experts
|
||||
),
|
||||
routing_method_type=2, # DeepSeek-styled routing method
|
||||
use_shuffled_weight=False,
|
||||
)
|
||||
|
||||
|
||||
def get_moe_impl_class():
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||
return DeepEPMoE
|
||||
@@ -752,8 +692,10 @@ def get_moe_impl_class():
|
||||
except:
|
||||
pass
|
||||
|
||||
if should_use_flashinfer_trtllm_moe():
|
||||
return FlashInferFusedMoE
|
||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
|
||||
return FusedMoE
|
||||
if get_moe_expert_parallel_world_size() > 1:
|
||||
return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE
|
||||
return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE
|
||||
return EPMoE
|
||||
return FusedMoE
|
||||
|
||||
@@ -763,8 +763,13 @@ class FlashInferFusedMoE(FusedMoE):
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.correction_bias = correction_bias
|
||||
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||
def forward(self, hidden_states: torch.Tensor, topk_output: tuple):
|
||||
assert self.use_flashinfer_trtllm_moe
|
||||
assert (
|
||||
self.activation == "silu"
|
||||
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
||||
assert self.quant_method is not None
|
||||
assert (
|
||||
self.renormalize
|
||||
@@ -772,6 +777,14 @@ class FlashInferFusedMoE(FusedMoE):
|
||||
assert (
|
||||
self.num_fused_shared_experts == 0
|
||||
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
|
||||
|
||||
# TRTLLM mode expects (TopK_config, router_logits) tuple
|
||||
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
|
||||
raise ValueError(
|
||||
f"FlashInferFusedMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
|
||||
)
|
||||
_, router_logits = topk_output
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply_with_router_logits(
|
||||
layer=self,
|
||||
|
||||
Reference in New Issue
Block a user