Fix incorrect MLAPO weight release in PD mixex scenarios. (#4774)
### What this PR does / why we need it?
Fix incorrect MLAPO weight release in PD mixex scenarios.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: ZYang6263 <zy626375@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -470,7 +470,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
if self.fused_qkv_a_proj is None or not isinstance(
|
||||
quant_method, AscendW8A8LinearMethod):
|
||||
reasons.append(
|
||||
"Currently mlapo only supports W8A8 quantization in MLA scenario."
|
||||
"Currently mlapo only supports W8A8 quantization in SFA scenario."
|
||||
"Some layers in your model are not quantized with W8A8,"
|
||||
"thus mlapo is disabled for these layers.")
|
||||
if self.enable_sfa_cp:
|
||||
@@ -597,8 +597,6 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[
|
||||
..., :self.q_lora_rank].contiguous()
|
||||
|
||||
self.fused_qkv_a_proj.weight = None
|
||||
|
||||
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
|
||||
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
|
||||
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
|
||||
@@ -673,9 +671,12 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
||||
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
||||
|
||||
if self.vllm_config.kv_transfer_config is not None:
|
||||
if self.vllm_config.kv_transfer_config is not None and \
|
||||
self.vllm_config.kv_transfer_config.is_kv_consumer:
|
||||
self.fused_qkv_a_proj.weight = None
|
||||
self.fused_qkv_a_proj.deq_scale = None
|
||||
self.fused_qkv_a_proj.quant_bias = None
|
||||
self.q_proj.weight = None
|
||||
self.q_proj.deq_scale = None
|
||||
self.q_proj.quant_bias = None
|
||||
torch.npu.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user