[Perf] Optimize fused_experts quantization code to save npu memory (#784)
### What this PR does / why we need it? In the w8a8 quantization code of `fused_experts`, the output of almost every operator is assigned a new variable name. If we want to save NPU memory, we manually `del` these variables to end their lifecycle, which fills the code with `del` statements and looks inelegant. Therefore, I plan to names the output of most operators as `hidden_states`, thereby ending the lifecycle of the previous `hidden_states`. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Signed-off-by: ApsarasX <apsarax@outlook.com>
This commit is contained in:
@@ -222,9 +222,6 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
if (self.tp_size > 1 and self.enable_mc2 and not is_prefill):
|
||||
chunks = torch.chunk(hidden_states,
|
||||
get_tp_group().world_size,
|
||||
@@ -248,8 +245,8 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
else:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
if shared_output is not None:
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
#
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
@@ -25,7 +25,7 @@ from vllm_ascend.distributed.parallel_state import get_ep_group
|
||||
from vllm_ascend.ops.fused_moe import select_experts
|
||||
|
||||
|
||||
def apply_mlp(x: torch.Tensor,
|
||||
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@@ -37,7 +37,7 @@ def apply_mlp(x: torch.Tensor,
|
||||
apply MLP: gate_up_proj -> swiglu -> down_proj
|
||||
|
||||
Args:
|
||||
x: input hidden states with shape (num_tokens, hidden_size).
|
||||
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
|
||||
w1: expert weights1 with shape
|
||||
(num_experts, hidden_size, intermediate_size * 2)
|
||||
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
|
||||
@@ -56,23 +56,27 @@ def apply_mlp(x: torch.Tensor,
|
||||
hidden_states: output hidden states after MLP.
|
||||
"""
|
||||
|
||||
assert len(hidden_states_wrapper) == 1
|
||||
hidden_states = hidden_states_wrapper.pop()
|
||||
if dynamic_scale is None:
|
||||
h, pertoken_scale = torch_npu.npu_dynamic_quant(x)
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
else:
|
||||
h = x
|
||||
pertoken_scale = dynamic_scale
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(x=[h],
|
||||
weight=[w1],
|
||||
split_item=3,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32)[0]
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=3,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32)[0]
|
||||
|
||||
swiglu_out, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=gate_up_out,
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=w1_scale,
|
||||
activation_scale=pertoken_scale,
|
||||
bias=None,
|
||||
@@ -83,17 +87,18 @@ def apply_mlp(x: torch.Tensor,
|
||||
quant_mode=1,
|
||||
)
|
||||
|
||||
# down_proj
|
||||
down_out = torch_npu.npu_grouped_matmul(x=[swiglu_out],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=w2_scale.dtype)[0]
|
||||
return down_out
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=w2_scale.dtype)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def fused_experts_with_mc2(
|
||||
@@ -152,7 +157,11 @@ def fused_experts_with_mc2(
|
||||
if quant_mode == 0:
|
||||
dynamic_scale = None
|
||||
|
||||
down_out_list = apply_mlp(expand_x,
|
||||
# place hidden_states in a list to transfer its ownership into the `apply_mlp` function
|
||||
hidden_states_wrapper = [expand_x]
|
||||
del expand_x
|
||||
|
||||
down_out_list = apply_mlp(hidden_states_wrapper,
|
||||
w1,
|
||||
w1_scale,
|
||||
w2,
|
||||
@@ -246,7 +255,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
||||
expert_tokens = token_counts[:num_experts]
|
||||
# Rearrange hidden_states
|
||||
sorted_hidden_states = hidden_states[sorted_token_indices]
|
||||
hidden_states = hidden_states[sorted_token_indices]
|
||||
group_list_type = 1
|
||||
else:
|
||||
row_idx_len = num_tokens * top_k
|
||||
@@ -255,19 +264,22 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
dtype=torch.int32,
|
||||
device=topk_weights.device).view(
|
||||
top_k, -1).permute(1, 0).contiguous()
|
||||
sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
expert_idx=topk_ids,
|
||||
active_num=num_tokens)
|
||||
del hidden_states
|
||||
|
||||
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
||||
expanded_expert_idx, num_experts)
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 0
|
||||
|
||||
down_out_list = apply_mlp(sorted_hidden_states,
|
||||
# place hidden_states in a list to transfer its ownership into the `apply_mlp` function
|
||||
hidden_states_wrapper = [hidden_states]
|
||||
del hidden_states
|
||||
|
||||
hidden_states = apply_mlp(hidden_states_wrapper,
|
||||
w1,
|
||||
w1_scale,
|
||||
w2,
|
||||
@@ -276,23 +288,23 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
group_list_type=group_list_type)
|
||||
|
||||
if expert_map is not None:
|
||||
down_out_list.mul_(sorted_weights.unsqueeze(1))
|
||||
hidden_states.mul_(sorted_weights.unsqueeze(1))
|
||||
final_hidden_states = torch.zeros(*original_shape,
|
||||
device=hidden_states.device,
|
||||
device=device,
|
||||
dtype=dtype)
|
||||
|
||||
num_valid_tokens = mask.sum()
|
||||
valid_token_mask = torch.arange(
|
||||
0, sorted_token_indices.shape[0],
|
||||
device=device).unsqueeze(1) < num_valid_tokens
|
||||
down_out_list = down_out_list.masked_fill_(~valid_token_mask,
|
||||
hidden_states = hidden_states.masked_fill_(~valid_token_mask,
|
||||
0).to(dtype)
|
||||
final_hidden_states.index_add_(0, sorted_token_indices, down_out_list)
|
||||
final_hidden_states.index_add_(0, sorted_token_indices, hidden_states)
|
||||
else:
|
||||
# TODO: Reorder device memory 2 times here, replace the current
|
||||
# implementation here when suitable operators become available.
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
down_out_list,
|
||||
hidden_states,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
@@ -300,7 +312,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
expanded_src_to_dst_row=expanded_row_idx,
|
||||
export_for_source_row=topk_ids,
|
||||
)
|
||||
del down_out_list
|
||||
|
||||
if len(original_shape) == 3:
|
||||
final_hidden_states = final_hidden_states.view(original_shape)
|
||||
return final_hidden_states
|
||||
|
||||
Reference in New Issue
Block a user