Fix incorrect MLAPO weight release in PD mixex scenarios. (#4774)

### What this PR does / why we need it?
Fix incorrect MLAPO weight release in PD mixex scenarios.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: ZYang6263 <zy626375@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
ZYang6263
2025-12-08 23:17:45 +08:00
committed by GitHub
parent b230e7e987
commit 432b861cae

View File

@@ -470,7 +470,7 @@ class AscendSFAImpl(MLAAttentionImpl):
if self.fused_qkv_a_proj is None or not isinstance(
quant_method, AscendW8A8LinearMethod):
reasons.append(
"Currently mlapo only supports W8A8 quantization in MLA scenario."
"Currently mlapo only supports W8A8 quantization in SFA scenario."
"Some layers in your model are not quantized with W8A8,"
"thus mlapo is disabled for these layers.")
if self.enable_sfa_cp:
@@ -597,8 +597,6 @@ class AscendSFAImpl(MLAAttentionImpl):
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[
..., :self.q_lora_rank].contiguous()
self.fused_qkv_a_proj.weight = None
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
@@ -673,9 +671,12 @@ class AscendSFAImpl(MLAAttentionImpl):
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)
if self.vllm_config.kv_transfer_config is not None:
if self.vllm_config.kv_transfer_config is not None and \
self.vllm_config.kv_transfer_config.is_kv_consumer:
self.fused_qkv_a_proj.weight = None
self.fused_qkv_a_proj.deq_scale = None
self.fused_qkv_a_proj.quant_bias = None
self.q_proj.weight = None
self.q_proj.deq_scale = None
self.q_proj.quant_bias = None
torch.npu.empty_cache()