Set num_fused_shared_experts as num_shared_experts when shared_experts fusion is not disabled (#6736)

This commit is contained in:
Cheng Wan
2025-06-04 15:53:22 -07:00
committed by GitHub
parent f0f84975f4
commit 81964328b7
22 changed files with 381 additions and 45 deletions

View File

@@ -156,6 +156,7 @@ class EPMoE(torch.nn.Module):
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
@@ -190,6 +191,7 @@ class EPMoE(torch.nn.Module):
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.num_fused_shared_experts = num_fused_shared_experts
self.topk_group = topk_group
self.correction_bias = correction_bias
self.custom_routing_function = custom_routing_function
@@ -250,6 +252,7 @@ class EPMoE(torch.nn.Module):
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function,
routed_scaling_factor=self.routed_scaling_factor,

View File

@@ -21,6 +21,7 @@ def fused_moe_forward_native(
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
@@ -41,6 +42,7 @@ def fused_moe_forward_native(
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
@@ -71,6 +73,7 @@ def moe_forward_native(
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
@@ -84,6 +87,7 @@ def moe_forward_native(
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
torch_native=True,

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
}
}

View File

@@ -1540,6 +1540,7 @@ def fused_moe(
activation: str = "silu",
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False,
@@ -1609,6 +1610,7 @@ def fused_moe(
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
routed_scaling_factor=routed_scaling_factor,
)

View File

@@ -127,6 +127,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
@@ -144,6 +145,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
activation=activation,
@@ -163,6 +165,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
@@ -179,6 +182,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
@@ -232,6 +236,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
inplace: bool = True,
@@ -245,6 +250,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize,
topk_group,
num_expert_group,
num_fused_shared_experts,
custom_routing_function,
correction_bias,
)
@@ -289,6 +295,7 @@ class FusedMoE(torch.nn.Module):
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
@@ -321,6 +328,7 @@ class FusedMoE(torch.nn.Module):
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.num_fused_shared_experts = num_fused_shared_experts
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias
@@ -651,6 +659,7 @@ class FusedMoE(torch.nn.Module):
use_grouped_topk=self.use_grouped_topk,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
activation=self.activation,

View File

@@ -303,6 +303,7 @@ def select_experts(
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
torch_native: bool = False,
@@ -310,7 +311,6 @@ def select_experts(
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
):
num_fused_shared_experts = global_server_args_dict["num_fused_shared_experts"]
router_logits, correction_bias = (
expert_location_dispatch.transform_select_experts_inputs(

View File

@@ -289,6 +289,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",

View File

@@ -367,6 +367,7 @@ class BlockInt8MoEMethod:
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
@@ -387,6 +388,7 @@ class BlockInt8MoEMethod:
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,

View File

@@ -272,6 +272,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
@@ -294,6 +295,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
@@ -627,6 +629,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
@@ -651,6 +654,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
correction_bias=correction_bias,

View File

@@ -937,6 +937,7 @@ class Fp8MoEMethod:
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
@@ -957,6 +958,7 @@ class Fp8MoEMethod:
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,

View File

@@ -341,6 +341,7 @@ class MoeWNA16Method:
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
@@ -362,6 +363,7 @@ class MoeWNA16Method:
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,

View File

@@ -287,6 +287,7 @@ class W8A8FP8MoEMethod:
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
@@ -306,6 +307,7 @@ class W8A8FP8MoEMethod:
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,

View File

@@ -225,6 +225,7 @@ class W8A8Int8MoEMethod:
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
@@ -245,6 +246,7 @@ class W8A8Int8MoEMethod:
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,

View File

@@ -89,7 +89,7 @@ global_server_args_dict = {
"max_micro_batch_size": ServerArgs.max_micro_batch_size,
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
"ep_dispatch_algorithm": ServerArgs.ep_dispatch_algorithm,
"num_fused_shared_experts": ServerArgs.num_fused_shared_experts,
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
"sampling_backend": ServerArgs.sampling_backend,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,

View File

@@ -204,7 +204,7 @@ class ModelRunner:
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"moe_dense_tp_size": server_args.moe_dense_tp_size,
"ep_dispatch_algorithm": server_args.ep_dispatch_algorithm,
"num_fused_shared_experts": server_args.num_fused_shared_experts,
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"torchao_config": server_args.torchao_config,
"sampling_backend": server_args.sampling_backend,

View File

@@ -224,9 +224,11 @@ class DeepseekV2MoE(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
self.num_fused_shared_experts = global_server_args_dict[
"num_fused_shared_experts"
]
self.num_fused_shared_experts = (
0
if global_server_args_dict["disable_shared_experts_fusion"]
else config.n_shared_experts
)
self.config = config
self.layer_id = layer_id
@@ -248,7 +250,7 @@ class DeepseekV2MoE(nn.Module):
num_experts=config.n_routed_experts
+ self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"],
top_k=config.num_experts_per_tok + min(self.num_fused_shared_experts, 1),
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
layer_id=self.layer_id,
@@ -256,6 +258,7 @@ class DeepseekV2MoE(nn.Module):
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
@@ -363,6 +366,7 @@ class DeepseekV2MoE(nn.Module):
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=forward_batch.num_token_non_padded,
@@ -456,6 +460,7 @@ class DeepseekV2MoE(nn.Module):
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=state.forward_batch.num_token_non_padded,
@@ -1679,9 +1684,11 @@ class DeepseekV2ForCausalLM(nn.Module):
def determine_num_fused_shared_experts(
self, architecture: str = "DeepseekV3ForCausalLM"
):
self.num_fused_shared_experts = global_server_args_dict[
"num_fused_shared_experts"
]
self.num_fused_shared_experts = (
0
if global_server_args_dict["disable_shared_experts_fusion"]
else self.config.n_shared_experts
)
if self.num_fused_shared_experts > 0:
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
if (
@@ -1690,15 +1697,11 @@ class DeepseekV2ForCausalLM(nn.Module):
or self.config.n_routed_experts != 256
):
self.num_fused_shared_experts = 0
global_server_args_dict["num_fused_shared_experts"] = 0
global_server_args_dict["disable_shared_experts_fusion"] = 1
log_info_on_rank0(
logger,
"Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
)
else:
assert (
self.num_fused_shared_experts == self.tp_size
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performance."
elif self.num_fused_shared_experts == 0:
if (
_is_cuda
@@ -1707,8 +1710,8 @@ class DeepseekV2ForCausalLM(nn.Module):
and self.config.n_routed_experts == 256
and (not global_server_args_dict["enable_deepep_moe"])
):
self.num_fused_shared_experts = self.tp_size
global_server_args_dict["num_fused_shared_experts"] = self.tp_size
self.num_fused_shared_experts = self.config.n_shared_experts
global_server_args_dict["disable_shared_experts_fusion"] = 0
log_info_on_rank0(
logger,
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
@@ -1910,6 +1913,7 @@ class DeepseekV2ForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1),
]
if self.num_fused_shared_experts > 0:
assert self.num_fused_shared_experts == 1
weights_list = list(weights)
weights_dict = dict(weights_list)
if self.quant_config is not None:
@@ -1971,22 +1975,21 @@ class DeepseekV2ForCausalLM(nn.Module):
for moe_layer in tqdm(
moe_layers,
desc=f"Cloning {self.num_fused_shared_experts} "
"replicas of the shared expert into MoE",
"shared expert into MoE",
):
for suffix in suffix_list:
shared_expert_weight_name = (
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
)
for num_repeat in range(self.num_fused_shared_experts):
weights_list.append(
(
f"model.layers.{moe_layer}."
f"mlp.experts."
f"{self.config.n_routed_experts + num_repeat}"
f".{suffix}",
weights_dict[shared_expert_weight_name],
)
weights_list.append(
(
f"model.layers.{moe_layer}."
f"mlp.experts."
f"{self.config.n_routed_experts + 0}"
f".{suffix}",
weights_dict[shared_expert_weight_name],
)
)
names_to_remove += [shared_expert_weight_name]
weights = [w for w in weights_list if w[0] not in names_to_remove]

View File

@@ -207,7 +207,7 @@ class ServerArgs:
flashinfer_mla_disable_ragged: bool = False
warmups: Optional[str] = None
moe_dense_tp_size: Optional[int] = None
num_fused_shared_experts: int = 0
disable_shared_experts_fusion: bool = False
disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False
mm_attention_backend: Optional[str] = None
@@ -1384,13 +1384,10 @@ class ServerArgs:
default=ServerArgs.deepep_config,
help="Tuned DeepEP config suitable for your own cluster. It can be either a string with JSON content or a file path.",
)
parser.add_argument(
"--num-fused-shared-experts",
type=int,
default=0,
help="The number of shared_experts need to be replicated to fuse with normal experts in deepseek v3/r1, "
"set it to tp_size can get best optimized performance. Note that for architectures with SM==90, we have enabled the shared experts fusion optimization by default for DeepSeek V3/R1, with num_fused_shared_experts automatically set to the TP size.",
"--disable-shared-experts-fusion",
action="store_true",
help="Disable shared experts fusion optimization for deepseek v3/r1.",
)
parser.add_argument(
"--disable-chunked-prefix-cache",