[perf] Fix MLAPO weight disposal for KV-consumer MLA in PD-mix deploy... (#5192)
### What this PR does / why we need it?
- Problem: In MLA+MLAPO, KV-consumer deployments keep
fused_qkv_a_proj/q_proj weights and quant params even though MLAPO uses
the prepacked buffers, increasing memory footprint on decode nodes.
- Fix: Conditionally drop those tensors only when
`kv_transfer_config.is_kv_consumer` to reclaim memory (consistent with
the SFA behavior #4774 ).
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: Chen Chen <0109chenchen@gmail.com>
This commit is contained in:
@@ -912,6 +912,21 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
||||||
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
||||||
|
|
||||||
|
# On KV consumers (decode-only) MLAPO uses the transformed weights built above;
|
||||||
|
# the original fused_qkv_a_proj/q_proj weights and quant params are no longer
|
||||||
|
# referenced, so drop them to save memory.
|
||||||
|
ascend_config = get_ascend_config()
|
||||||
|
if self.vllm_config.kv_transfer_config is not None and \
|
||||||
|
self.vllm_config.kv_transfer_config.is_kv_consumer and \
|
||||||
|
ascend_config.recompute_scheduler_enable:
|
||||||
|
self.fused_qkv_a_proj.weight = None
|
||||||
|
self.fused_qkv_a_proj.deq_scale = None
|
||||||
|
self.fused_qkv_a_proj.quant_bias = None
|
||||||
|
self.q_proj.weight = None
|
||||||
|
self.q_proj.deq_scale = None
|
||||||
|
self.q_proj.quant_bias = None
|
||||||
|
torch.npu.empty_cache()
|
||||||
|
|
||||||
def get_context_seq_len_npu(self, index: int,
|
def get_context_seq_len_npu(self, index: int,
|
||||||
attn_metadata: AscendMLAMetadata):
|
attn_metadata: AscendMLAMetadata):
|
||||||
prefill_metadata = attn_metadata.prefill
|
prefill_metadata = attn_metadata.prefill
|
||||||
|
|||||||
@@ -371,7 +371,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
if self.enable_sfa_cp:
|
if self.enable_sfa_cp:
|
||||||
self.local_num_heads = self.num_heads * self.tp_size
|
self.local_num_heads = self.num_heads * self.tp_size
|
||||||
|
|
||||||
#TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97
|
# TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97
|
||||||
self._replace_linear_class_for_sfa_cp()
|
self._replace_linear_class_for_sfa_cp()
|
||||||
from vllm_ascend.distributed.parallel_state import \
|
from vllm_ascend.distributed.parallel_state import \
|
||||||
get_shared_weight_group
|
get_shared_weight_group
|
||||||
@@ -537,7 +537,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
cache_mode=cache_mode,
|
cache_mode=cache_mode,
|
||||||
is_output_kv=True,
|
is_output_kv=True,
|
||||||
)
|
)
|
||||||
#TODO: Temporarily adapt SFA-CP and replace it later with PCP. --clrs97
|
# TODO: Temporarily adapt SFA-CP and replace it later with PCP. --clrs97
|
||||||
k_pe = get_tp_group().all_gather(k_pe, 0)
|
k_pe = get_tp_group().all_gather(k_pe, 0)
|
||||||
k_nope = get_tp_group().all_gather(k_nope, 0)
|
k_nope = get_tp_group().all_gather(k_nope, 0)
|
||||||
|
|
||||||
@@ -659,8 +659,13 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
||||||
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
||||||
|
|
||||||
|
# On KV consumers (decode-only) MLAPO uses the transformed weights built above;
|
||||||
|
# the original fused_qkv_a_proj/q_proj weights and quant params are no longer
|
||||||
|
# referenced, so drop them to save memory.
|
||||||
|
ascend_config = get_ascend_config()
|
||||||
if self.vllm_config.kv_transfer_config is not None and \
|
if self.vllm_config.kv_transfer_config is not None and \
|
||||||
self.vllm_config.kv_transfer_config.is_kv_consumer:
|
self.vllm_config.kv_transfer_config.is_kv_consumer and \
|
||||||
|
ascend_config.recompute_scheduler_enable:
|
||||||
self.fused_qkv_a_proj.weight = None
|
self.fused_qkv_a_proj.weight = None
|
||||||
self.fused_qkv_a_proj.deq_scale = None
|
self.fused_qkv_a_proj.deq_scale = None
|
||||||
self.fused_qkv_a_proj.quant_bias = None
|
self.fused_qkv_a_proj.quant_bias = None
|
||||||
|
|||||||
Reference in New Issue
Block a user