[Qwen-moe] Remove the minor operation arange (#2373)
### What this PR does / why we need it?
Integrate the arange operator to reduce the time spent and improve
performance
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
56dcf4e7e9
---------
Signed-off-by: s30076806 <songjiayang2@h-partners.com>
This commit is contained in:
@@ -92,8 +92,15 @@ def test_fused_experts(
|
||||
score = torch.softmax(score, dim=-1, dtype=dtype)
|
||||
topk_weights, topk_ids = torch.topk(score, topk)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
row_idx = (torch.arange(
|
||||
0,
|
||||
m * topk,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
).view(topk, -1).permute(1, 0).contiguous())
|
||||
|
||||
output = fused_experts(a, w1, w2, topk_weights, topk_ids, topk, e_map)
|
||||
output = fused_experts(a, w1, w2, topk_weights, topk_ids, row_idx, topk,
|
||||
e_map)
|
||||
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, e_map)
|
||||
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
|
||||
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
|
||||
@@ -148,7 +155,7 @@ def test_select_experts(
|
||||
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
|
||||
x)
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
topk_weights, topk_ids, row_idx = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=topk,
|
||||
@@ -169,6 +176,7 @@ def test_select_experts(
|
||||
assert topk_weights.shape == (m, topk)
|
||||
assert topk_ids.shape == (m, topk)
|
||||
assert topk_ids.dtype == torch.int32
|
||||
assert row_idx.shape == (m, topk)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICE)
|
||||
|
||||
@@ -405,7 +405,7 @@ class TestExpertsSelector:
|
||||
|
||||
x = torch.randn(8, 2)
|
||||
router_logits = torch.randn(8, 2)
|
||||
topk_weights, topk_ids = select_experts(
|
||||
topk_weights, topk_ids, _ = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=2,
|
||||
|
||||
@@ -719,12 +719,12 @@ class TestSelectExperts(TestBase):
|
||||
def test_softmax_scoring(self):
|
||||
"""Test softmax scoring function"""
|
||||
|
||||
weights, ids = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="softmax")
|
||||
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="softmax")
|
||||
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
@@ -732,12 +732,12 @@ class TestSelectExperts(TestBase):
|
||||
def test_sigmoid_scoring(self):
|
||||
"""Test sigmoid scoring function"""
|
||||
|
||||
weights, ids = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="sigmoid")
|
||||
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="sigmoid")
|
||||
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
@@ -760,13 +760,13 @@ class TestSelectExperts(TestBase):
|
||||
self.top_k,
|
||||
dtype=torch.long))
|
||||
|
||||
weights, ids = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=False,
|
||||
topk_group=4,
|
||||
num_expert_group=2)
|
||||
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=False,
|
||||
topk_group=4,
|
||||
num_expert_group=2)
|
||||
|
||||
mock_topk.assert_called()
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
@@ -780,7 +780,7 @@ class TestSelectExperts(TestBase):
|
||||
self.num_experts)
|
||||
|
||||
e_score_correction_bias = torch.randn(self.num_experts)
|
||||
weights, ids = select_experts(
|
||||
weights, ids, _ = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
@@ -803,7 +803,7 @@ class TestSelectExperts(TestBase):
|
||||
self.top_k,
|
||||
dtype=torch.int32))
|
||||
|
||||
weights, ids = select_experts(
|
||||
weights, ids, _ = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
@@ -824,7 +824,7 @@ class TestSelectExperts(TestBase):
|
||||
self.top_k,
|
||||
dtype=torch.long))
|
||||
|
||||
weights, _ = select_experts(
|
||||
weights, ids, _ = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
@@ -844,7 +844,7 @@ class TestSelectExperts(TestBase):
|
||||
self.top_k,
|
||||
dtype=torch.long))
|
||||
|
||||
weights, ids = select_experts(
|
||||
weights, ids, _ = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
|
||||
@@ -55,6 +55,12 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
torch.randn(self.num_tokens),
|
||||
)
|
||||
mock_moe_finalize_routing.return_value = self.placeholder
|
||||
row_idx_len = self.num_tokens * 8
|
||||
row_idx = (torch.arange(
|
||||
0,
|
||||
row_idx_len,
|
||||
dtype=torch.int32,
|
||||
).view(8, -1).permute(1, 0).contiguous())
|
||||
|
||||
result = fused_experts_with_all2all(
|
||||
hidden_states=self.placeholder,
|
||||
@@ -64,6 +70,7 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
w2_scale=self.placeholder,
|
||||
topk_weights=self.placeholder,
|
||||
topk_ids=self.placeholder,
|
||||
row_idx=row_idx,
|
||||
top_k=8,
|
||||
expert_map=expert_map,
|
||||
ep_group=ep_group,
|
||||
|
||||
@@ -130,7 +130,7 @@ def forward_oot(
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
topk_weights, topk_ids, _ = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
|
||||
@@ -326,6 +326,7 @@ def fused_experts_with_all2all(
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
ep_group: GroupCoordinator = None,
|
||||
@@ -336,17 +337,10 @@ def fused_experts_with_all2all(
|
||||
|
||||
num_tokens, _ = hidden_states.shape
|
||||
num_experts = w1.shape[0]
|
||||
device = hidden_states.device
|
||||
|
||||
if expert_map is not None:
|
||||
global_num_experts = len(expert_map)
|
||||
local_num_experts = global_num_experts // ep_group.world_size
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = (torch.arange(0,
|
||||
row_idx_len,
|
||||
dtype=torch.int32,
|
||||
device=device).view(top_k, -1).permute(
|
||||
1, 0).contiguous())
|
||||
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
@@ -380,12 +374,6 @@ def fused_experts_with_all2all(
|
||||
|
||||
hidden_states = hidden_states[sorted_idx]
|
||||
else:
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = torch.arange(0,
|
||||
row_idx_len,
|
||||
dtype=torch.int32,
|
||||
device=topk_weights.device).view(
|
||||
top_k, -1).permute(1, 0).contiguous()
|
||||
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
@@ -459,6 +447,7 @@ def fused_experts_with_all2all_buffer(
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
top_k: int,
|
||||
max_model_len: int,
|
||||
global_batch_size: int,
|
||||
@@ -470,14 +459,10 @@ def fused_experts_with_all2all_buffer(
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
num_tokens, _ = hidden_states.shape
|
||||
device = hidden_states.device
|
||||
|
||||
global_num_experts = len(expert_map)
|
||||
local_num_experts = global_num_experts // ep_group.world_size
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32,
|
||||
device=device).view(top_k,
|
||||
-1).permute(1, 0).contiguous())
|
||||
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
@@ -690,6 +675,7 @@ def fused_experts(
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
@@ -781,12 +767,6 @@ def fused_experts(
|
||||
# Rearrange hidden_states
|
||||
sorted_hidden_states = hidden_states[sorted_token_indices]
|
||||
else:
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = (torch.arange(0,
|
||||
row_idx_len,
|
||||
dtype=torch.int32,
|
||||
device=device).view(top_k, -1).permute(
|
||||
1, 0).contiguous())
|
||||
active_num = max_num_tokens if max_num_tokens is not None else num_tokens
|
||||
sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
@@ -908,7 +888,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
topk_weights, topk_ids, row_idx = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
@@ -952,6 +932,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map)
|
||||
elif MOE_ALL2ALL_BUFFER:
|
||||
@@ -961,6 +942,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
top_k=top_k,
|
||||
max_model_len=self.max_model_len,
|
||||
global_batch_size=self.global_batch_size,
|
||||
@@ -982,6 +964,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
ep_group=get_ep_group())
|
||||
|
||||
@@ -20,6 +20,17 @@ import torch
|
||||
import torch_npu
|
||||
|
||||
|
||||
def return_row_idx(hidden_states, top_k):
|
||||
num_tokens = hidden_states.shape[0]
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = (torch.arange(0,
|
||||
row_idx_len,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device).view(
|
||||
top_k, -1).permute(1, 0).contiguous())
|
||||
return row_idx
|
||||
|
||||
|
||||
def select_experts(hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
@@ -56,7 +67,8 @@ def select_experts(hidden_states: torch.Tensor,
|
||||
topk_ids: selected expert IDs of shape (num_tokens, top_k).
|
||||
"""
|
||||
|
||||
topk_weights, topk_ids = _select_experts_with_fusion_ops(
|
||||
topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
@@ -83,7 +95,9 @@ def select_experts(hidden_states: torch.Tensor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts,
|
||||
)
|
||||
return topk_weights, topk_ids
|
||||
if row_idx is None:
|
||||
row_idx = return_row_idx(hidden_states, top_k)
|
||||
return topk_weights, topk_ids, row_idx
|
||||
|
||||
|
||||
def _native_grouped_topk(
|
||||
@@ -156,6 +170,7 @@ def _select_expert_use_group_topk(
|
||||
|
||||
|
||||
def _select_experts_with_fusion_ops(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
@@ -168,7 +183,7 @@ def _select_experts_with_fusion_ops(
|
||||
global_num_experts: int = -1,
|
||||
is_unquantized: bool = False):
|
||||
|
||||
topk_weights, topk_ids = None, None
|
||||
topk_weights, topk_ids, row_idx = None, None, None
|
||||
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
if is_deepseek_v3_r1:
|
||||
@@ -186,14 +201,14 @@ def _select_experts_with_fusion_ops(
|
||||
# y2_flag=False, # old api; should the third output be output
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20))
|
||||
|
||||
row_idx = return_row_idx(hidden_states, top_k)
|
||||
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax" and is_unquantized:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
|
||||
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
|
||||
x=router_logits, finished=None, k=top_k)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
return topk_weights, topk_ids, row_idx
|
||||
|
||||
|
||||
def _native_select_experts(
|
||||
|
||||
@@ -268,7 +268,7 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
topk_weights, topk_ids = select_experts(
|
||||
topk_weights, topk_ids, row_idx = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
@@ -334,6 +334,7 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
w2_scale_bias=layer.w2_scale_bias,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
ep_group=self.ep_group,
|
||||
|
||||
@@ -365,14 +365,9 @@ def fused_experts_with_mc2(
|
||||
return hidden_states, shared_output
|
||||
|
||||
|
||||
def init_routing_quant(hidden_states, top_k, topk_ids, global_num_experts):
|
||||
def init_routing_quant(hidden_states, top_k, topk_ids, row_idx,
|
||||
global_num_experts):
|
||||
num_tokens, _ = hidden_states.shape
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = (torch.arange(0,
|
||||
row_idx_len,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device).view(
|
||||
top_k, -1).permute(1, 0).contiguous())
|
||||
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
@@ -398,6 +393,7 @@ def fused_experts_with_all2all(
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
ep_group: GroupCoordinator = None,
|
||||
@@ -431,7 +427,7 @@ def fused_experts_with_all2all(
|
||||
)
|
||||
else:
|
||||
quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales = init_routing_quant(
|
||||
hidden_states, top_k, topk_ids, global_num_experts)
|
||||
hidden_states, top_k, topk_ids, row_idx, global_num_experts)
|
||||
|
||||
gather_sizes = global_expert_tokens.new_empty(
|
||||
global_expert_tokens.shape[0])
|
||||
@@ -463,12 +459,6 @@ def fused_experts_with_all2all(
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 1
|
||||
else:
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = torch.arange(0,
|
||||
row_idx_len,
|
||||
dtype=torch.int32,
|
||||
device=topk_weights.device).view(
|
||||
top_k, -1).permute(1, 0).contiguous()
|
||||
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
@@ -627,6 +617,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None):
|
||||
original_shape = hidden_states.shape
|
||||
@@ -677,12 +668,6 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
hidden_states = hidden_states[sorted_token_indices]
|
||||
group_list_type = 1
|
||||
else:
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = torch.arange(0,
|
||||
row_idx_len,
|
||||
dtype=torch.int32,
|
||||
device=topk_weights.device).view(
|
||||
top_k, -1).permute(1, 0).contiguous()
|
||||
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
@@ -903,7 +888,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
topk_weights, topk_ids, row_idx = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
@@ -973,6 +958,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map)
|
||||
else:
|
||||
@@ -988,6 +974,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
ep_group=self.ep_group,
|
||||
|
||||
Reference in New Issue
Block a user