[Refactor] Replace the implementations of o_proj, q_b_proj, and kv_b_proj with custom_op for sharded CP (#5698)

### What this PR does / why we need it?
Based on the Sharded-CP feature
PR:https://github.com/vllm-project/vllm-ascend/pull/4702;
RFC:https://github.com/vllm-project/vllm/issues/30055

This PR officially integrates Deepseek V3.2's DSA-CP support on the
basis of https://github.com/vllm-project/vllm-ascend/pull/4702,
improving inference efficiency and scalability under mixed
prefill-decode workloads. The main improvements include:
- Replace the implementations of o_proj, q_b_proj, and kv_b_proj with
custom_op for TP=1.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Signed-off-by: Kurumi5210 <jaychou1620@gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
zzhxxx
2026-01-09 15:58:40 +08:00
committed by GitHub
parent e11ff8e535
commit 64d29875f9
4 changed files with 110 additions and 68 deletions

View File

@@ -10,8 +10,7 @@ from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.linear import (ReplicatedLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
from vllm.v1.attention.backends.utils import AttentionCGSupport
@@ -34,7 +33,7 @@ from vllm_ascend.ops.triton.rope import rope_forward_triton
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer,
enable_sp, maybe_trans_nz, replace_layer)
enable_dsa_cp, maybe_trans_nz)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
if TYPE_CHECKING:
@@ -149,8 +148,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
got {self.decode_threshold}"
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.enable_sfa_cp = enable_sp() and \
hasattr(self.model_config.hf_text_config, "index_topk")
self.enable_sfa_cp = enable_dsa_cp()
assert not (
self.enable_sfa_cp
@@ -368,13 +366,11 @@ class AscendSFAImpl(MLAAttentionImpl):
assert self.indexer is not None, "Indexer is required for DSA."
self.enable_sfa_cp = enable_sp()
self.enable_sfa_cp = enable_dsa_cp()
self.local_num_heads = self.num_heads
self.vllm_config = get_current_vllm_config()
if self.enable_sfa_cp:
self.local_num_heads = self.num_heads * self.tp_size
self._replace_linear_class_for_sfa_cp()
self.layer_sharding_kwargs = []
for layer_name in (get_ascend_config().layer_sharding or []):
if layer_name in kwargs:
@@ -925,42 +921,3 @@ class AscendSFAImpl(MLAAttentionImpl):
sparse_count=2048,
sparse_mode=3)
return topk_indices
def _replace_linear_class_for_sfa_cp(self):
vllm_config = get_current_vllm_config()
# Dispose tensor from the original q_proj
dispose_layer(self.q_proj)
# Construct the new q_proj using ReplicatedLinear
new_q_proj = ReplicatedLinear(self.q_lora_rank,
self.local_num_heads * self.qk_head_dim,
bias=False,
quant_config=vllm_config.quant_config,
prefix=self.q_proj.prefix)
# Replace the q_proj with the new one
replace_layer(self.q_proj, new_q_proj)
# Dispose tensor from the original kv_b_proj
dispose_layer(self.kv_b_proj)
# Construct the new kv_b_proj using ReplicatedLinear
new_kv_b_proj = ReplicatedLinear(
self.kv_lora_rank,
self.local_num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=vllm_config.quant_config,
prefix=self.kv_b_proj.prefix)
# Replace the kv_b_proj with the new one
replace_layer(self.kv_b_proj, new_kv_b_proj)
# Dispose tensor from the original o_proj
dispose_layer(self.o_proj)
# Construct the new o_proj using ReplicatedLinear
config = vllm_config.model_config.hf_text_config
new_o_proj = ReplicatedLinear(config.num_attention_heads *
config.v_head_dim,
config.hidden_size,
bias=False,
quant_config=vllm_config.quant_config,
prefix=self.o_proj.prefix)
# Replace the o_proj with the new one
replace_layer(self.o_proj, new_o_proj)