[Bugfix] Fix memory-leak caused by dist._functional_collectives.reduce_scatter_tensor (#1380)

### What this PR does / why we need it?
In some cases, `dist._functional_collectives.reduce_scatter_tensor` can
cause its input tensor not to be released immediately after the current
layer ends. Instead, it will only be released when the GPU memory usage
of the current process reaches a certain threshold (approximately every
15 layers each time).

**Before Fix**

<img width="1441" alt="截屏2025-06-24 01 26 13"
src="https://github.com/user-attachments/assets/72d5dbb3-c8c8-4778-bf64-8db7bab8aff0"
/>

**After Fix**
<img width="1475" alt="截屏2025-06-24 01 23 43"
src="https://github.com/user-attachments/assets/6c69cfcd-a469-4ee5-b8c6-210aeb3a5bdf"
/>

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?


- vLLM version: v0.9.1
- vLLM main:
9ff2af6d2b

---------

Signed-off-by: ApsarasX <apsarax@outlook.com>
This commit is contained in:
ApsarasX
2025-07-10 10:57:24 +08:00
committed by GitHub
parent b1c66b211f
commit 89c1a0f006
2 changed files with 29 additions and 5 deletions

View File

@@ -39,6 +39,8 @@ from vllm.model_executor.layers.quantization.base_config import \
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.communication_op import \
data_parallel_reduce_scatter
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
@@ -1342,11 +1344,8 @@ class AscendFusedMoE(FusedMoE):
final_hidden_states = final_hidden_states[start:end, :]
dispose_tensor(e_hidden_states)
elif fused_moe_state == FusedMoEState.AllGather:
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
e_hidden_states,
"sum",
scatter_dim=0,
group=get_dp_group().device_group)
final_hidden_states = data_parallel_reduce_scatter(
e_hidden_states, dim=0)
final_hidden_states = final_hidden_states[:num_tokens]
dispose_tensor(e_hidden_states)
else: