[Refactor] Replace the implementations of o_proj, q_b_proj, and kv_b_proj with custom_op for sharded CP (#5698)
### What this PR does / why we need it?
Based on the Sharded-CP feature
PR:https://github.com/vllm-project/vllm-ascend/pull/4702;
RFC:https://github.com/vllm-project/vllm/issues/30055
This PR officially integrates Deepseek V3.2's DSA-CP support on the
basis of https://github.com/vllm-project/vllm-ascend/pull/4702,
improving inference efficiency and scalability under mixed
prefill-decode workloads. The main improvements include:
- Replace the implementations of o_proj, q_b_proj, and kv_b_proj with
custom_op for TP=1.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Signed-off-by: Kurumi5210 <jaychou1620@gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
@@ -10,8 +10,7 @@ from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.linear import (ReplicatedLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
@@ -34,7 +33,7 @@ from vllm_ascend.ops.triton.rope import rope_forward_triton
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer,
|
||||
enable_sp, maybe_trans_nz, replace_layer)
|
||||
enable_dsa_cp, maybe_trans_nz)
|
||||
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -149,8 +148,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
got {self.decode_threshold}"
|
||||
|
||||
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||
self.enable_sfa_cp = enable_sp() and \
|
||||
hasattr(self.model_config.hf_text_config, "index_topk")
|
||||
self.enable_sfa_cp = enable_dsa_cp()
|
||||
|
||||
assert not (
|
||||
self.enable_sfa_cp
|
||||
@@ -368,13 +366,11 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
|
||||
assert self.indexer is not None, "Indexer is required for DSA."
|
||||
|
||||
self.enable_sfa_cp = enable_sp()
|
||||
self.enable_sfa_cp = enable_dsa_cp()
|
||||
self.local_num_heads = self.num_heads
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
if self.enable_sfa_cp:
|
||||
self.local_num_heads = self.num_heads * self.tp_size
|
||||
|
||||
self._replace_linear_class_for_sfa_cp()
|
||||
self.layer_sharding_kwargs = []
|
||||
for layer_name in (get_ascend_config().layer_sharding or []):
|
||||
if layer_name in kwargs:
|
||||
@@ -925,42 +921,3 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
sparse_count=2048,
|
||||
sparse_mode=3)
|
||||
return topk_indices
|
||||
|
||||
def _replace_linear_class_for_sfa_cp(self):
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
# Dispose tensor from the original q_proj
|
||||
dispose_layer(self.q_proj)
|
||||
# Construct the new q_proj using ReplicatedLinear
|
||||
new_q_proj = ReplicatedLinear(self.q_lora_rank,
|
||||
self.local_num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=vllm_config.quant_config,
|
||||
prefix=self.q_proj.prefix)
|
||||
# Replace the q_proj with the new one
|
||||
replace_layer(self.q_proj, new_q_proj)
|
||||
|
||||
# Dispose tensor from the original kv_b_proj
|
||||
dispose_layer(self.kv_b_proj)
|
||||
# Construct the new kv_b_proj using ReplicatedLinear
|
||||
new_kv_b_proj = ReplicatedLinear(
|
||||
self.kv_lora_rank,
|
||||
self.local_num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=vllm_config.quant_config,
|
||||
prefix=self.kv_b_proj.prefix)
|
||||
# Replace the kv_b_proj with the new one
|
||||
replace_layer(self.kv_b_proj, new_kv_b_proj)
|
||||
|
||||
# Dispose tensor from the original o_proj
|
||||
dispose_layer(self.o_proj)
|
||||
# Construct the new o_proj using ReplicatedLinear
|
||||
config = vllm_config.model_config.hf_text_config
|
||||
new_o_proj = ReplicatedLinear(config.num_attention_heads *
|
||||
config.v_head_dim,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=vllm_config.quant_config,
|
||||
prefix=self.o_proj.prefix)
|
||||
# Replace the o_proj with the new one
|
||||
replace_layer(self.o_proj, new_o_proj)
|
||||
|
||||
Reference in New Issue
Block a user