[Qwen-moe] Remove the minor operation arange (#2373)
### What this PR does / why we need it?
Integrate the arange operator to reduce the time spent and improve
performance
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
56dcf4e7e9
---------
Signed-off-by: s30076806 <songjiayang2@h-partners.com>
This commit is contained in:
@@ -365,14 +365,9 @@ def fused_experts_with_mc2(
|
||||
return hidden_states, shared_output
|
||||
|
||||
|
||||
def init_routing_quant(hidden_states, top_k, topk_ids, global_num_experts):
|
||||
def init_routing_quant(hidden_states, top_k, topk_ids, row_idx,
|
||||
global_num_experts):
|
||||
num_tokens, _ = hidden_states.shape
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = (torch.arange(0,
|
||||
row_idx_len,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device).view(
|
||||
top_k, -1).permute(1, 0).contiguous())
|
||||
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
@@ -398,6 +393,7 @@ def fused_experts_with_all2all(
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
ep_group: GroupCoordinator = None,
|
||||
@@ -431,7 +427,7 @@ def fused_experts_with_all2all(
|
||||
)
|
||||
else:
|
||||
quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales = init_routing_quant(
|
||||
hidden_states, top_k, topk_ids, global_num_experts)
|
||||
hidden_states, top_k, topk_ids, row_idx, global_num_experts)
|
||||
|
||||
gather_sizes = global_expert_tokens.new_empty(
|
||||
global_expert_tokens.shape[0])
|
||||
@@ -463,12 +459,6 @@ def fused_experts_with_all2all(
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 1
|
||||
else:
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = torch.arange(0,
|
||||
row_idx_len,
|
||||
dtype=torch.int32,
|
||||
device=topk_weights.device).view(
|
||||
top_k, -1).permute(1, 0).contiguous()
|
||||
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
@@ -627,6 +617,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None):
|
||||
original_shape = hidden_states.shape
|
||||
@@ -677,12 +668,6 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
hidden_states = hidden_states[sorted_token_indices]
|
||||
group_list_type = 1
|
||||
else:
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = torch.arange(0,
|
||||
row_idx_len,
|
||||
dtype=torch.int32,
|
||||
device=topk_weights.device).view(
|
||||
top_k, -1).permute(1, 0).contiguous()
|
||||
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
@@ -903,7 +888,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
topk_weights, topk_ids, row_idx = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
@@ -973,6 +958,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map)
|
||||
else:
|
||||
@@ -988,6 +974,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
ep_group=self.ep_group,
|
||||
|
||||
Reference in New Issue
Block a user