【0.11.0-dev】optimization of kimi-k2 in cann8.3 (#4555)
### What this PR does / why we need it? In cann8.3, npu_moe_gating_top_k operator can support expert nums with 384, so kimi can use the operator to get better preformance. --------- Signed-off-by: Levi-JQ <yujinqi2@huawei.com> Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
This commit is contained in:
@@ -383,7 +383,7 @@ class TestTorchairAscendUnquantizedFusedMoEMethod:
|
|||||||
else:
|
else:
|
||||||
assert result.shape == x.shape
|
assert result.shape == x.shape
|
||||||
|
|
||||||
@pytest.mark.parametrize("others_param", [16, 1, 4])
|
@pytest.mark.parametrize("others_param", [16, 4])
|
||||||
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
|
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
|
||||||
mock_moe_env, others_param):
|
mock_moe_env, others_param):
|
||||||
"""
|
"""
|
||||||
@@ -396,9 +396,18 @@ class TestTorchairAscendUnquantizedFusedMoEMethod:
|
|||||||
is_prefill = False
|
is_prefill = False
|
||||||
forward_context = MagicMock(
|
forward_context = MagicMock(
|
||||||
fused_moe_state=_get_fused_moe_state(ep_size, is_prefill, True))
|
fused_moe_state=_get_fused_moe_state(ep_size, is_prefill, True))
|
||||||
with patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", return_value=forward_context), \
|
if ep_size == 1:
|
||||||
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3):
|
top_k_return = (torch.randn(16, 2), torch.randint(0, 16,
|
||||||
|
(16, 2)), None)
|
||||||
|
expert_map = torch.tensor(
|
||||||
|
[0, 1, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1])
|
||||||
|
else:
|
||||||
|
top_k_return = (torch.randn(8, 2), torch.randint(0, 8,
|
||||||
|
(8, 2)), None)
|
||||||
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
|
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
|
||||||
|
with patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", return_value=forward_context), \
|
||||||
|
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
|
||||||
|
patch('torch_npu.npu_moe_gating_top_k', return_value=top_k_return):
|
||||||
moe_method.ep_size = ep_size
|
moe_method.ep_size = ep_size
|
||||||
x = torch.randn(8, 2, 2)
|
x = torch.randn(8, 2, 2)
|
||||||
if ep_size == 1:
|
if ep_size == 1:
|
||||||
|
|||||||
@@ -857,38 +857,20 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
shared_experts: Optional[Any] = None,
|
shared_experts: Optional[Any] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
global_redundant_expert_num = get_ascend_config(
|
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||||
).init_redundancy_expert
|
router_logits,
|
||||||
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
|
k=top_k, # topk currently is 8
|
||||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
bias=e_score_correction_bias,
|
||||||
if is_deepseek_v3_r1:
|
k_group=topk_group, # fix: 4
|
||||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
group_count=num_expert_group, # fix 8
|
||||||
router_logits,
|
group_select_mode=
|
||||||
k=top_k, # topk currently is 8
|
1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
||||||
bias=e_score_correction_bias,
|
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||||
k_group=topk_group, # fix: 4
|
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
||||||
group_count=num_expert_group, # fix 8
|
# out_flag=False, # todo new api; should the third output be output
|
||||||
group_select_mode=
|
# y2_flag=False, # old api; should the third output be output
|
||||||
1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
routed_scaling_factor=1,
|
||||||
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
eps=float(1e-20))
|
||||||
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
|
||||||
# out_flag=False, # todo new api; should the third output be output
|
|
||||||
# y2_flag=False, # old api; should the third output be output
|
|
||||||
routed_scaling_factor=1,
|
|
||||||
eps=float(1e-20))
|
|
||||||
else:
|
|
||||||
topk_weights, topk_ids = torchair_select_experts(
|
|
||||||
hidden_states=x,
|
|
||||||
router_logits=router_logits,
|
|
||||||
top_k=top_k,
|
|
||||||
use_grouped_topk=use_grouped_topk,
|
|
||||||
renormalize=renormalize,
|
|
||||||
topk_group=topk_group,
|
|
||||||
num_expert_group=num_expert_group,
|
|
||||||
custom_routing_function=custom_routing_function,
|
|
||||||
scoring_func=scoring_func,
|
|
||||||
e_score_correction_bias=e_score_correction_bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
topk_weights = topk_weights.to(x.dtype)
|
topk_weights = topk_weights.to(x.dtype)
|
||||||
# this is a naive implementation for experts load balance so as
|
# this is a naive implementation for experts load balance so as
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ from vllm.forward_context import get_forward_context
|
|||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
|
|
||||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
|
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
|
||||||
torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2)
|
torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2)
|
||||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||||
@@ -322,34 +321,20 @@ class TorchairAscendW4A8DynamicFusedMoEMethod:
|
|||||||
assert router_logits.shape[
|
assert router_logits.shape[
|
||||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||||
|
|
||||||
if global_num_experts == 256:
|
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
router_logits,
|
||||||
router_logits,
|
k=top_k, # topk currently is 8
|
||||||
k=top_k, # topk currently is 8
|
bias=e_score_correction_bias,
|
||||||
bias=e_score_correction_bias,
|
k_group=topk_group, # fix: 4
|
||||||
k_group=topk_group, # fix: 4
|
group_count=num_expert_group, # fix 8
|
||||||
group_count=num_expert_group, # fix 8
|
group_select_mode=
|
||||||
group_select_mode=
|
1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
||||||
1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||||
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
||||||
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
# out_flag=False, # todo new api; should the third output be output
|
||||||
# out_flag=False, # todo new api; should the third output be output
|
# y2_flag=False, # old api; should the third output be output
|
||||||
# y2_flag=False, # old api; should the third output be output
|
routed_scaling_factor=1,
|
||||||
routed_scaling_factor=1,
|
eps=float(1e-20))
|
||||||
eps=float(1e-20))
|
|
||||||
else:
|
|
||||||
topk_weights, topk_ids = torchair_select_experts(
|
|
||||||
hidden_states=x,
|
|
||||||
router_logits=router_logits,
|
|
||||||
top_k=top_k,
|
|
||||||
use_grouped_topk=use_grouped_topk,
|
|
||||||
renormalize=renormalize,
|
|
||||||
topk_group=topk_group,
|
|
||||||
num_expert_group=num_expert_group,
|
|
||||||
custom_routing_function=custom_routing_function,
|
|
||||||
scoring_func=scoring_func,
|
|
||||||
e_score_correction_bias=e_score_correction_bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
fused_moe_state = get_forward_context().fused_moe_state
|
fused_moe_state = get_forward_context().fused_moe_state
|
||||||
shared_gate_up, shared_dequant_scale = None, None
|
shared_gate_up, shared_dequant_scale = None, None
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ from vllm.forward_context import get_forward_context
|
|||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
|
|
||||||
from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor,
|
from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor,
|
||||||
super_kernel)
|
super_kernel)
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
|
||||||
@@ -938,8 +937,6 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
|||||||
assert router_logits.shape[
|
assert router_logits.shape[
|
||||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||||
|
|
||||||
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
|
|
||||||
|
|
||||||
fused_moe_state = get_forward_context().fused_moe_state
|
fused_moe_state = get_forward_context().fused_moe_state
|
||||||
if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2:
|
if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2:
|
||||||
fused_moe_state = FusedMoEState.All2All
|
fused_moe_state = FusedMoEState.All2All
|
||||||
@@ -948,35 +945,20 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
|||||||
with super_kernel(prefix,
|
with super_kernel(prefix,
|
||||||
"stream-fusion=1",
|
"stream-fusion=1",
|
||||||
enabled=running_in_super_kernel):
|
enabled=running_in_super_kernel):
|
||||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||||
if is_deepseek_v3_r1:
|
router_logits,
|
||||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
k=top_k, # topk currently is 8
|
||||||
router_logits,
|
bias=e_score_correction_bias,
|
||||||
k=top_k, # topk currently is 8
|
k_group=topk_group, # fix: 4
|
||||||
bias=e_score_correction_bias,
|
group_count=num_expert_group, # fix 8
|
||||||
k_group=topk_group, # fix: 4
|
group_select_mode=
|
||||||
group_count=num_expert_group, # fix 8
|
1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
||||||
group_select_mode=
|
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||||
1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
||||||
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
# out_flag=False, # todo new api; should the third output be output
|
||||||
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
# y2_flag=False, # old api; should the third output be output
|
||||||
# out_flag=False, # todo new api; should the third output be output
|
routed_scaling_factor=1,
|
||||||
# y2_flag=False, # old api; should the third output be output
|
eps=float(1e-20))
|
||||||
routed_scaling_factor=1,
|
|
||||||
eps=float(1e-20))
|
|
||||||
else:
|
|
||||||
topk_weights, topk_ids = torchair_select_experts(
|
|
||||||
hidden_states=x,
|
|
||||||
router_logits=router_logits,
|
|
||||||
top_k=top_k,
|
|
||||||
use_grouped_topk=use_grouped_topk,
|
|
||||||
renormalize=renormalize,
|
|
||||||
topk_group=topk_group,
|
|
||||||
num_expert_group=num_expert_group,
|
|
||||||
custom_routing_function=custom_routing_function,
|
|
||||||
scoring_func=scoring_func,
|
|
||||||
e_score_correction_bias=e_score_correction_bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
|
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
|
||||||
with npu_stream_switch("moe_secondary", 0):
|
with npu_stream_switch("moe_secondary", 0):
|
||||||
|
|||||||
Reference in New Issue
Block a user