[Perf] Optimize fused_experts quantization code to save npu memory (#784)

### What this PR does / why we need it?
In the w8a8 quantization code of `fused_experts`, the output of almost
every operator is assigned a new variable name. If we want to save NPU
memory, we manually `del` these variables to end their lifecycle, which
fills the code with `del` statements and looks inelegant.
Therefore, I plan to names the output of most operators as
`hidden_states`, thereby ending the lifecycle of the previous
`hidden_states`.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?

Signed-off-by: ApsarasX <apsarax@outlook.com>
This commit is contained in:
ApsarasX
2025-05-09 15:09:37 +08:00
committed by GitHub
parent 2c685e3b61
commit 324f819b92
2 changed files with 50 additions and 41 deletions

View File

@@ -222,9 +222,6 @@ class CustomDeepseekV2MoE(nn.Module):
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
if (self.tp_size > 1 and self.enable_mc2 and not is_prefill):
chunks = torch.chunk(hidden_states,
get_tp_group().world_size,
@@ -248,8 +245,8 @@ class CustomDeepseekV2MoE(nn.Module):
else:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
if shared_output is not None:
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
final_hidden_states = final_hidden_states + shared_output
return final_hidden_states.view(num_tokens, hidden_dim)

View File

@@ -16,7 +16,7 @@
#
import os
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, List, Optional
import torch
import torch_npu
@@ -25,7 +25,7 @@ from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe import select_experts
def apply_mlp(x: torch.Tensor,
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
@@ -37,7 +37,7 @@ def apply_mlp(x: torch.Tensor,
apply MLP: gate_up_proj -> swiglu -> down_proj
Args:
x: input hidden states with shape (num_tokens, hidden_size).
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
w1: expert weights1 with shape
(num_experts, hidden_size, intermediate_size * 2)
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
@@ -56,23 +56,27 @@ def apply_mlp(x: torch.Tensor,
hidden_states: output hidden states after MLP.
"""
assert len(hidden_states_wrapper) == 1
hidden_states = hidden_states_wrapper.pop()
if dynamic_scale is None:
h, pertoken_scale = torch_npu.npu_dynamic_quant(x)
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
hidden_states)
else:
h = x
pertoken_scale = dynamic_scale
# gmm1: gate_up_proj
gate_up_out = torch_npu.npu_grouped_matmul(x=[h],
weight=[w1],
split_item=3,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32)[0]
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=3,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32)[0]
swiglu_out, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=gate_up_out,
# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=w1_scale,
activation_scale=pertoken_scale,
bias=None,
@@ -83,17 +87,18 @@ def apply_mlp(x: torch.Tensor,
quant_mode=1,
)
# down_proj
down_out = torch_npu.npu_grouped_matmul(x=[swiglu_out],
weight=[w2],
scale=[w2_scale],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=w2_scale.dtype)[0]
return down_out
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
scale=[w2_scale],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=w2_scale.dtype)[0]
return hidden_states
def fused_experts_with_mc2(
@@ -152,7 +157,11 @@ def fused_experts_with_mc2(
if quant_mode == 0:
dynamic_scale = None
down_out_list = apply_mlp(expand_x,
# place hidden_states in a list to transfer its ownership into the `apply_mlp` function
hidden_states_wrapper = [expand_x]
del expand_x
down_out_list = apply_mlp(hidden_states_wrapper,
w1,
w1_scale,
w2,
@@ -246,7 +255,7 @@ def fused_experts(hidden_states: torch.Tensor,
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
expert_tokens = token_counts[:num_experts]
# Rearrange hidden_states
sorted_hidden_states = hidden_states[sorted_token_indices]
hidden_states = hidden_states[sorted_token_indices]
group_list_type = 1
else:
row_idx_len = num_tokens * top_k
@@ -255,19 +264,22 @@ def fused_experts(hidden_states: torch.Tensor,
dtype=torch.int32,
device=topk_weights.device).view(
top_k, -1).permute(1, 0).contiguous()
sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens)
del hidden_states
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
expanded_expert_idx, num_experts)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 0
down_out_list = apply_mlp(sorted_hidden_states,
# place hidden_states in a list to transfer its ownership into the `apply_mlp` function
hidden_states_wrapper = [hidden_states]
del hidden_states
hidden_states = apply_mlp(hidden_states_wrapper,
w1,
w1_scale,
w2,
@@ -276,23 +288,23 @@ def fused_experts(hidden_states: torch.Tensor,
group_list_type=group_list_type)
if expert_map is not None:
down_out_list.mul_(sorted_weights.unsqueeze(1))
hidden_states.mul_(sorted_weights.unsqueeze(1))
final_hidden_states = torch.zeros(*original_shape,
device=hidden_states.device,
device=device,
dtype=dtype)
num_valid_tokens = mask.sum()
valid_token_mask = torch.arange(
0, sorted_token_indices.shape[0],
device=device).unsqueeze(1) < num_valid_tokens
down_out_list = down_out_list.masked_fill_(~valid_token_mask,
hidden_states = hidden_states.masked_fill_(~valid_token_mask,
0).to(dtype)
final_hidden_states.index_add_(0, sorted_token_indices, down_out_list)
final_hidden_states.index_add_(0, sorted_token_indices, hidden_states)
else:
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.
final_hidden_states = torch_npu.npu_moe_finalize_routing(
down_out_list,
hidden_states,
skip1=None,
skip2=None,
bias=None,
@@ -300,7 +312,7 @@ def fused_experts(hidden_states: torch.Tensor,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
)
del down_out_list
if len(original_shape) == 3:
final_hidden_states = final_hidden_states.view(original_shape)
return final_hidden_states