[main] adapt usage of npu_moe_gating_top_k_softmax and remove envs.SELECT_GATING_TOPK_SOTFMAX_EXPERTS (#2112)
backport of v0.9.1-dev:
https://github.com/vllm-project/vllm-ascend/pull/1902
origin main npu_moe_gating_top_k_softmax:
https://github.com/vllm-project/vllm-ascend/pull/1355
- vLLM version: v0.10.0
- vLLM main:
055bd3978e
Signed-off-by: huangxialu <huangxialu1@huawei.com>
This commit is contained in:
@@ -23,11 +23,13 @@ Run `pytest tests/ops/test_fused_moe.py`.
|
|||||||
# here to make the test pass.
|
# here to make the test pass.
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
|
||||||
from vllm_ascend.ops.fused_moe import fused_experts
|
from vllm_ascend.ops.fused_moe import fused_experts, select_experts
|
||||||
|
|
||||||
NUM_EXPERTS = [8, 64]
|
NUM_EXPERTS = [8, 64]
|
||||||
EP_SIZE = [1, 4]
|
EP_SIZE = [1, 4]
|
||||||
@@ -98,3 +100,97 @@ def test_fused_experts(
|
|||||||
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
|
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
|
||||||
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
|
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
|
||||||
torch.npu.empty_cache()
|
torch.npu.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("m", [1, 33, 64])
|
||||||
|
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||||
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
|
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
|
||||||
|
@pytest.mark.parametrize("use_grouped_topk", [True, False])
|
||||||
|
@pytest.mark.parametrize("renormalize", [True, False])
|
||||||
|
@pytest.mark.parametrize("with_e_correction", [True, False])
|
||||||
|
@pytest.mark.parametrize("custom_routing", [True, False])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize("device", DEVICE)
|
||||||
|
def test_select_experts(
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
e: int,
|
||||||
|
topk: int,
|
||||||
|
scoring_func: str,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
renormalize: bool,
|
||||||
|
with_e_correction: bool,
|
||||||
|
custom_routing: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
|
):
|
||||||
|
topk_group = 4 if use_grouped_topk else None
|
||||||
|
num_expert_group = e // 4 if use_grouped_topk else None
|
||||||
|
|
||||||
|
hidden_states = torch.randn(m, n, device=device, dtype=dtype)
|
||||||
|
router_logits = torch.randn(m, e, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
e_score_correction_bias = (torch.randn(e, device=device, dtype=dtype)
|
||||||
|
if with_e_correction else None)
|
||||||
|
|
||||||
|
custom_routing_function = None
|
||||||
|
if custom_routing:
|
||||||
|
custom_routing_function = MagicMock()
|
||||||
|
mock_weights = torch.randn(m, topk, device=device, dtype=dtype)
|
||||||
|
mock_ids = torch.randint(0,
|
||||||
|
e, (m, topk),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
custom_routing_function.return_value = (mock_weights, mock_ids)
|
||||||
|
|
||||||
|
with patch("vllm_ascend.ops.fused_moe.native_grouped_topk"
|
||||||
|
) as mock_native_grouped_topk:
|
||||||
|
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
|
||||||
|
x)
|
||||||
|
|
||||||
|
topk_weights, topk_ids = select_experts(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
router_logits=router_logits,
|
||||||
|
top_k=topk,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_grouped_topk:
|
||||||
|
mock_native_grouped_topk.assert_called_once()
|
||||||
|
else:
|
||||||
|
mock_native_grouped_topk.assert_not_called()
|
||||||
|
|
||||||
|
assert topk_weights.shape == (m, topk)
|
||||||
|
assert topk_ids.shape == (m, topk)
|
||||||
|
assert topk_ids.dtype == torch.int32
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device", DEVICE)
|
||||||
|
def test_select_experts_invalid_scoring_func(device: str):
|
||||||
|
with pytest.raises(ValueError,
|
||||||
|
match="Unsupported scoring function: invalid"):
|
||||||
|
select_experts(hidden_states=torch.randn(1, 128, device=device),
|
||||||
|
router_logits=torch.randn(1, 8, device=device),
|
||||||
|
top_k=2,
|
||||||
|
use_grouped_topk=False,
|
||||||
|
renormalize=False,
|
||||||
|
scoring_func="invalid")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device", DEVICE)
|
||||||
|
def test_select_experts_missing_group_params(device: str):
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
select_experts(hidden_states=torch.randn(1, 128, device=device),
|
||||||
|
router_logits=torch.randn(1, 64, device=device),
|
||||||
|
top_k=2,
|
||||||
|
use_grouped_topk=True,
|
||||||
|
renormalize=False,
|
||||||
|
scoring_func="softmax")
|
||||||
|
|||||||
@@ -297,9 +297,8 @@ class TestAscendUnquantizedFusedMoEMethod:
|
|||||||
assert not layer.w13_weight.requires_grad
|
assert not layer.w13_weight.requires_grad
|
||||||
assert not layer.w2_weight.requires_grad
|
assert not layer.w2_weight.requires_grad
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("others_param",
|
||||||
"others_param",
|
[[256, 4], [128, 1], [128, 1], [128, 4]])
|
||||||
[[256, 4, False], [128, 1, False], [128, 1, True], [128, 4, False]])
|
|
||||||
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
|
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
|
||||||
mock_moe_env, others_param):
|
mock_moe_env, others_param):
|
||||||
"""
|
"""
|
||||||
@@ -308,15 +307,13 @@ class TestAscendUnquantizedFusedMoEMethod:
|
|||||||
3 test use select_gating_topk_softmax_experts and fused_experts
|
3 test use select_gating_topk_softmax_experts and fused_experts
|
||||||
4 test use select_experts and fused_experts_with_all2all_buffer
|
4 test use select_experts and fused_experts_with_all2all_buffer
|
||||||
"""
|
"""
|
||||||
global_num_experts, ep_size, select_softmax = others_param
|
global_num_experts, ep_size = others_param
|
||||||
is_prefill = False
|
is_prefill = False
|
||||||
is_deepseek_v3_r1 = global_num_experts == 256
|
is_deepseek_v3_r1 = global_num_experts == 256
|
||||||
forward_context = MagicMock(fused_moe_state=get_fused_moe_state(
|
forward_context = MagicMock(fused_moe_state=get_fused_moe_state(
|
||||||
ep_size, is_prefill, is_deepseek_v3_r1))
|
ep_size, is_prefill, is_deepseek_v3_r1))
|
||||||
with patch(
|
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
|
||||||
"vllm_ascend.ops.fused_moe.SELECT_GATING_TOPK_SOTFMAX_EXPERTS",
|
return_value=forward_context):
|
||||||
select_softmax), \
|
|
||||||
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context):
|
|
||||||
moe_method.ep_size = ep_size
|
moe_method.ep_size = ep_size
|
||||||
x = torch.randn(8, 2, 2)
|
x = torch.randn(8, 2, 2)
|
||||||
router_logits = torch.randn(8, 8)
|
router_logits = torch.randn(8, 8)
|
||||||
|
|||||||
@@ -117,11 +117,6 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
# value to False to disable the optimized model.
|
# value to False to disable the optimized model.
|
||||||
"USE_OPTIMIZED_MODEL":
|
"USE_OPTIMIZED_MODEL":
|
||||||
lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
|
lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
|
||||||
# SELECT_GATING_TOPK_SOTFMAX_EXPERTS is the equivalent of select_experts in non-quantized scenarios.
|
|
||||||
# In theory, it should have better performance than select_experts.
|
|
||||||
# Subsequent versions will remove the SELECT_GATING_TOPK_SOTFMAX_EXPERTS tag and use it as the default mode.
|
|
||||||
"SELECT_GATING_TOPK_SOTFMAX_EXPERTS":
|
|
||||||
lambda: bool(int(os.getenv("SELECT_GATING_TOPK_SOTFMAX_EXPERTS", '0'))),
|
|
||||||
# The tolerance of the kv cache size, if the difference between the
|
# The tolerance of the kv cache size, if the difference between the
|
||||||
# actual kv cache size and the cached kv cache size is less than this value,
|
# actual kv cache size and the cached kv cache size is less than this value,
|
||||||
# then the cached kv cache size will be used.
|
# then the cached kv cache size will be used.
|
||||||
|
|||||||
@@ -22,13 +22,10 @@ from vllm.config import CompilationLevel, get_current_vllm_config
|
|||||||
from vllm.model_executor.layers.fused_moe.layer import \
|
from vllm.model_executor.layers.fused_moe.layer import \
|
||||||
UnquantizedFusedMoEMethod
|
UnquantizedFusedMoEMethod
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
|
||||||
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge,
|
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge,
|
||||||
select_experts,
|
select_experts)
|
||||||
select_gating_top_k_softmax_experts)
|
|
||||||
from vllm_ascend.utils import is_310p
|
from vllm_ascend.utils import is_310p
|
||||||
|
|
||||||
SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS
|
|
||||||
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
||||||
|
|
||||||
|
|
||||||
@@ -61,26 +58,19 @@ def forward_oot(
|
|||||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
|
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
|
||||||
if SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
|
topk_weights, topk_ids = select_experts(
|
||||||
topk_weights, topk_ids = select_gating_top_k_softmax_experts(
|
global_num_experts=global_num_experts,
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
renormalize=renormalize)
|
use_grouped_topk=use_grouped_topk,
|
||||||
else:
|
renormalize=renormalize,
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_group=topk_group,
|
||||||
global_num_experts=global_num_experts,
|
num_expert_group=num_expert_group,
|
||||||
hidden_states=x,
|
custom_routing_function=custom_routing_function,
|
||||||
router_logits=router_logits,
|
scoring_func=scoring_func,
|
||||||
top_k=top_k,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
use_grouped_topk=use_grouped_topk,
|
)
|
||||||
renormalize=renormalize,
|
|
||||||
topk_group=topk_group,
|
|
||||||
num_expert_group=num_expert_group,
|
|
||||||
custom_routing_function=custom_routing_function,
|
|
||||||
scoring_func=scoring_func,
|
|
||||||
e_score_correction_bias=e_score_correction_bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
if topk_ids.shape[1] < top_k or is_310p():
|
if topk_ids.shape[1] < top_k or is_310p():
|
||||||
assert global_num_experts is not None
|
assert global_num_experts is not None
|
||||||
|
|||||||
@@ -52,7 +52,6 @@ from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
|||||||
get_rm_router_logits_state, is_310p)
|
get_rm_router_logits_state, is_310p)
|
||||||
|
|
||||||
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
||||||
SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS
|
|
||||||
|
|
||||||
|
|
||||||
def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
|
def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
|
||||||
@@ -859,39 +858,6 @@ def fused_experts(
|
|||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
def select_gating_top_k_softmax_experts(
|
|
||||||
hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int,
|
|
||||||
renormalize: bool) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Select top-k experts based on router logits.
|
|
||||||
only supports float16、bfloat16、float32
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
|
||||||
router_logits: Router logits of shape (num_tokens, num_experts).
|
|
||||||
top_k: Number of experts to select.
|
|
||||||
renormalize: Whether to renormalize the routing weights.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
topk_weights: Routing weights of shape (num_tokens, top_k).
|
|
||||||
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If an unsupported scoring function is provided.
|
|
||||||
"""
|
|
||||||
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
|
|
||||||
router_logits, None, k=top_k)
|
|
||||||
|
|
||||||
# # Required by npu_moe_init_routing
|
|
||||||
# topk_weights = topk_weights.to(hidden_states.dtype)
|
|
||||||
# topk_ids = topk_ids.to(torch.int32)
|
|
||||||
|
|
||||||
if renormalize:
|
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
|
||||||
|
|
||||||
|
|
||||||
def native_grouped_topk(
|
def native_grouped_topk(
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
num_expert_group: Optional[int],
|
num_expert_group: Optional[int],
|
||||||
@@ -953,8 +919,24 @@ def select_experts(
|
|||||||
ValueError: If an unsupported scoring function is provided.
|
ValueError: If an unsupported scoring function is provided.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def _renormalize_topk_weights(
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
renormalize: bool,
|
||||||
|
):
|
||||||
|
if renormalize:
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1,
|
||||||
|
keepdim=True)
|
||||||
|
return topk_weights
|
||||||
|
|
||||||
if scoring_func == "softmax":
|
if scoring_func == "softmax":
|
||||||
# NOTE: vLLM use dtype=torch.float here
|
# NOTE: vLLM use dtype=torch.float here
|
||||||
|
if not use_grouped_topk and custom_routing_function is None:
|
||||||
|
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
|
||||||
|
x=router_logits, finished=None, k=top_k)
|
||||||
|
topk_ids = topk_ids.to(torch.int32)
|
||||||
|
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||||
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
topk_weights = router_logits.softmax(dim=-1)
|
topk_weights = router_logits.softmax(dim=-1)
|
||||||
elif scoring_func == "sigmoid":
|
elif scoring_func == "sigmoid":
|
||||||
topk_weights = router_logits.sigmoid()
|
topk_weights = router_logits.sigmoid()
|
||||||
@@ -988,10 +970,11 @@ def select_experts(
|
|||||||
k=top_k,
|
k=top_k,
|
||||||
dim=-1,
|
dim=-1,
|
||||||
sorted=False)
|
sorted=False)
|
||||||
elif custom_routing_function is None:
|
topk_ids = topk_ids.to(torch.int32)
|
||||||
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||||
topk_weights = topk_weights.to(hidden_states.dtype)
|
return topk_weights, topk_ids
|
||||||
else:
|
|
||||||
|
if custom_routing_function is not None:
|
||||||
topk_weights, topk_ids = custom_routing_function(
|
topk_weights, topk_ids = custom_routing_function(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
gating_output=router_logits,
|
gating_output=router_logits,
|
||||||
@@ -1002,11 +985,12 @@ def select_experts(
|
|||||||
topk_ids = topk_ids.to(torch.int32)
|
topk_ids = topk_ids.to(torch.int32)
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
||||||
|
topk_weights = topk_weights.to(hidden_states.dtype)
|
||||||
|
|
||||||
# Required by npu_moe_init_routing
|
# Required by npu_moe_init_routing
|
||||||
topk_ids = topk_ids.to(torch.int32)
|
topk_ids = topk_ids.to(torch.int32)
|
||||||
|
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||||
if renormalize:
|
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
@@ -1070,23 +1054,18 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
if is_deepseek_v3_r1:
|
if is_deepseek_v3_r1:
|
||||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||||
router_logits,
|
router_logits,
|
||||||
k=top_k, # topk当前写8
|
k=top_k, # topk currently is 8
|
||||||
bias=e_score_correction_bias,
|
bias=e_score_correction_bias,
|
||||||
k_group=topk_group, # fix: 4
|
k_group=topk_group, # fix: 4
|
||||||
group_count=num_expert_group, # fix 8
|
group_count=num_expert_group, # fix 8
|
||||||
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
|
group_select_mode=
|
||||||
|
1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
||||||
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||||
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
||||||
# out_flag=False, # todo new api; 第三个输出是否输出
|
# out_flag=False, # todo new api; should the third output be output
|
||||||
# y2_flag=False, # old api; 第三个输出是否输出
|
# y2_flag=False, # old api; should the third output be output
|
||||||
routed_scaling_factor=1,
|
routed_scaling_factor=1,
|
||||||
eps=float(1e-20))
|
eps=float(1e-20))
|
||||||
elif SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
|
|
||||||
topk_weights, topk_ids = select_gating_top_k_softmax_experts(
|
|
||||||
hidden_states=x,
|
|
||||||
router_logits=router_logits,
|
|
||||||
top_k=top_k,
|
|
||||||
renormalize=renormalize)
|
|
||||||
else:
|
else:
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
|
|||||||
Reference in New Issue
Block a user