【0.11.0-dev】optimization of kimi-k2 in cann8.3 (#4555)

### What this PR does / why we need it?
In cann8.3, npu_moe_gating_top_k operator can support expert nums with
384, so kimi can use the operator to get better preformance.
---------

Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
This commit is contained in:
Levi
2025-12-09 08:49:15 +08:00
committed by GitHub
parent 0d094531b4
commit 9862a23985
4 changed files with 54 additions and 96 deletions

View File

@@ -383,7 +383,7 @@ class TestTorchairAscendUnquantizedFusedMoEMethod:
else: else:
assert result.shape == x.shape assert result.shape == x.shape
@pytest.mark.parametrize("others_param", [16, 1, 4]) @pytest.mark.parametrize("others_param", [16, 4])
def test_apply_with_expert_map(self, moe_method, mock_dist_env, def test_apply_with_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param): mock_moe_env, others_param):
""" """
@@ -396,9 +396,18 @@ class TestTorchairAscendUnquantizedFusedMoEMethod:
is_prefill = False is_prefill = False
forward_context = MagicMock( forward_context = MagicMock(
fused_moe_state=_get_fused_moe_state(ep_size, is_prefill, True)) fused_moe_state=_get_fused_moe_state(ep_size, is_prefill, True))
with patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", return_value=forward_context), \ if ep_size == 1:
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3): top_k_return = (torch.randn(16, 2), torch.randint(0, 16,
(16, 2)), None)
expert_map = torch.tensor(
[0, 1, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1])
else:
top_k_return = (torch.randn(8, 2), torch.randint(0, 8,
(8, 2)), None)
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]) expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
with patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", return_value=forward_context), \
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
patch('torch_npu.npu_moe_gating_top_k', return_value=top_k_return):
moe_method.ep_size = ep_size moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2) x = torch.randn(8, 2, 2)
if ep_size == 1: if ep_size == 1:

View File

@@ -857,38 +857,20 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
shared_experts: Optional[Any] = None, shared_experts: Optional[Any] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
global_redundant_expert_num = get_ascend_config( topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
).init_redundancy_expert router_logits,
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256 k=top_k, # topk currently is 8
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern bias=e_score_correction_bias,
if is_deepseek_v3_r1: k_group=topk_group, # fix: 4
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( group_count=num_expert_group, # fix 8
router_logits, group_select_mode=
k=top_k, # topk currently is 8 1, # 0: the maximum in the group; 1: topk2.sum(fix)
bias=e_score_correction_bias, renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
k_group=topk_group, # fix: 4 norm_type=1, # 0: softmax; 1: sigmoid(fix)
group_count=num_expert_group, # fix 8 # out_flag=False, # todo new api; should the third output be output
group_select_mode= # y2_flag=False, # old api; should the third output be output
1, # 0: the maximum in the group; 1: topk2.sum(fix) routed_scaling_factor=1,
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax eps=float(1e-20))
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
else:
topk_weights, topk_ids = torchair_select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
topk_weights = topk_weights.to(x.dtype) topk_weights = topk_weights.to(x.dtype)
# this is a naive implementation for experts load balance so as # this is a naive implementation for experts load balance so as

View File

@@ -27,7 +27,6 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import ( from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2) torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2)
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
@@ -322,34 +321,20 @@ class TorchairAscendW4A8DynamicFusedMoEMethod:
assert router_logits.shape[ assert router_logits.shape[
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
if global_num_experts == 256: topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( router_logits,
router_logits, k=top_k, # topk currently is 8
k=top_k, # topk currently is 8 bias=e_score_correction_bias,
bias=e_score_correction_bias, k_group=topk_group, # fix: 4
k_group=topk_group, # fix: 4 group_count=num_expert_group, # fix 8
group_count=num_expert_group, # fix 8 group_select_mode=
group_select_mode= 1, # 0: the maximum in the group; 1: topk2.sum(fix)
1, # 0: the maximum in the group; 1: topk2.sum(fix) renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax norm_type=1, # 0: softmax; 1: sigmoid(fix)
norm_type=1, # 0: softmax; 1: sigmoid(fix) # out_flag=False, # todo new api; should the third output be output
# out_flag=False, # todo new api; should the third output be output # y2_flag=False, # old api; should the third output be output
# y2_flag=False, # old api; should the third output be output routed_scaling_factor=1,
routed_scaling_factor=1, eps=float(1e-20))
eps=float(1e-20))
else:
topk_weights, topk_ids = torchair_select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
fused_moe_state = get_forward_context().fused_moe_state fused_moe_state = get_forward_context().fused_moe_state
shared_gate_up, shared_dequant_scale = None, None shared_gate_up, shared_dequant_scale = None, None

View File

@@ -25,7 +25,6 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor, from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor,
super_kernel) super_kernel)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
@@ -938,8 +937,6 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
assert router_logits.shape[ assert router_logits.shape[
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
fused_moe_state = get_forward_context().fused_moe_state fused_moe_state = get_forward_context().fused_moe_state
if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2: if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2:
fused_moe_state = FusedMoEState.All2All fused_moe_state = FusedMoEState.All2All
@@ -948,35 +945,20 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
with super_kernel(prefix, with super_kernel(prefix,
"stream-fusion=1", "stream-fusion=1",
enabled=running_in_super_kernel): enabled=running_in_super_kernel):
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
if is_deepseek_v3_r1: router_logits,
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( k=top_k, # topk currently is 8
router_logits, bias=e_score_correction_bias,
k=top_k, # topk currently is 8 k_group=topk_group, # fix: 4
bias=e_score_correction_bias, group_count=num_expert_group, # fix 8
k_group=topk_group, # fix: 4 group_select_mode=
group_count=num_expert_group, # fix 8 1, # 0: the maximum in the group; 1: topk2.sum(fix)
group_select_mode= renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
1, # 0: the maximum in the group; 1: topk2.sum(fix) norm_type=1, # 0: softmax; 1: sigmoid(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax # out_flag=False, # todo new api; should the third output be output
norm_type=1, # 0: softmax; 1: sigmoid(fix) # y2_flag=False, # old api; should the third output be output
# out_flag=False, # todo new api; should the third output be output routed_scaling_factor=1,
# y2_flag=False, # old api; should the third output be output eps=float(1e-20))
routed_scaling_factor=1,
eps=float(1e-20))
else:
topk_weights, topk_ids = torchair_select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0): with npu_stream_switch("moe_secondary", 0):