[Qwen-moe] Remove the minor operation arange (#2373)

### What this PR does / why we need it?
Integrate the arange operator to reduce the time spent and improve
performance

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?

- vLLM version: v0.10.1.1
- vLLM main:
56dcf4e7e9

---------

Signed-off-by: s30076806 <songjiayang2@h-partners.com>
This commit is contained in:
s30076806
2025-08-27 09:13:31 +08:00
committed by GitHub
parent 358ba68994
commit 6a4ec186e7
9 changed files with 80 additions and 79 deletions

View File

@@ -365,14 +365,9 @@ def fused_experts_with_mc2(
return hidden_states, shared_output
def init_routing_quant(hidden_states, top_k, topk_ids, global_num_experts):
def init_routing_quant(hidden_states, top_k, topk_ids, row_idx,
global_num_experts):
num_tokens, _ = hidden_states.shape
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=hidden_states.device).view(
top_k, -1).permute(1, 0).contiguous())
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
@@ -398,6 +393,7 @@ def fused_experts_with_all2all(
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
ep_group: GroupCoordinator = None,
@@ -431,7 +427,7 @@ def fused_experts_with_all2all(
)
else:
quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales = init_routing_quant(
hidden_states, top_k, topk_ids, global_num_experts)
hidden_states, top_k, topk_ids, row_idx, global_num_experts)
gather_sizes = global_expert_tokens.new_empty(
global_expert_tokens.shape[0])
@@ -463,12 +459,6 @@ def fused_experts_with_all2all(
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1
else:
row_idx_len = num_tokens * top_k
row_idx = torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=topk_weights.device).view(
top_k, -1).permute(1, 0).contiguous()
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
@@ -627,6 +617,7 @@ def fused_experts(hidden_states: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None):
original_shape = hidden_states.shape
@@ -677,12 +668,6 @@ def fused_experts(hidden_states: torch.Tensor,
hidden_states = hidden_states[sorted_token_indices]
group_list_type = 1
else:
row_idx_len = num_tokens * top_k
row_idx = torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=topk_weights.device).view(
top_k, -1).permute(1, 0).contiguous()
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
@@ -903,7 +888,7 @@ class AscendW8A8DynamicFusedMoEMethod:
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"
topk_weights, topk_ids = select_experts(
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
@@ -973,6 +958,7 @@ class AscendW8A8DynamicFusedMoEMethod:
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
top_k=top_k,
expert_map=expert_map)
else:
@@ -988,6 +974,7 @@ class AscendW8A8DynamicFusedMoEMethod:
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
top_k=top_k,
expert_map=expert_map,
ep_group=self.ep_group,