[perf] Fix MLAPO weight disposal for KV-consumer MLA in PD-mix deploy... (#5192)
### What this PR does / why we need it?
- Problem: In MLA+MLAPO, KV-consumer deployments keep
fused_qkv_a_proj/q_proj weights and quant params even though MLAPO uses
the prepacked buffers, increasing memory footprint on decode nodes.
- Fix: Conditionally drop those tensors only when
`kv_transfer_config.is_kv_consumer` to reclaim memory (consistent with
the SFA behavior #4774 ).
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: Chen Chen <0109chenchen@gmail.com>
This commit is contained in:
@@ -912,6 +912,21 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
||||
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
||||
|
||||
# On KV consumers (decode-only) MLAPO uses the transformed weights built above;
|
||||
# the original fused_qkv_a_proj/q_proj weights and quant params are no longer
|
||||
# referenced, so drop them to save memory.
|
||||
ascend_config = get_ascend_config()
|
||||
if self.vllm_config.kv_transfer_config is not None and \
|
||||
self.vllm_config.kv_transfer_config.is_kv_consumer and \
|
||||
ascend_config.recompute_scheduler_enable:
|
||||
self.fused_qkv_a_proj.weight = None
|
||||
self.fused_qkv_a_proj.deq_scale = None
|
||||
self.fused_qkv_a_proj.quant_bias = None
|
||||
self.q_proj.weight = None
|
||||
self.q_proj.deq_scale = None
|
||||
self.q_proj.quant_bias = None
|
||||
torch.npu.empty_cache()
|
||||
|
||||
def get_context_seq_len_npu(self, index: int,
|
||||
attn_metadata: AscendMLAMetadata):
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
|
||||
@@ -371,7 +371,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
if self.enable_sfa_cp:
|
||||
self.local_num_heads = self.num_heads * self.tp_size
|
||||
|
||||
#TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97
|
||||
# TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97
|
||||
self._replace_linear_class_for_sfa_cp()
|
||||
from vllm_ascend.distributed.parallel_state import \
|
||||
get_shared_weight_group
|
||||
@@ -537,7 +537,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
cache_mode=cache_mode,
|
||||
is_output_kv=True,
|
||||
)
|
||||
#TODO: Temporarily adapt SFA-CP and replace it later with PCP. --clrs97
|
||||
# TODO: Temporarily adapt SFA-CP and replace it later with PCP. --clrs97
|
||||
k_pe = get_tp_group().all_gather(k_pe, 0)
|
||||
k_nope = get_tp_group().all_gather(k_nope, 0)
|
||||
|
||||
@@ -659,8 +659,13 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
||||
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
||||
|
||||
# On KV consumers (decode-only) MLAPO uses the transformed weights built above;
|
||||
# the original fused_qkv_a_proj/q_proj weights and quant params are no longer
|
||||
# referenced, so drop them to save memory.
|
||||
ascend_config = get_ascend_config()
|
||||
if self.vllm_config.kv_transfer_config is not None and \
|
||||
self.vllm_config.kv_transfer_config.is_kv_consumer:
|
||||
self.vllm_config.kv_transfer_config.is_kv_consumer and \
|
||||
ascend_config.recompute_scheduler_enable:
|
||||
self.fused_qkv_a_proj.weight = None
|
||||
self.fused_qkv_a_proj.deq_scale = None
|
||||
self.fused_qkv_a_proj.quant_bias = None
|
||||
|
||||
Reference in New Issue
Block a user