[Perf] Refactor tensor disposal logic to reduce memory usage (#966)
### What this PR does / why we need it? 1. In previous PRs https://github.com/vllm-project/vllm-ascend/pull/580 https://github.com/vllm-project/vllm-ascend/pull/784, I saved GPU memory by promptly deleting unnecessary tensors. For tensors passed from upper-layer functions, I used a list container to transfer the parameter and then popped the tensor from the list within the inner function to achieve deletion. Recently, I discovered a better implementation in sglang—the `dispose_tensor` function and I recommend adopting this approach. 2. Dispose `hidden_states` and `residual` from the previous layer once they're no longer used. 3. Avoid to generate `self.inputs_embeds` in `ModelRunnerV1` in non-multimodal scenarios. With the aforementioned optimizations, using the DeepSeek-R1-W8A8 model under the conditions of `TP=16` and `max-model-len=32768`, we can save 1.3GB of npu memory. **Reference**: https://github.com/sgl-project/sglang/pull/6147 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? --------- Signed-off-by: ApsarasX <apsarax@outlook.com>
This commit is contained in:
@@ -68,6 +68,7 @@ from vllm.sequence import IntermediateTensors
|
|||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||||
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
||||||
|
from vllm_ascend.utils import dispose_tensor
|
||||||
|
|
||||||
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||||
|
|
||||||
@@ -518,8 +519,14 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
else:
|
else:
|
||||||
|
previous_hidden_states, previous_residual = hidden_states, residual
|
||||||
hidden_states, residual = self.input_layernorm(
|
hidden_states, residual = self.input_layernorm(
|
||||||
hidden_states, residual)
|
hidden_states, residual)
|
||||||
|
# Dispose hidden_states and residual from the previous layer
|
||||||
|
# to save npu memory because they're no longer used.
|
||||||
|
dispose_tensor(previous_hidden_states)
|
||||||
|
dispose_tensor(previous_residual)
|
||||||
|
|
||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -25,11 +25,12 @@ from vllm.distributed import GroupCoordinator
|
|||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.distributed.parallel_state import get_ep_group
|
from vllm_ascend.distributed.parallel_state import get_ep_group
|
||||||
from vllm_ascend.ops.fused_moe import select_experts
|
from vllm_ascend.ops.fused_moe import select_experts
|
||||||
|
from vllm_ascend.utils import dispose_tensor
|
||||||
|
|
||||||
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||||
|
|
||||||
|
|
||||||
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
def apply_mlp(hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w1_scale: torch.Tensor,
|
w1_scale: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
@@ -41,7 +42,7 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
|||||||
apply MLP: gate_up_proj -> swiglu -> down_proj
|
apply MLP: gate_up_proj -> swiglu -> down_proj
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
|
hidden_states: input hidden states with shape (num_tokens, hidden_size).
|
||||||
w1: expert weights1 with shape
|
w1: expert weights1 with shape
|
||||||
(num_experts, hidden_size, intermediate_size * 2)
|
(num_experts, hidden_size, intermediate_size * 2)
|
||||||
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
|
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
|
||||||
@@ -60,11 +61,13 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
|||||||
hidden_states: output hidden states after MLP.
|
hidden_states: output hidden states after MLP.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert len(hidden_states_wrapper) == 1
|
|
||||||
hidden_states = hidden_states_wrapper.pop()
|
|
||||||
if dynamic_scale is None:
|
if dynamic_scale is None:
|
||||||
|
unquantized_hidden_states = hidden_states
|
||||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||||
hidden_states)
|
hidden_states)
|
||||||
|
# Dispose the original unquantized hidden states
|
||||||
|
# to save npu memory because they're no longer used.
|
||||||
|
dispose_tensor(unquantized_hidden_states)
|
||||||
else:
|
else:
|
||||||
pertoken_scale = dynamic_scale
|
pertoken_scale = dynamic_scale
|
||||||
|
|
||||||
@@ -155,11 +158,8 @@ def fused_experts_with_mc2(
|
|||||||
if quant_mode == 0:
|
if quant_mode == 0:
|
||||||
dynamic_scale = None
|
dynamic_scale = None
|
||||||
|
|
||||||
# place hidden_states in a list to transfer its ownership into the `apply_mlp` function
|
# `expand_x` will be disposed in the `apply_mlp` function
|
||||||
hidden_states_wrapper = [expand_x]
|
down_out_list = apply_mlp(expand_x,
|
||||||
del expand_x
|
|
||||||
|
|
||||||
down_out_list = apply_mlp(hidden_states_wrapper,
|
|
||||||
w1,
|
w1,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
w2,
|
w2,
|
||||||
@@ -281,10 +281,8 @@ def fused_experts_with_all2all(
|
|||||||
expert_tokens = expert_tokens.to(torch.int64)
|
expert_tokens = expert_tokens.to(torch.int64)
|
||||||
group_list_type = 0
|
group_list_type = 0
|
||||||
|
|
||||||
hidden_states_wrapper = [hidden_states]
|
# `hidden_states` will be disposed in the `apply_mlp` function
|
||||||
del hidden_states
|
hidden_states = apply_mlp(hidden_states,
|
||||||
|
|
||||||
hidden_states = apply_mlp(hidden_states_wrapper,
|
|
||||||
w1,
|
w1,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
w2,
|
w2,
|
||||||
@@ -399,11 +397,8 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
expert_tokens = expert_tokens.to(torch.int64)
|
expert_tokens = expert_tokens.to(torch.int64)
|
||||||
group_list_type = 0
|
group_list_type = 0
|
||||||
|
|
||||||
# place hidden_states in a list to transfer its ownership into the `apply_mlp` function
|
# `hidden_states` will be disposed in the `apply_mlp` function
|
||||||
hidden_states_wrapper = [hidden_states]
|
hidden_states = apply_mlp(hidden_states,
|
||||||
del hidden_states
|
|
||||||
|
|
||||||
hidden_states = apply_mlp(hidden_states_wrapper,
|
|
||||||
w1,
|
w1,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
w2,
|
w2,
|
||||||
|
|||||||
@@ -169,3 +169,7 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
|||||||
"No adjustment needed for ACL graph batch sizes: %s model (layers: %d) with %d sizes",
|
"No adjustment needed for ACL graph batch sizes: %s model (layers: %d) with %d sizes",
|
||||||
vllm_config.model_config.architectures[0], num_hidden_layers,
|
vllm_config.model_config.architectures[0], num_hidden_layers,
|
||||||
len(original_sizes))
|
len(original_sizes))
|
||||||
|
|
||||||
|
|
||||||
|
def dispose_tensor(x: torch.Tensor):
|
||||||
|
x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype))
|
||||||
|
|||||||
@@ -240,10 +240,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
|
|
||||||
self.inputs_embeds = torch.zeros(
|
if self.is_multimodal_model:
|
||||||
(self.max_num_tokens, self.hidden_size),
|
self.inputs_embeds = torch.zeros(
|
||||||
dtype=self.dtype,
|
(self.max_num_tokens, self.hidden_size),
|
||||||
device=self.device)
|
dtype=self.dtype,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
||||||
self.arange_np: npt.NDArray[np.int32] = np.arange(max(
|
self.arange_np: npt.NDArray[np.int32] = np.arange(max(
|
||||||
|
|||||||
Reference in New Issue
Block a user