[Qwen-moe] Remove the minor operation arange (#2373)

### What this PR does / why we need it?
Integrate the arange operator to reduce the time spent and improve
performance

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?

- vLLM version: v0.10.1.1
- vLLM main:
56dcf4e7e9

---------

Signed-off-by: s30076806 <songjiayang2@h-partners.com>
This commit is contained in:
s30076806
2025-08-27 09:13:31 +08:00
committed by GitHub
parent 358ba68994
commit 6a4ec186e7
9 changed files with 80 additions and 79 deletions

View File

@@ -130,7 +130,7 @@ def forward_oot(
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
topk_weights, topk_ids, _ = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,

View File

@@ -326,6 +326,7 @@ def fused_experts_with_all2all(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
ep_group: GroupCoordinator = None,
@@ -336,17 +337,10 @@ def fused_experts_with_all2all(
num_tokens, _ = hidden_states.shape
num_experts = w1.shape[0]
device = hidden_states.device
if expert_map is not None:
global_num_experts = len(expert_map)
local_num_experts = global_num_experts // ep_group.world_size
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=device).view(top_k, -1).permute(
1, 0).contiguous())
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
@@ -380,12 +374,6 @@ def fused_experts_with_all2all(
hidden_states = hidden_states[sorted_idx]
else:
row_idx_len = num_tokens * top_k
row_idx = torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=topk_weights.device).view(
top_k, -1).permute(1, 0).contiguous()
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
@@ -459,6 +447,7 @@ def fused_experts_with_all2all_buffer(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
top_k: int,
max_model_len: int,
global_batch_size: int,
@@ -470,14 +459,10 @@ def fused_experts_with_all2all_buffer(
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
num_tokens, _ = hidden_states.shape
device = hidden_states.device
global_num_experts = len(expert_map)
local_num_experts = global_num_experts // ep_group.world_size
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32,
device=device).view(top_k,
-1).permute(1, 0).contiguous())
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
@@ -690,6 +675,7 @@ def fused_experts(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
apply_router_weight_on_input: bool = False,
@@ -781,12 +767,6 @@ def fused_experts(
# Rearrange hidden_states
sorted_hidden_states = hidden_states[sorted_token_indices]
else:
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=device).view(top_k, -1).permute(
1, 0).contiguous())
active_num = max_num_tokens if max_num_tokens is not None else num_tokens
sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
@@ -908,7 +888,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
**kwargs,
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
@@ -952,6 +932,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
top_k=top_k,
expert_map=expert_map)
elif MOE_ALL2ALL_BUFFER:
@@ -961,6 +942,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
top_k=top_k,
max_model_len=self.max_model_len,
global_batch_size=self.global_batch_size,
@@ -982,6 +964,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
top_k=top_k,
expert_map=expert_map,
ep_group=get_ep_group())

View File

@@ -20,6 +20,17 @@ import torch
import torch_npu
def return_row_idx(hidden_states, top_k):
num_tokens = hidden_states.shape[0]
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=hidden_states.device).view(
top_k, -1).permute(1, 0).contiguous())
return row_idx
def select_experts(hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -56,7 +67,8 @@ def select_experts(hidden_states: torch.Tensor,
topk_ids: selected expert IDs of shape (num_tokens, top_k).
"""
topk_weights, topk_ids = _select_experts_with_fusion_ops(
topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
@@ -83,7 +95,9 @@ def select_experts(hidden_states: torch.Tensor,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
)
return topk_weights, topk_ids
if row_idx is None:
row_idx = return_row_idx(hidden_states, top_k)
return topk_weights, topk_ids, row_idx
def _native_grouped_topk(
@@ -156,6 +170,7 @@ def _select_expert_use_group_topk(
def _select_experts_with_fusion_ops(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
@@ -168,7 +183,7 @@ def _select_experts_with_fusion_ops(
global_num_experts: int = -1,
is_unquantized: bool = False):
topk_weights, topk_ids = None, None
topk_weights, topk_ids, row_idx = None, None, None
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
is_deepseek_v3_r1 = global_num_experts == 256
if is_deepseek_v3_r1:
@@ -186,14 +201,14 @@ def _select_experts_with_fusion_ops(
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
row_idx = return_row_idx(hidden_states, top_k)
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax" and is_unquantized:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
x=router_logits, finished=None, k=top_k)
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids
return topk_weights, topk_ids, row_idx
def _native_select_experts(