[Refactor] Replace the implementations of o_proj, q_b_proj, and kv_b_proj with custom_op for sharded CP (#5698)

### What this PR does / why we need it?
Based on the Sharded-CP feature
PR:https://github.com/vllm-project/vllm-ascend/pull/4702;
RFC:https://github.com/vllm-project/vllm/issues/30055

This PR officially integrates Deepseek V3.2's DSA-CP support on the
basis of https://github.com/vllm-project/vllm-ascend/pull/4702,
improving inference efficiency and scalability under mixed
prefill-decode workloads. The main improvements include:
- Replace the implementations of o_proj, q_b_proj, and kv_b_proj with
custom_op for TP=1.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Signed-off-by: Kurumi5210 <jaychou1620@gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
zzhxxx
2026-01-09 15:58:40 +08:00
committed by GitHub
parent e11ff8e535
commit 64d29875f9
4 changed files with 110 additions and 68 deletions

View File

@@ -1119,11 +1119,6 @@ def dispose_layer(layer: Any):
dispose_tensor(attr_value)
def replace_layer(original_layer: Any, new_layer: Any):
original_layer.__class__ = new_layer.__class__
original_layer.__dict__ = new_layer.__dict__
def check_kv_extra_config(vllm_config):
def _check(name: str, config: dict):
@@ -1166,17 +1161,31 @@ def singleton(cls):
return get_instance
@lru_cache(maxsize=1)
def get_current_model_config():
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
return vllm_config.model_config
#TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32. and subsequent updates will introduce new interfaces. --zzhx1
@lru_cache(maxsize=1)
def enable_dsa_cp() -> bool:
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
return is_ds_v32 and enable_sp()
if vllm_config is None:
return False
model_config = getattr(vllm_config, "model_config", None)
if model_config is None:
return False
hf_text_config = getattr(model_config, "hf_text_config", None)
if hf_text_config is None:
return False
return hasattr(hf_text_config, "index_topk")
@lru_cache(maxsize=1)
def enable_dsa_cp_with_layer_shard() -> bool:
if not enable_dsa_cp():
return False
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
is_prefill_instance = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_producer
return is_prefill_instance