From ffdd1a36e20180e881036a6d502dac1810ba085e Mon Sep 17 00:00:00 2001 From: linfeng-yuan <1102311262@qq.com> Date: Mon, 22 Sep 2025 14:06:43 +0800 Subject: [PATCH] [bugfix][torchair] fix wasted NPU memory buffer allocation for quantized deepseek with unquantized MTP layer (#3068) ### What this PR does / why we need it? While running quantized deepseek models with unquantized MTP layer, free NPU memory abnormally decreases for `2*HCCL_BUFFSIZE` bytes. This results from the wasted VRAM buffer allocation casued by calling `dist.all_to_all_single` without correct device process group argument. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? We run vllm online serving with quantized deepseek-r1 and unquantized MTP layer, and observed that free_memory increased without redundat VRAM buffer for HCCL communication op (all_to_all_single). - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/6d8246aaffff3ebec84767e373212a7b8da328e2 Signed-off-by: linfeng-yuan <1102311262@qq.com> --- .../quantization/torchair_w8a8_dynamic.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py index 832cbc5..be212e2 100644 --- a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py @@ -416,6 +416,7 @@ def torchair_fused_experts_with_all2all( num_experts = w1.shape[0] if expert_map is not None: + assert ep_group is not None, "ep_group must be provided when expert_map is given" global_num_experts = len(expert_map) + global_redundant_expert_num if hasattr(torch_npu, "npu_moe_init_routing_quant"): quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant( @@ -435,8 +436,9 @@ def torchair_fused_experts_with_all2all( gather_sizes = global_expert_tokens.new_empty( global_expert_tokens.shape[0]) - dist.all_to_all_single(gather_sizes, global_expert_tokens) - + dist.all_to_all_single(gather_sizes, + global_expert_tokens, + group=ep_group.device_group) token_counts_combined = torch.stack( [gather_sizes, global_expert_tokens], dim=0) token_counts_combined = token_counts_combined.view( @@ -451,10 +453,16 @@ def torchair_fused_experts_with_all2all( gather_size_list = token_counts_combined_cpu[1] scatter_size_list = token_counts_combined_cpu[0] - dist.all_to_all_single(gathered_tokens, quantized_tokens, - scatter_size_list, gather_size_list) - dist.all_to_all_single(dynamic_scale, token_scales, scatter_size_list, - gather_size_list) + dist.all_to_all_single(gathered_tokens, + quantized_tokens, + scatter_size_list, + gather_size_list, + group=ep_group.device_group) + dist.all_to_all_single(dynamic_scale, + token_scales, + scatter_size_list, + gather_size_list, + group=ep_group.device_group) hidden_states, dynamic_scale, inverse_indices, expert_tokens = torch_npu.npu_moe_re_routing( gathered_tokens, @@ -502,9 +510,11 @@ def torchair_fused_experts_with_all2all( index=inverse_indices.to(torch.float32).argsort().to(torch.int32)) hidden_states = reordered_outputs.new_empty(*quantized_tokens.shape) - dist.all_to_all_single(hidden_states, reordered_outputs, - gather_size_list, scatter_size_list) - + dist.all_to_all_single(hidden_states, + reordered_outputs, + gather_size_list, + scatter_size_list, + group=ep_group.device_group) final_hidden_states = torch_npu.npu_moe_finalize_routing( hidden_states, skip1=None,