This reverts commit c4a11a745a.
ops npu_gating_top_k caused Qwen3-30B precision problem, so revert it.
Signed-off-by: 1092626063 <1092626063@qq.com>
This commit is contained in:
@@ -28,8 +28,7 @@ import torch
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
|
||||||
from vllm_ascend.ops.moe.experts_selector import (check_npu_moe_gating_top_k,
|
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||||
select_experts)
|
|
||||||
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp
|
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp
|
||||||
from vllm_ascend.ops.moe.token_dispatcher import TokenDispatcherWithAllGather
|
from vllm_ascend.ops.moe.token_dispatcher import TokenDispatcherWithAllGather
|
||||||
|
|
||||||
@@ -297,10 +296,7 @@ def test_select_experts(
|
|||||||
e_score_correction_bias=e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
call_moe_gatingtopk = check_npu_moe_gating_top_k(
|
if use_grouped_topk:
|
||||||
hidden_states, topk, topk_group, num_expert_group, scoring_func,
|
|
||||||
custom_routing_function)
|
|
||||||
if not call_moe_gatingtopk and use_grouped_topk:
|
|
||||||
mock_native_grouped_topk.assert_called_once()
|
mock_native_grouped_topk.assert_called_once()
|
||||||
else:
|
else:
|
||||||
mock_native_grouped_topk.assert_not_called()
|
mock_native_grouped_topk.assert_not_called()
|
||||||
|
|||||||
@@ -753,14 +753,6 @@ class TestSelectExperts(TestBase):
|
|||||||
|
|
||||||
self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
|
self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
|
||||||
self.router_logits = torch.randn(self.num_tokens, self.num_experts)
|
self.router_logits = torch.randn(self.num_tokens, self.num_experts)
|
||||||
"""Mock custom routing"""
|
|
||||||
self.mock_custom_routing = MagicMock()
|
|
||||||
self.mock_custom_routing.return_value = (torch.ones(
|
|
||||||
self.num_tokens, self.top_k),
|
|
||||||
torch.zeros(
|
|
||||||
self.num_tokens,
|
|
||||||
self.top_k,
|
|
||||||
dtype=torch.int32))
|
|
||||||
|
|
||||||
self.mock_ctx = MagicMock()
|
self.mock_ctx = MagicMock()
|
||||||
self.mock_ctx.weight_prefetch_method = MagicMock()
|
self.mock_ctx.weight_prefetch_method = MagicMock()
|
||||||
@@ -770,7 +762,7 @@ class TestSelectExperts(TestBase):
|
|||||||
self.addCleanup(patcher.stop)
|
self.addCleanup(patcher.stop)
|
||||||
patcher.start()
|
patcher.start()
|
||||||
|
|
||||||
@patch('torch_npu.npu_moe_gating_top_k')
|
@patch('torch_npu.npu_moe_gating_top_k_softmax')
|
||||||
def test_softmax_scoring(self, mock_topk):
|
def test_softmax_scoring(self, mock_topk):
|
||||||
"""Test softmax scoring function"""
|
"""Test softmax scoring function"""
|
||||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||||
@@ -797,14 +789,12 @@ class TestSelectExperts(TestBase):
|
|||||||
def test_sigmoid_scoring(self):
|
def test_sigmoid_scoring(self):
|
||||||
"""Test sigmoid scoring function"""
|
"""Test sigmoid scoring function"""
|
||||||
|
|
||||||
weights, ids = select_experts(
|
weights, ids = select_experts(hidden_states=self.hidden_states,
|
||||||
hidden_states=self.hidden_states,
|
router_logits=self.router_logits,
|
||||||
router_logits=self.router_logits,
|
top_k=self.top_k,
|
||||||
top_k=self.top_k,
|
use_grouped_topk=False,
|
||||||
use_grouped_topk=False,
|
renormalize=False,
|
||||||
renormalize=False,
|
scoring_func="sigmoid")
|
||||||
scoring_func="sigmoid",
|
|
||||||
custom_routing_function=self.mock_custom_routing)
|
|
||||||
|
|
||||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||||
@@ -863,20 +853,27 @@ class TestSelectExperts(TestBase):
|
|||||||
|
|
||||||
def test_custom_routing_function(self):
|
def test_custom_routing_function(self):
|
||||||
"""Test custom routing function"""
|
"""Test custom routing function"""
|
||||||
|
mock_custom_routing = MagicMock()
|
||||||
|
mock_custom_routing.return_value = (torch.ones(self.num_tokens,
|
||||||
|
self.top_k),
|
||||||
|
torch.zeros(self.num_tokens,
|
||||||
|
self.top_k,
|
||||||
|
dtype=torch.int32))
|
||||||
|
|
||||||
weights, ids = select_experts(
|
weights, ids = select_experts(
|
||||||
hidden_states=self.hidden_states,
|
hidden_states=self.hidden_states,
|
||||||
router_logits=self.router_logits,
|
router_logits=self.router_logits,
|
||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
use_grouped_topk=False,
|
use_grouped_topk=False,
|
||||||
renormalize=False,
|
renormalize=False,
|
||||||
custom_routing_function=self.mock_custom_routing)
|
custom_routing_function=mock_custom_routing)
|
||||||
|
|
||||||
self.mock_custom_routing.assert_called_once()
|
mock_custom_routing.assert_called_once()
|
||||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||||
self.assertEqual(ids.dtype, torch.int32)
|
self.assertEqual(ids.dtype, torch.int32)
|
||||||
|
|
||||||
@patch('torch_npu.npu_moe_gating_top_k')
|
@patch('torch_npu.npu_moe_gating_top_k_softmax')
|
||||||
def test_renormalize(self, mock_topk):
|
def test_renormalize(self, mock_topk):
|
||||||
"""Test renormalization"""
|
"""Test renormalization"""
|
||||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||||
@@ -902,13 +899,13 @@ class TestSelectExperts(TestBase):
|
|||||||
sums = weights.sum(dim=-1)
|
sums = weights.sum(dim=-1)
|
||||||
self.assertTrue(torch.allclose(sums, torch.ones_like(sums)))
|
self.assertTrue(torch.allclose(sums, torch.ones_like(sums)))
|
||||||
|
|
||||||
@patch('torch_npu.npu_moe_gating_top_k')
|
@patch('torch_npu.npu_moe_gating_top_k_softmax')
|
||||||
def test_output_dtypes(self, mock_topk):
|
def test_output_dtypes(self, mock_topk):
|
||||||
"""Test output dtypes"""
|
"""Test output dtypes"""
|
||||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||||
torch.zeros(self.num_tokens,
|
torch.zeros(self.num_tokens,
|
||||||
self.top_k,
|
self.top_k,
|
||||||
dtype=torch.int32),
|
dtype=torch.long),
|
||||||
torch.arange(0,
|
torch.arange(0,
|
||||||
self.num_tokens * self.top_k,
|
self.num_tokens * self.top_k,
|
||||||
dtype=torch.int32).view(
|
dtype=torch.int32).view(
|
||||||
|
|||||||
@@ -96,7 +96,6 @@ def set_ascend_forward_context(
|
|||||||
ep_size = (get_ep_group().world_size if
|
ep_size = (get_ep_group().world_size if
|
||||||
vllm_config.parallel_config.enable_expert_parallel else 1)
|
vllm_config.parallel_config.enable_expert_parallel else 1)
|
||||||
|
|
||||||
# fused_moe_state is used in torchair, it will be deleted along with torchair
|
|
||||||
is_deepseek_v3_r1 = hasattr(
|
is_deepseek_v3_r1 = hasattr(
|
||||||
vllm_config.model_config.hf_config, 'n_routed_experts'
|
vllm_config.model_config.hf_config, 'n_routed_experts'
|
||||||
) and vllm_config.model_config.hf_config.n_routed_experts == 256
|
) and vllm_config.model_config.hf_config.n_routed_experts == 256
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ import torch
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
|
||||||
|
|
||||||
def select_experts(hidden_states: torch.Tensor,
|
def select_experts(hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
@@ -60,28 +62,21 @@ def select_experts(hidden_states: torch.Tensor,
|
|||||||
if weight_prefetch_method:
|
if weight_prefetch_method:
|
||||||
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
|
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
|
||||||
hidden_states, "gate_up")
|
hidden_states, "gate_up")
|
||||||
is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k(
|
topk_weights, topk_ids = _select_experts_with_fusion_ops(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
|
router_logits=router_logits,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
|
renormalize=renormalize,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
custom_routing_function=custom_routing_function)
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
global_num_experts=global_num_experts)
|
||||||
|
|
||||||
if is_support_npu_moe_gating_top_k:
|
if topk_weights is None:
|
||||||
topk_weights, topk_ids = _select_experts_with_fusion_ops(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
router_logits=router_logits,
|
|
||||||
top_k=top_k,
|
|
||||||
use_grouped_topk=use_grouped_topk,
|
|
||||||
topk_group=topk_group,
|
|
||||||
renormalize=renormalize,
|
|
||||||
e_score_correction_bias=e_score_correction_bias,
|
|
||||||
num_expert_group=num_expert_group,
|
|
||||||
scoring_func=scoring_func,
|
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
|
||||||
global_num_experts=global_num_experts)
|
|
||||||
else:
|
|
||||||
topk_weights, topk_ids = _native_select_experts(
|
topk_weights, topk_ids = _native_select_experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
@@ -98,32 +93,6 @@ def select_experts(hidden_states: torch.Tensor,
|
|||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
def check_npu_moe_gating_top_k(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
top_k: int,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
scoring_func: str = "softmax",
|
|
||||||
custom_routing_function: Optional[Callable] = None):
|
|
||||||
if custom_routing_function is not None:
|
|
||||||
return False
|
|
||||||
if scoring_func != "softmax" and scoring_func != "sigmoid":
|
|
||||||
return False
|
|
||||||
topk_group = topk_group if topk_group is not None else 1
|
|
||||||
num_expert_group = num_expert_group if num_expert_group is not None else 1
|
|
||||||
if not (num_expert_group > 0 and hidden_states.shape[-1] % num_expert_group
|
|
||||||
== 0 and hidden_states.shape[-1] // num_expert_group > 2):
|
|
||||||
return False
|
|
||||||
if topk_group < 1 or topk_group > num_expert_group:
|
|
||||||
return False
|
|
||||||
if top_k < 1 or \
|
|
||||||
top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)):
|
|
||||||
return False
|
|
||||||
if topk_group * hidden_states.shape[-1] / num_expert_group < top_k:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def _native_grouped_topk(
|
def _native_grouped_topk(
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
num_expert_group: Optional[int],
|
num_expert_group: Optional[int],
|
||||||
@@ -202,31 +171,34 @@ def _select_experts_with_fusion_ops(
|
|||||||
e_score_correction_bias: Optional[torch.Tensor],
|
e_score_correction_bias: Optional[torch.Tensor],
|
||||||
topk_group: Optional[int],
|
topk_group: Optional[int],
|
||||||
num_expert_group: Optional[int],
|
num_expert_group: Optional[int],
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
routed_scaling_factor=1.0,
|
routed_scaling_factor=1.0,
|
||||||
global_num_experts: int = -1):
|
global_num_experts: int = -1):
|
||||||
|
|
||||||
topk_group = topk_group if topk_group is not None else 1
|
topk_weights, topk_ids = None, None
|
||||||
num_expert_group = num_expert_group if num_expert_group is not None else 1
|
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
|
||||||
norm_type = 0 if scoring_func == "softmax" else 1
|
global_redundant_expert_num = get_ascend_config().init_redundancy_expert
|
||||||
if e_score_correction_bias is not None and \
|
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
|
||||||
e_score_correction_bias.dtype != router_logits.dtype:
|
if is_deepseek_v3_r1:
|
||||||
e_score_correction_bias = e_score_correction_bias.to(
|
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||||
router_logits.dtype)
|
router_logits,
|
||||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
k=top_k, # topk currently 8
|
||||||
router_logits,
|
bias=e_score_correction_bias,
|
||||||
k=top_k,
|
k_group=topk_group, # fix: 4
|
||||||
bias=e_score_correction_bias,
|
group_count=num_expert_group, # fix 8
|
||||||
k_group=topk_group,
|
group_select_mode=
|
||||||
group_count=num_expert_group,
|
1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
||||||
group_select_mode=1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||||
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
||||||
norm_type=norm_type, # 0: softmax; 1: sigmoid
|
# out_flag=False, # todo new api; should the third output be output
|
||||||
# out_flag=False, # todo new api; should the third output be output
|
# y2_flag=False, # old api; should the third output be output
|
||||||
# y2_flag=False, # old api; should the third output be output
|
routed_scaling_factor=1,
|
||||||
routed_scaling_factor=1,
|
eps=float(1e-20))
|
||||||
eps=float(1e-20))
|
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax":
|
||||||
if scoring_func == "softmax":
|
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
|
||||||
|
x=router_logits, finished=None, k=top_k)
|
||||||
|
topk_ids = topk_ids.to(torch.int32)
|
||||||
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|||||||
Reference in New Issue
Block a user