[Perf]Enable npu_moe_gating_top_k_softmax on quantized scenarios (#2633)
### What this PR does / why we need it?
This PR enables `npu_moe_gating_top_k_softmax` when running quantized
MoE (such as W8A8). This op in fact makes no distinction between
quantized and non-quantized scenarios. Introducing this op reduces 3~4ms
for TPOT.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
ce30dca5c4
Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
@@ -716,8 +716,19 @@ class TestSelectExperts(TestBase):
|
||||
self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
|
||||
self.router_logits = torch.randn(self.num_tokens, self.num_experts)
|
||||
|
||||
def test_softmax_scoring(self):
|
||||
@patch('torch_npu.npu_moe_gating_top_k_softmax')
|
||||
def test_softmax_scoring(self, mock_topk):
|
||||
"""Test softmax scoring function"""
|
||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||
torch.zeros(self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.long),
|
||||
torch.arange(0,
|
||||
self.num_tokens * self.top_k,
|
||||
dtype=torch.int32).view(
|
||||
self.top_k,
|
||||
-1).permute(1,
|
||||
0).contiguous())
|
||||
|
||||
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
@@ -816,13 +827,19 @@ class TestSelectExperts(TestBase):
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.dtype, torch.int32)
|
||||
|
||||
@patch('torch.topk')
|
||||
@patch('torch_npu.npu_moe_gating_top_k_softmax')
|
||||
def test_renormalize(self, mock_topk):
|
||||
"""Test weight renormalization"""
|
||||
"""Test renormalization"""
|
||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||
torch.zeros(self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.long))
|
||||
dtype=torch.long),
|
||||
torch.arange(0,
|
||||
self.num_tokens * self.top_k,
|
||||
dtype=torch.int32).view(
|
||||
self.top_k,
|
||||
-1).permute(1,
|
||||
0).contiguous())
|
||||
|
||||
weights, ids, _ = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
@@ -836,13 +853,19 @@ class TestSelectExperts(TestBase):
|
||||
sums = weights.sum(dim=-1)
|
||||
self.assertTrue(torch.allclose(sums, torch.ones_like(sums)))
|
||||
|
||||
@patch('torch.topk')
|
||||
@patch('torch_npu.npu_moe_gating_top_k_softmax')
|
||||
def test_output_dtypes(self, mock_topk):
|
||||
"""Test output dtypes"""
|
||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||
torch.zeros(self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.long))
|
||||
dtype=torch.long),
|
||||
torch.arange(0,
|
||||
self.num_tokens * self.top_k,
|
||||
dtype=torch.int32).view(
|
||||
self.top_k,
|
||||
-1).permute(1,
|
||||
0).contiguous())
|
||||
|
||||
weights, ids, _ = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
|
||||
@@ -173,8 +173,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts,
|
||||
is_unquantized=True)
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
# this is a naive implementation for experts load balance so as
|
||||
|
||||
@@ -43,7 +43,6 @@ def select_experts(hidden_states: torch.Tensor,
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
indices_type: Optional[torch.dtype] = None,
|
||||
is_unquantized: bool = False,
|
||||
global_num_experts: int = -1):
|
||||
"""
|
||||
Fused experts with select experts.
|
||||
@@ -60,7 +59,6 @@ def select_experts(hidden_states: torch.Tensor,
|
||||
scoring_func: Scoring function to use.
|
||||
e_score_correction_bias: Correction bias to apply to expert scores.
|
||||
indices_type: dtype of indices
|
||||
is_unquantized: Whether the data are unquantized.
|
||||
global_num_experts: Global number of experts.
|
||||
|
||||
Returns:
|
||||
@@ -80,8 +78,7 @@ def select_experts(hidden_states: torch.Tensor,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
global_num_experts=global_num_experts,
|
||||
is_unquantized=is_unquantized)
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
if topk_weights is None:
|
||||
topk_weights, topk_ids = _native_select_experts(
|
||||
@@ -183,8 +180,7 @@ def _select_experts_with_fusion_ops(
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor=1.0,
|
||||
global_num_experts: int = -1,
|
||||
is_unquantized: bool = False):
|
||||
global_num_experts: int = -1):
|
||||
|
||||
topk_weights, topk_ids, row_idx = None, None, None
|
||||
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
|
||||
@@ -205,7 +201,7 @@ def _select_experts_with_fusion_ops(
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20))
|
||||
row_idx = return_row_idx(hidden_states, top_k)
|
||||
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax" and is_unquantized:
|
||||
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax":
|
||||
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
|
||||
x=router_logits, finished=None, k=top_k)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
|
||||
Reference in New Issue
Block a user