[NVIDIA] Fix breakage of using trtllm-gen fp8 moe (#8773)
This commit is contained in:
@@ -673,66 +673,6 @@ class DeepEPMoE(EPMoE):
|
|||||||
return down_output
|
return down_output
|
||||||
|
|
||||||
|
|
||||||
class FlashInferEPMoE(EPMoE):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
renormalize = kwargs.pop("renormalize", True)
|
|
||||||
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
|
|
||||||
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
|
|
||||||
num_expert_group = kwargs.pop("num_expert_group", None)
|
|
||||||
topk_group = kwargs.pop("topk_group", None)
|
|
||||||
correction_bias = kwargs.pop("correction_bias", None)
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.renormalize = renormalize
|
|
||||||
self.num_fused_shared_experts = num_fused_shared_experts
|
|
||||||
self.use_grouped_topk = use_grouped_topk
|
|
||||||
if self.use_grouped_topk:
|
|
||||||
assert num_expert_group is not None and topk_group is not None
|
|
||||||
self.num_expert_group = num_expert_group
|
|
||||||
self.topk_group = topk_group
|
|
||||||
self.correction_bias = correction_bias
|
|
||||||
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
|
||||||
assert self.use_flashinfer_trtllm_moe
|
|
||||||
assert (
|
|
||||||
self.activation == "silu"
|
|
||||||
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
|
||||||
assert (
|
|
||||||
self.renormalize
|
|
||||||
), "Renormalize is required for flashinfer blockscale fp8 moe"
|
|
||||||
assert (
|
|
||||||
self.num_fused_shared_experts == 0
|
|
||||||
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
|
|
||||||
a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
|
|
||||||
# NOTE: scales of hidden states have to be transposed!
|
|
||||||
a_sf_t = a_sf.t().contiguous()
|
|
||||||
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
|
|
||||||
|
|
||||||
return trtllm_fp8_block_scale_moe(
|
|
||||||
routing_logits=router_logits.to(torch.float32),
|
|
||||||
routing_bias=self.correction_bias.to(hidden_states.dtype),
|
|
||||||
hidden_states=a_q,
|
|
||||||
hidden_states_scale=a_sf_t,
|
|
||||||
gemm1_weights=self.w13_weight,
|
|
||||||
gemm1_weights_scale=self.w13_weight_scale_inv,
|
|
||||||
gemm2_weights=self.w2_weight,
|
|
||||||
gemm2_weights_scale=self.w2_weight_scale_inv,
|
|
||||||
num_experts=self.num_experts,
|
|
||||||
top_k=self.top_k,
|
|
||||||
n_group=self.num_expert_group,
|
|
||||||
topk_group=self.topk_group,
|
|
||||||
intermediate_size=self.w2_weight.shape[2],
|
|
||||||
local_expert_offset=self.start_expert_id,
|
|
||||||
local_num_experts=self.num_local_experts,
|
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
|
||||||
tile_tokens_dim=get_tile_tokens_dim(
|
|
||||||
hidden_states.shape[0], self.top_k, self.num_experts
|
|
||||||
),
|
|
||||||
routing_method_type=2, # DeepSeek-styled routing method
|
|
||||||
use_shuffled_weight=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_moe_impl_class():
|
def get_moe_impl_class():
|
||||||
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||||
return DeepEPMoE
|
return DeepEPMoE
|
||||||
@@ -752,8 +692,10 @@ def get_moe_impl_class():
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if should_use_flashinfer_trtllm_moe():
|
||||||
|
return FlashInferFusedMoE
|
||||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
|
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
|
||||||
return FusedMoE
|
return FusedMoE
|
||||||
if get_moe_expert_parallel_world_size() > 1:
|
if get_moe_expert_parallel_world_size() > 1:
|
||||||
return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE
|
return EPMoE
|
||||||
return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE
|
return FusedMoE
|
||||||
|
|||||||
@@ -763,8 +763,13 @@ class FlashInferFusedMoE(FusedMoE):
|
|||||||
self.num_expert_group = num_expert_group
|
self.num_expert_group = num_expert_group
|
||||||
self.topk_group = topk_group
|
self.topk_group = topk_group
|
||||||
self.correction_bias = correction_bias
|
self.correction_bias = correction_bias
|
||||||
|
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
def forward(self, hidden_states: torch.Tensor, topk_output: tuple):
|
||||||
|
assert self.use_flashinfer_trtllm_moe
|
||||||
|
assert (
|
||||||
|
self.activation == "silu"
|
||||||
|
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
assert (
|
assert (
|
||||||
self.renormalize
|
self.renormalize
|
||||||
@@ -772,6 +777,14 @@ class FlashInferFusedMoE(FusedMoE):
|
|||||||
assert (
|
assert (
|
||||||
self.num_fused_shared_experts == 0
|
self.num_fused_shared_experts == 0
|
||||||
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
|
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
|
||||||
|
|
||||||
|
# TRTLLM mode expects (TopK_config, router_logits) tuple
|
||||||
|
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"FlashInferFusedMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
|
||||||
|
)
|
||||||
|
_, router_logits = topk_output
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
final_hidden_states = self.quant_method.apply_with_router_logits(
|
final_hidden_states = self.quant_method.apply_with_router_logits(
|
||||||
layer=self,
|
layer=self,
|
||||||
|
|||||||
Reference in New Issue
Block a user