[Refactor] Rename n_share_experts_fusion as num_fused_shared_experts (#6735)

This commit is contained in:
Cheng Wan
2025-06-03 17:48:24 -07:00
committed by GitHub
parent b6d0ce9f78
commit 8a5480528d
14 changed files with 82 additions and 93 deletions

View File

@@ -103,7 +103,7 @@ def grouped_topk(
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
n_share_experts_fusion: int = 0,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
@@ -128,10 +128,10 @@ def grouped_topk(
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if n_share_experts_fusion:
if num_fused_shared_experts:
topk_ids[:, -1] = torch.randint(
low=num_experts,
high=num_experts + n_share_experts_fusion,
high=num_experts + num_fused_shared_experts,
size=(topk_ids.size(0),),
dtype=topk_ids.dtype,
device=topk_ids.device,
@@ -141,7 +141,7 @@ def grouped_topk(
if renormalize:
topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if n_share_experts_fusion == 0
if num_fused_shared_experts == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
@@ -160,7 +160,7 @@ def biased_grouped_topk_impl(
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
n_share_experts_fusion: int = 0,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
@@ -192,10 +192,10 @@ def biased_grouped_topk_impl(
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
topk_weights = scores.gather(1, topk_ids)
if n_share_experts_fusion:
if num_fused_shared_experts:
topk_ids[:, -1] = torch.randint(
low=num_experts,
high=num_experts + n_share_experts_fusion,
high=num_experts + num_fused_shared_experts,
size=(topk_ids.size(0),),
dtype=topk_ids.dtype,
device=topk_ids.device,
@@ -205,7 +205,7 @@ def biased_grouped_topk_impl(
if renormalize:
topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if n_share_experts_fusion == 0
if num_fused_shared_experts == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
@@ -239,7 +239,7 @@ def biased_grouped_topk(
num_expert_group: int = 0,
topk_group: int = 0,
compiled: bool = True,
n_share_experts_fusion: int = 0,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
@@ -247,7 +247,7 @@ def biased_grouped_topk(
assert (
routed_scaling_factor is not None
), "routed_scaling_factor is required for biased_grouped_topk"
# TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
# TODO: moe_fused_gate kernel is not supported for num_fused_shared_experts > 0 now.
if (
_is_cuda
and gating_output.shape[1] // num_expert_group
@@ -260,7 +260,7 @@ def biased_grouped_topk(
num_expert_group,
topk_group,
topk,
n_share_experts_fusion,
num_fused_shared_experts,
routed_scaling_factor,
)
# TODO merge into kernel for this branch
@@ -288,7 +288,7 @@ def biased_grouped_topk(
renormalize,
num_expert_group,
topk_group,
n_share_experts_fusion=n_share_experts_fusion,
num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
@@ -310,7 +310,7 @@ def select_experts(
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
):
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
num_fused_shared_experts = global_server_args_dict["num_fused_shared_experts"]
router_logits, correction_bias = (
expert_location_dispatch.transform_select_experts_inputs(
@@ -332,7 +332,7 @@ def select_experts(
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
n_share_experts_fusion=n_share_experts_fusion,
num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
@@ -346,7 +346,7 @@ def select_experts(
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
n_share_experts_fusion=n_share_experts_fusion,
num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,

View File

@@ -89,7 +89,7 @@ global_server_args_dict = {
"max_micro_batch_size": ServerArgs.max_micro_batch_size,
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
"ep_dispatch_algorithm": ServerArgs.ep_dispatch_algorithm,
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
"num_fused_shared_experts": ServerArgs.num_fused_shared_experts,
"sampling_backend": ServerArgs.sampling_backend,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,

View File

@@ -204,7 +204,7 @@ class ModelRunner:
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"moe_dense_tp_size": server_args.moe_dense_tp_size,
"ep_dispatch_algorithm": server_args.ep_dispatch_algorithm,
"n_share_experts_fusion": server_args.n_share_experts_fusion,
"num_fused_shared_experts": server_args.num_fused_shared_experts,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"torchao_config": server_args.torchao_config,
"sampling_backend": server_args.sampling_backend,

View File

@@ -122,7 +122,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.determine_n_share_experts_fusion("DeepseekV3ForCausalLMNextN")
self.determine_num_fused_shared_experts("DeepseekV3ForCausalLMNextN")
self.model = DeepseekModelNextN(
config, quant_config, prefix=add_prefix("model", prefix)

View File

@@ -224,7 +224,9 @@ class DeepseekV2MoE(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
self.num_fused_shared_experts = global_server_args_dict[
"num_fused_shared_experts"
]
self.config = config
self.layer_id = layer_id
@@ -244,9 +246,9 @@ class DeepseekV2MoE(nn.Module):
self.experts = get_moe_impl_class()(
num_experts=config.n_routed_experts
+ self.n_share_experts_fusion
+ self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"],
top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
top_k=config.num_experts_per_tok + min(self.num_fused_shared_experts, 1),
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
layer_id=self.layer_id,
@@ -265,7 +267,7 @@ class DeepseekV2MoE(nn.Module):
),
)
if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
# disable tp for shared experts when enable deepep moe
self.shared_experts = DeepseekV2MLP(
@@ -418,7 +420,7 @@ class DeepseekV2MoE(nn.Module):
return final_hidden_states
def _forward_shared_experts(self, hidden_states):
if self.n_share_experts_fusion == 0:
if self.num_fused_shared_experts == 0:
return self.shared_experts(hidden_states)
else:
return None
@@ -434,7 +436,7 @@ class DeepseekV2MoE(nn.Module):
def op_shared_experts(self, state):
hidden_states_mlp_input = state.pop("hidden_states_mlp_input")
if (self.n_share_experts_fusion == 0) and is_non_idle_and_non_empty(
if (self.num_fused_shared_experts == 0) and is_non_idle_and_non_empty(
state.forward_batch.forward_mode, hidden_states_mlp_input
):
state.shared_output = self.shared_experts(hidden_states_mlp_input)
@@ -1648,7 +1650,7 @@ class DeepseekV2ForCausalLM(nn.Module):
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.determine_n_share_experts_fusion()
self.determine_num_fused_shared_experts()
self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
@@ -1674,28 +1676,30 @@ class DeepseekV2ForCausalLM(nn.Module):
def routed_experts_weights_of_layer(self):
return self._routed_experts_weights_of_layer.value
def determine_n_share_experts_fusion(
def determine_num_fused_shared_experts(
self, architecture: str = "DeepseekV3ForCausalLM"
):
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
if self.n_share_experts_fusion > 0:
self.num_fused_shared_experts = global_server_args_dict[
"num_fused_shared_experts"
]
if self.num_fused_shared_experts > 0:
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
if (
not _is_cuda
or self.config.architectures[0] != architecture
or self.config.n_routed_experts != 256
):
self.n_share_experts_fusion = 0
global_server_args_dict["n_share_experts_fusion"] = 0
self.num_fused_shared_experts = 0
global_server_args_dict["num_fused_shared_experts"] = 0
log_info_on_rank0(
logger,
"Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
)
else:
assert (
self.n_share_experts_fusion == self.tp_size
self.num_fused_shared_experts == self.tp_size
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performance."
elif self.n_share_experts_fusion == 0:
elif self.num_fused_shared_experts == 0:
if (
_is_cuda
and torch.cuda.get_device_capability("cuda") >= (9, 0)
@@ -1703,8 +1707,8 @@ class DeepseekV2ForCausalLM(nn.Module):
and self.config.n_routed_experts == 256
and (not global_server_args_dict["enable_deepep_moe"])
):
self.n_share_experts_fusion = self.tp_size
global_server_args_dict["n_share_experts_fusion"] = self.tp_size
self.num_fused_shared_experts = self.tp_size
global_server_args_dict["num_fused_shared_experts"] = self.tp_size
log_info_on_rank0(
logger,
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
@@ -1905,7 +1909,7 @@ class DeepseekV2ForCausalLM(nn.Module):
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
if self.n_share_experts_fusion > 0:
if self.num_fused_shared_experts > 0:
weights_list = list(weights)
weights_dict = dict(weights_list)
if self.quant_config is not None:
@@ -1966,14 +1970,14 @@ class DeepseekV2ForCausalLM(nn.Module):
for moe_layer in tqdm(
moe_layers,
desc=f"Cloning {self.n_share_experts_fusion} "
desc=f"Cloning {self.num_fused_shared_experts} "
"replicas of the shared expert into MoE",
):
for suffix in suffix_list:
shared_expert_weight_name = (
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
)
for num_repeat in range(self.n_share_experts_fusion):
for num_repeat in range(self.num_fused_shared_experts):
weights_list.append(
(
f"model.layers.{moe_layer}."
@@ -1992,7 +1996,7 @@ class DeepseekV2ForCausalLM(nn.Module):
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
)
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None

View File

@@ -206,7 +206,7 @@ class ServerArgs:
flashinfer_mla_disable_ragged: bool = False
warmups: Optional[str] = None
moe_dense_tp_size: Optional[int] = None
n_share_experts_fusion: int = 0
num_fused_shared_experts: int = 0
disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False
mm_attention_backend: Optional[str] = None
@@ -1373,11 +1373,11 @@ class ServerArgs:
)
parser.add_argument(
"--n-share-experts-fusion",
"--num-fused-shared-experts",
type=int,
default=0,
help="The number of shared_experts need to be replicated to fuse with normal experts in deepseek v3/r1, "
"set it to tp_size can get best optimized performance. Note that for architectures with SM==90, we have enabled the shared experts fusion optimization by default for DeepSeek V3/R1, with n_share_experts_fusion automatically set to the TP size.",
"set it to tp_size can get best optimized performance. Note that for architectures with SM==90, we have enabled the shared experts fusion optimization by default for DeepSeek V3/R1, with num_fused_shared_experts automatically set to the TP size.",
)
parser.add_argument(
"--disable-chunked-prefix-cache",