support deepseek quant & mix-parallel with graphmode (#585)

### What this PR does / why we need it?
1. support deepseek with w8a8 quant;
2. support deepseek with mix-parallel(multi-DP, EP+TP);
3. support deepseek with graphmode.
---------

Signed-off-by: wen-jie666 <wenjie39@huawei.com>
Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com>
Signed-off-by: libaokui <libaokui@huawei.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: wen-jie666 <wenjie39@huawei.com>
This commit is contained in:
zzzzwwjj
2025-04-23 16:23:25 +08:00
committed by GitHub
parent e74331a1ed
commit 5c6d05a59e
13 changed files with 520 additions and 221 deletions

View File

@@ -330,17 +330,16 @@ def native_grouped_topk(
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: Optional[bool] = True
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Select top-k experts based on router logits.
@@ -364,7 +363,6 @@ def select_experts(
Raises:
ValueError: If an unsupported scoring function is provided.
"""
if custom_routing_function is not None:
raise NotImplementedError(
"Custom routing function is not supported now")
@@ -466,21 +464,36 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
is_prefill=False,
**kwargs,
):
# set prefill as false always, should fix this
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
is_prefill=is_prefill)
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk当前写8
bias=e_score_correction_bias,
k_group=topk_group, # fix: 4
group_count=num_expert_group, # fix 8
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; 第三个输出是否输出
# y2_flag=False, # old api; 第三个输出是否输出
routed_scaling_factor=1,
eps=float(1e-20))
else:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
if os.environ.get("VLLM_ENABLE_MC2") == "1" and not is_prefill:
if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill:
return fused_experts_with_mc2(
hidden_states=x,
w1=layer.w13_weight,
@@ -611,10 +624,11 @@ class AscendFusedMoE(FusedMoE):
real_top_k = self.top_k
if self.dp_size > 1:
if int(os.environ.get("VLLM_ENABLE_MC2") # type: ignore
if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore
) == 1 and not is_prefill:
...
elif int(os.environ.get("USING_LCCL_COM")) == 1: # type: ignore
elif int(os.environ.get("USING_LCCL_COM",
'0')) == 1: # type: ignore
hidden_states = get_dp_group().all_gather(
hidden_states, 0, False)
router_logits = get_dp_group().all_gather(
@@ -631,7 +645,7 @@ class AscendFusedMoE(FusedMoE):
top_k=real_top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.num_experts,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
@@ -641,7 +655,7 @@ class AscendFusedMoE(FusedMoE):
is_prefill=is_prefill)
if self.dp_size > 1:
if int(os.environ.get("VLLM_ENABLE_MC2") # type: ignore
if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore
) == 1 and not is_prefill:
...
else: