[perf] Fix MLAPO weight disposal for KV-consumer MLA in PD-mix deploy... (#5192)

### What this PR does / why we need it?

- Problem: In MLA+MLAPO, KV-consumer deployments keep
fused_qkv_a_proj/q_proj weights and quant params even though MLAPO uses
the prepacked buffers, increasing memory footprint on decode nodes.
- Fix: Conditionally drop those tensors only when
`kv_transfer_config.is_kv_consumer` to reclaim memory (consistent with
the SFA behavior #4774 ).

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: Chen Chen <0109chenchen@gmail.com>
This commit is contained in:
Chen Chen
2026-01-05 21:29:45 +08:00
committed by GitHub
parent b10ef9b9f3
commit a2daacbd71
2 changed files with 23 additions and 3 deletions

View File

@@ -912,6 +912,21 @@ class AscendMLAImpl(MLAAttentionImpl):
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)
# On KV consumers (decode-only) MLAPO uses the transformed weights built above;
# the original fused_qkv_a_proj/q_proj weights and quant params are no longer
# referenced, so drop them to save memory.
ascend_config = get_ascend_config()
if self.vllm_config.kv_transfer_config is not None and \
self.vllm_config.kv_transfer_config.is_kv_consumer and \
ascend_config.recompute_scheduler_enable:
self.fused_qkv_a_proj.weight = None
self.fused_qkv_a_proj.deq_scale = None
self.fused_qkv_a_proj.quant_bias = None
self.q_proj.weight = None
self.q_proj.deq_scale = None
self.q_proj.quant_bias = None
torch.npu.empty_cache()
def get_context_seq_len_npu(self, index: int,
attn_metadata: AscendMLAMetadata):
prefill_metadata = attn_metadata.prefill

View File

@@ -371,7 +371,7 @@ class AscendSFAImpl(MLAAttentionImpl):
if self.enable_sfa_cp:
self.local_num_heads = self.num_heads * self.tp_size
#TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97
# TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97
self._replace_linear_class_for_sfa_cp()
from vllm_ascend.distributed.parallel_state import \
get_shared_weight_group
@@ -537,7 +537,7 @@ class AscendSFAImpl(MLAAttentionImpl):
cache_mode=cache_mode,
is_output_kv=True,
)
#TODO: Temporarily adapt SFA-CP and replace it later with PCP. --clrs97
# TODO: Temporarily adapt SFA-CP and replace it later with PCP. --clrs97
k_pe = get_tp_group().all_gather(k_pe, 0)
k_nope = get_tp_group().all_gather(k_nope, 0)
@@ -659,8 +659,13 @@ class AscendSFAImpl(MLAAttentionImpl):
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)
# On KV consumers (decode-only) MLAPO uses the transformed weights built above;
# the original fused_qkv_a_proj/q_proj weights and quant params are no longer
# referenced, so drop them to save memory.
ascend_config = get_ascend_config()
if self.vllm_config.kv_transfer_config is not None and \
self.vllm_config.kv_transfer_config.is_kv_consumer:
self.vllm_config.kv_transfer_config.is_kv_consumer and \
ascend_config.recompute_scheduler_enable:
self.fused_qkv_a_proj.weight = None
self.fused_qkv_a_proj.deq_scale = None
self.fused_qkv_a_proj.quant_bias = None