From e3c7f71462f1c252a60ce6476336fb24311a593c Mon Sep 17 00:00:00 2001 From: ApsarasX Date: Thu, 29 May 2025 11:48:26 +0800 Subject: [PATCH] [Perf] Refactor tensor disposal logic to reduce memory usage (#966) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? 1. In previous PRs https://github.com/vllm-project/vllm-ascend/pull/580 https://github.com/vllm-project/vllm-ascend/pull/784, I saved GPU memory by promptly deleting unnecessary tensors. For tensors passed from upper-layer functions, I used a list container to transfer the parameter and then popped the tensor from the list within the inner function to achieve deletion. Recently, I discovered a better implementation in sglang—the `dispose_tensor` function and I recommend adopting this approach. 2. Dispose `hidden_states` and `residual` from the previous layer once they're no longer used. 3. Avoid to generate `self.inputs_embeds` in `ModelRunnerV1` in non-multimodal scenarios. With the aforementioned optimizations, using the DeepSeek-R1-W8A8 model under the conditions of `TP=16` and `max-model-len=32768`, we can save 1.3GB of npu memory. **Reference**: https://github.com/sgl-project/sglang/pull/6147 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? --------- Signed-off-by: ApsarasX --- vllm_ascend/models/deepseek_v2.py | 7 +++++ vllm_ascend/quantization/w8a8_dynamic.py | 33 ++++++++++-------------- vllm_ascend/utils.py | 4 +++ vllm_ascend/worker/model_runner_v1.py | 9 ++++--- 4 files changed, 30 insertions(+), 23 deletions(-) diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 5e97444..264a798 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -68,6 +68,7 @@ from vllm.sequence import IntermediateTensors import vllm_ascend.envs as envs_ascend from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod +from vllm_ascend.utils import dispose_tensor VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 @@ -518,8 +519,14 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: + previous_hidden_states, previous_residual = hidden_states, residual hidden_states, residual = self.input_layernorm( hidden_states, residual) + # Dispose hidden_states and residual from the previous layer + # to save npu memory because they're no longer used. + dispose_tensor(previous_hidden_states) + dispose_tensor(previous_residual) + hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 0f54b01..a847364 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, Optional import torch import torch.distributed as dist @@ -25,11 +25,12 @@ from vllm.distributed import GroupCoordinator import vllm_ascend.envs as envs_ascend from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.ops.fused_moe import select_experts +from vllm_ascend.utils import dispose_tensor VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 -def apply_mlp(hidden_states_wrapper: List[torch.Tensor], +def apply_mlp(hidden_states: torch.Tensor, w1: torch.Tensor, w1_scale: torch.Tensor, w2: torch.Tensor, @@ -41,7 +42,7 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor], apply MLP: gate_up_proj -> swiglu -> down_proj Args: - hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size). + hidden_states: input hidden states with shape (num_tokens, hidden_size). w1: expert weights1 with shape (num_experts, hidden_size, intermediate_size * 2) w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2) @@ -60,11 +61,13 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor], hidden_states: output hidden states after MLP. """ - assert len(hidden_states_wrapper) == 1 - hidden_states = hidden_states_wrapper.pop() if dynamic_scale is None: + unquantized_hidden_states = hidden_states hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( hidden_states) + # Dispose the original unquantized hidden states + # to save npu memory because they're no longer used. + dispose_tensor(unquantized_hidden_states) else: pertoken_scale = dynamic_scale @@ -155,11 +158,8 @@ def fused_experts_with_mc2( if quant_mode == 0: dynamic_scale = None - # place hidden_states in a list to transfer its ownership into the `apply_mlp` function - hidden_states_wrapper = [expand_x] - del expand_x - - down_out_list = apply_mlp(hidden_states_wrapper, + # `expand_x` will be disposed in the `apply_mlp` function + down_out_list = apply_mlp(expand_x, w1, w1_scale, w2, @@ -281,10 +281,8 @@ def fused_experts_with_all2all( expert_tokens = expert_tokens.to(torch.int64) group_list_type = 0 - hidden_states_wrapper = [hidden_states] - del hidden_states - - hidden_states = apply_mlp(hidden_states_wrapper, + # `hidden_states` will be disposed in the `apply_mlp` function + hidden_states = apply_mlp(hidden_states, w1, w1_scale, w2, @@ -399,11 +397,8 @@ def fused_experts(hidden_states: torch.Tensor, expert_tokens = expert_tokens.to(torch.int64) group_list_type = 0 - # place hidden_states in a list to transfer its ownership into the `apply_mlp` function - hidden_states_wrapper = [hidden_states] - del hidden_states - - hidden_states = apply_mlp(hidden_states_wrapper, + # `hidden_states` will be disposed in the `apply_mlp` function + hidden_states = apply_mlp(hidden_states, w1, w1_scale, w2, diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index cd83fae..67cc0b8 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -169,3 +169,7 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: "No adjustment needed for ACL graph batch sizes: %s model (layers: %d) with %d sizes", vllm_config.model_config.architectures[0], num_hidden_layers, len(original_sizes)) + + +def dispose_tensor(x: torch.Tensor): + x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype)) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 24bd2b4..184f352 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -240,10 +240,11 @@ class NPUModelRunner(LoRAModelRunnerMixin): device="cpu", pin_memory=True) - self.inputs_embeds = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=self.device) + if self.is_multimodal_model: + self.inputs_embeds = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=self.device) # OPTIMIZATION: Cache the tensors rather than creating them every step. self.arange_np: npt.NDArray[np.int32] = np.arange(max(