use npu_moe_gating_top_k_softmax (#1355)
### What this PR does / why we need it?
The optimization solution for non-deepseek select_experts is to replace
gating_topk_softmax with softmax+topk+to, which is optimized from 37us
to 14us on bf16/fp16 of qwen3-235b
- vLLM version: v0.9.2
- vLLM main:
1a4f35e2ea
---------
Signed-off-by: ttanzhiqiang <389825161@qq.com>
This commit is contained in:
37
tests/e2e/singlecard/ops/test_gating_top_k_softmax.py
Normal file
37
tests/e2e/singlecard/ops/test_gating_top_k_softmax.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'B',
|
||||||
|
[1, 16, 64, 128, 32768],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'D',
|
||||||
|
[8, 16, 32, 64, 128],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'top_k',
|
||||||
|
[1, 2, 4, 8],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"dtype, atol, rtol",
|
||||||
|
[
|
||||||
|
(torch.float16, 1e-3, 1e-3),
|
||||||
|
(torch.bfloat16, 1e-3, 1e-3),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_quant_fpx_linear(B: int, D: int, top_k: int, dtype, atol, rtol):
|
||||||
|
x = torch.rand((B, D), dtype=dtype).to("npu")
|
||||||
|
# finished = torch.randint(1, size=(B,), dtype=torch.bool).to("npu")
|
||||||
|
finished = None
|
||||||
|
y, expert_idx, row_idx = torch_npu.npu_moe_gating_top_k_softmax(x,
|
||||||
|
finished,
|
||||||
|
k=top_k)
|
||||||
|
|
||||||
|
topk_weights = x.softmax(dim=-1)
|
||||||
|
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
||||||
|
topk_ids = topk_ids.to(torch.int32)
|
||||||
|
torch.allclose(y, topk_weights, atol=atol, rtol=rtol)
|
||||||
|
torch.allclose(expert_idx, topk_ids, atol=atol, rtol=rtol)
|
||||||
@@ -117,6 +117,11 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
# value to False to disable the optimized model.
|
# value to False to disable the optimized model.
|
||||||
"USE_OPTIMIZED_MODEL":
|
"USE_OPTIMIZED_MODEL":
|
||||||
lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
|
lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
|
||||||
|
# SELECT_GATING_TOPK_SOTFMAX_EXPERTS is the equivalent of select_experts in non-quantized scenarios.
|
||||||
|
# In theory, it should have better performance than select_experts.
|
||||||
|
# Subsequent versions will remove the SELECT_GATING_TOPK_SOTFMAX_EXPERTS tag and use it as the default mode.
|
||||||
|
"SELECT_GATING_TOPK_SOTFMAX_EXPERTS":
|
||||||
|
lambda: bool(int(os.getenv("SELECT_GATING_TOPK_SOTFMAX_EXPERTS", '0'))),
|
||||||
# The tolerance of the kv cache size, if the difference between the
|
# The tolerance of the kv cache size, if the difference between the
|
||||||
# actual kv cache size and the cached kv cache size is less than this value,
|
# actual kv cache size and the cached kv cache size is less than this value,
|
||||||
# then the cached kv cache size will be used.
|
# then the cached kv cache size will be used.
|
||||||
|
|||||||
@@ -22,10 +22,13 @@ from vllm.config import CompilationLevel, get_current_vllm_config
|
|||||||
from vllm.model_executor.layers.fused_moe.layer import \
|
from vllm.model_executor.layers.fused_moe.layer import \
|
||||||
UnquantizedFusedMoEMethod
|
UnquantizedFusedMoEMethod
|
||||||
|
|
||||||
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge,
|
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge,
|
||||||
select_experts)
|
select_experts,
|
||||||
|
select_gating_top_k_softmax_experts)
|
||||||
from vllm_ascend.utils import is_310p
|
from vllm_ascend.utils import is_310p
|
||||||
|
|
||||||
|
SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS
|
||||||
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
||||||
|
|
||||||
|
|
||||||
@@ -54,6 +57,14 @@ def forward_oot(
|
|||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
if SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
|
||||||
|
topk_weights, topk_ids = select_gating_top_k_softmax_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize)
|
||||||
|
else:
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
|
|||||||
npu_stream_switch, npu_wait_tensor)
|
npu_stream_switch, npu_wait_tensor)
|
||||||
|
|
||||||
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
||||||
|
SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS
|
||||||
|
|
||||||
|
|
||||||
def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
|
def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
|
||||||
@@ -821,6 +822,39 @@ def fused_experts(
|
|||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def select_gating_top_k_softmax_experts(
|
||||||
|
hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int,
|
||||||
|
renormalize: bool) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Select top-k experts based on router logits.
|
||||||
|
only supports float16、bfloat16、float32
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||||
|
router_logits: Router logits of shape (num_tokens, num_experts).
|
||||||
|
top_k: Number of experts to select.
|
||||||
|
renormalize: Whether to renormalize the routing weights.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
topk_weights: Routing weights of shape (num_tokens, top_k).
|
||||||
|
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If an unsupported scoring function is provided.
|
||||||
|
"""
|
||||||
|
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
|
||||||
|
router_logits, None, k=top_k)
|
||||||
|
|
||||||
|
# # Required by npu_moe_init_routing
|
||||||
|
# topk_weights = topk_weights.to(hidden_states.dtype)
|
||||||
|
# topk_ids = topk_ids.to(torch.int32)
|
||||||
|
|
||||||
|
if renormalize:
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
def native_grouped_topk(
|
def native_grouped_topk(
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
num_expert_group: Optional[int],
|
num_expert_group: Optional[int],
|
||||||
@@ -1013,6 +1047,12 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
# y2_flag=False, # old api; 第三个输出是否输出
|
# y2_flag=False, # old api; 第三个输出是否输出
|
||||||
routed_scaling_factor=1,
|
routed_scaling_factor=1,
|
||||||
eps=float(1e-20))
|
eps=float(1e-20))
|
||||||
|
elif SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
|
||||||
|
topk_weights, topk_ids = select_gating_top_k_softmax_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize)
|
||||||
else:
|
else:
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
|
|||||||
Reference in New Issue
Block a user