support deepseek quant & mix-parallel with graphmode (#585)
### What this PR does / why we need it? 1. support deepseek with w8a8 quant; 2. support deepseek with mix-parallel(multi-DP, EP+TP); 3. support deepseek with graphmode. --------- Signed-off-by: wen-jie666 <wenjie39@huawei.com> Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com> Signed-off-by: libaokui <libaokui@huawei.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: wen-jie666 <wenjie39@huawei.com>
This commit is contained in:
@@ -330,17 +330,16 @@ def native_grouped_topk(
|
||||
|
||||
|
||||
def select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: Optional[bool] = True
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Select top-k experts based on router logits.
|
||||
@@ -364,7 +363,6 @@ def select_experts(
|
||||
Raises:
|
||||
ValueError: If an unsupported scoring function is provided.
|
||||
"""
|
||||
|
||||
if custom_routing_function is not None:
|
||||
raise NotImplementedError(
|
||||
"Custom routing function is not supported now")
|
||||
@@ -466,21 +464,36 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
is_prefill=False,
|
||||
**kwargs,
|
||||
):
|
||||
# set prefill as false always, should fix this
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
is_prefill=is_prefill)
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if global_num_experts == 256:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k, # topk当前写8
|
||||
bias=e_score_correction_bias,
|
||||
k_group=topk_group, # fix: 4
|
||||
group_count=num_expert_group, # fix 8
|
||||
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
|
||||
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
||||
# out_flag=False, # todo new api; 第三个输出是否输出
|
||||
# y2_flag=False, # old api; 第三个输出是否输出
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20))
|
||||
else:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
if os.environ.get("VLLM_ENABLE_MC2") == "1" and not is_prefill:
|
||||
if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill:
|
||||
return fused_experts_with_mc2(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@@ -611,10 +624,11 @@ class AscendFusedMoE(FusedMoE):
|
||||
real_top_k = self.top_k
|
||||
|
||||
if self.dp_size > 1:
|
||||
if int(os.environ.get("VLLM_ENABLE_MC2") # type: ignore
|
||||
if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore
|
||||
) == 1 and not is_prefill:
|
||||
...
|
||||
elif int(os.environ.get("USING_LCCL_COM")) == 1: # type: ignore
|
||||
elif int(os.environ.get("USING_LCCL_COM",
|
||||
'0')) == 1: # type: ignore
|
||||
hidden_states = get_dp_group().all_gather(
|
||||
hidden_states, 0, False)
|
||||
router_logits = get_dp_group().all_gather(
|
||||
@@ -631,7 +645,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
top_k=real_top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.num_experts,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
@@ -641,7 +655,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
is_prefill=is_prefill)
|
||||
|
||||
if self.dp_size > 1:
|
||||
if int(os.environ.get("VLLM_ENABLE_MC2") # type: ignore
|
||||
if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore
|
||||
) == 1 and not is_prefill:
|
||||
...
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user