[Refactor] Replace the implementations of o_proj, q_b_proj, and kv_b_proj with custom_op for sharded CP (#5698)
### What this PR does / why we need it?
Based on the Sharded-CP feature
PR:https://github.com/vllm-project/vllm-ascend/pull/4702;
RFC:https://github.com/vllm-project/vllm/issues/30055
This PR officially integrates Deepseek V3.2's DSA-CP support on the
basis of https://github.com/vllm-project/vllm-ascend/pull/4702,
improving inference efficiency and scalability under mixed
prefill-decode workloads. The main improvements include:
- Replace the implementations of o_proj, q_b_proj, and kv_b_proj with
custom_op for TP=1.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Signed-off-by: Kurumi5210 <jaychou1620@gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
@@ -1119,11 +1119,6 @@ def dispose_layer(layer: Any):
|
||||
dispose_tensor(attr_value)
|
||||
|
||||
|
||||
def replace_layer(original_layer: Any, new_layer: Any):
|
||||
original_layer.__class__ = new_layer.__class__
|
||||
original_layer.__dict__ = new_layer.__dict__
|
||||
|
||||
|
||||
def check_kv_extra_config(vllm_config):
|
||||
|
||||
def _check(name: str, config: dict):
|
||||
@@ -1166,17 +1161,31 @@ def singleton(cls):
|
||||
return get_instance
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_current_model_config():
|
||||
from vllm.config import get_current_vllm_config
|
||||
vllm_config = get_current_vllm_config()
|
||||
return vllm_config.model_config
|
||||
|
||||
|
||||
#TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32. and subsequent updates will introduce new interfaces. --zzhx1
|
||||
@lru_cache(maxsize=1)
|
||||
def enable_dsa_cp() -> bool:
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
|
||||
return is_ds_v32 and enable_sp()
|
||||
if vllm_config is None:
|
||||
return False
|
||||
|
||||
model_config = getattr(vllm_config, "model_config", None)
|
||||
if model_config is None:
|
||||
return False
|
||||
|
||||
hf_text_config = getattr(model_config, "hf_text_config", None)
|
||||
if hf_text_config is None:
|
||||
return False
|
||||
|
||||
return hasattr(hf_text_config, "index_topk")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def enable_dsa_cp_with_layer_shard() -> bool:
|
||||
if not enable_dsa_cp():
|
||||
return False
|
||||
from vllm.config import get_current_vllm_config
|
||||
vllm_config = get_current_vllm_config()
|
||||
is_prefill_instance = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_producer
|
||||
return is_prefill_instance
|
||||
|
||||
Reference in New Issue
Block a user