From 324f819b929ae2c30fe629e27df0c37c2cc607b7 Mon Sep 17 00:00:00 2001 From: ApsarasX Date: Fri, 9 May 2025 15:09:37 +0800 Subject: [PATCH] [Perf] Optimize fused_experts quantization code to save npu memory (#784) ### What this PR does / why we need it? In the w8a8 quantization code of `fused_experts`, the output of almost every operator is assigned a new variable name. If we want to save NPU memory, we manually `del` these variables to end their lifecycle, which fills the code with `del` statements and looks inelegant. Therefore, I plan to names the output of most operators as `hidden_states`, thereby ending the lifecycle of the previous `hidden_states`. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Signed-off-by: ApsarasX --- vllm_ascend/models/deepseek_v2.py | 7 +- vllm_ascend/quantization/w8a8_dynamic.py | 84 ++++++++++++++---------- 2 files changed, 50 insertions(+), 41 deletions(-) diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index ef47519..c46a3c2 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -222,9 +222,6 @@ class CustomDeepseekV2MoE(nn.Module): num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - if self.n_shared_experts is not None: - shared_output = self.shared_experts(hidden_states) - if (self.tp_size > 1 and self.enable_mc2 and not is_prefill): chunks = torch.chunk(hidden_states, get_tp_group().world_size, @@ -248,8 +245,8 @@ class CustomDeepseekV2MoE(nn.Module): else: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) - - if shared_output is not None: + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) final_hidden_states = final_hidden_states + shared_output return final_hidden_states.view(num_tokens, hidden_dim) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 4fbfadc..97bddba 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -16,7 +16,7 @@ # import os -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional import torch import torch_npu @@ -25,7 +25,7 @@ from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.ops.fused_moe import select_experts -def apply_mlp(x: torch.Tensor, +def apply_mlp(hidden_states_wrapper: List[torch.Tensor], w1: torch.Tensor, w1_scale: torch.Tensor, w2: torch.Tensor, @@ -37,7 +37,7 @@ def apply_mlp(x: torch.Tensor, apply MLP: gate_up_proj -> swiglu -> down_proj Args: - x: input hidden states with shape (num_tokens, hidden_size). + hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size). w1: expert weights1 with shape (num_experts, hidden_size, intermediate_size * 2) w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2) @@ -56,23 +56,27 @@ def apply_mlp(x: torch.Tensor, hidden_states: output hidden states after MLP. """ + assert len(hidden_states_wrapper) == 1 + hidden_states = hidden_states_wrapper.pop() if dynamic_scale is None: - h, pertoken_scale = torch_npu.npu_dynamic_quant(x) + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( + hidden_states) else: - h = x pertoken_scale = dynamic_scale # gmm1: gate_up_proj - gate_up_out = torch_npu.npu_grouped_matmul(x=[h], - weight=[w1], - split_item=3, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=torch.int32)[0] + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=3, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=torch.int32)[0] - swiglu_out, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( - x=gate_up_out, + # act_fn: swiglu + hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( + x=hidden_states, weight_scale=w1_scale, activation_scale=pertoken_scale, bias=None, @@ -83,17 +87,18 @@ def apply_mlp(x: torch.Tensor, quant_mode=1, ) - # down_proj - down_out = torch_npu.npu_grouped_matmul(x=[swiglu_out], - weight=[w2], - scale=[w2_scale], - per_token_scale=[swiglu_out_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=w2_scale.dtype)[0] - return down_out + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + scale=[w2_scale], + per_token_scale=[swiglu_out_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=w2_scale.dtype)[0] + return hidden_states def fused_experts_with_mc2( @@ -152,7 +157,11 @@ def fused_experts_with_mc2( if quant_mode == 0: dynamic_scale = None - down_out_list = apply_mlp(expand_x, + # place hidden_states in a list to transfer its ownership into the `apply_mlp` function + hidden_states_wrapper = [expand_x] + del expand_x + + down_out_list = apply_mlp(hidden_states_wrapper, w1, w1_scale, w2, @@ -246,7 +255,7 @@ def fused_experts(hidden_states: torch.Tensor, token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) expert_tokens = token_counts[:num_experts] # Rearrange hidden_states - sorted_hidden_states = hidden_states[sorted_token_indices] + hidden_states = hidden_states[sorted_token_indices] group_list_type = 1 else: row_idx_len = num_tokens * top_k @@ -255,19 +264,22 @@ def fused_experts(hidden_states: torch.Tensor, dtype=torch.int32, device=topk_weights.device).view( top_k, -1).permute(1, 0).contiguous() - sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens) - del hidden_states expert_tokens = torch_npu.npu_moe_compute_expert_tokens( expanded_expert_idx, num_experts) expert_tokens = expert_tokens.to(torch.int64) group_list_type = 0 - down_out_list = apply_mlp(sorted_hidden_states, + # place hidden_states in a list to transfer its ownership into the `apply_mlp` function + hidden_states_wrapper = [hidden_states] + del hidden_states + + hidden_states = apply_mlp(hidden_states_wrapper, w1, w1_scale, w2, @@ -276,23 +288,23 @@ def fused_experts(hidden_states: torch.Tensor, group_list_type=group_list_type) if expert_map is not None: - down_out_list.mul_(sorted_weights.unsqueeze(1)) + hidden_states.mul_(sorted_weights.unsqueeze(1)) final_hidden_states = torch.zeros(*original_shape, - device=hidden_states.device, + device=device, dtype=dtype) num_valid_tokens = mask.sum() valid_token_mask = torch.arange( 0, sorted_token_indices.shape[0], device=device).unsqueeze(1) < num_valid_tokens - down_out_list = down_out_list.masked_fill_(~valid_token_mask, + hidden_states = hidden_states.masked_fill_(~valid_token_mask, 0).to(dtype) - final_hidden_states.index_add_(0, sorted_token_indices, down_out_list) + final_hidden_states.index_add_(0, sorted_token_indices, hidden_states) else: # TODO: Reorder device memory 2 times here, replace the current # implementation here when suitable operators become available. final_hidden_states = torch_npu.npu_moe_finalize_routing( - down_out_list, + hidden_states, skip1=None, skip2=None, bias=None, @@ -300,7 +312,7 @@ def fused_experts(hidden_states: torch.Tensor, expanded_src_to_dst_row=expanded_row_idx, export_for_source_row=topk_ids, ) - del down_out_list + if len(original_shape) == 3: final_hidden_states = final_hidden_states.view(original_shape) return final_hidden_states