[Refactor] Replace the implementations of o_proj, q_b_proj, and kv_b_proj with custom_op for sharded CP (#5698)

### What this PR does / why we need it?
Based on the Sharded-CP feature
PR:https://github.com/vllm-project/vllm-ascend/pull/4702;
RFC:https://github.com/vllm-project/vllm/issues/30055

This PR officially integrates Deepseek V3.2's DSA-CP support on the
basis of https://github.com/vllm-project/vllm-ascend/pull/4702,
improving inference efficiency and scalability under mixed
prefill-decode workloads. The main improvements include:
- Replace the implementations of o_proj, q_b_proj, and kv_b_proj with
custom_op for TP=1.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Signed-off-by: Kurumi5210 <jaychou1620@gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
zzhxxx
2026-01-09 15:58:40 +08:00
committed by GitHub
parent e11ff8e535
commit 64d29875f9
4 changed files with 110 additions and 68 deletions

View File

@@ -100,8 +100,22 @@ class TestAscendSFAMetadataBuilder(TestBase):
assert builder.device == device
assert builder.vllm_config == vllm_config
@patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config")
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
def test_ascend_sfa_metadata_builder_build(self, mock_get_cos_and_sin_mla):
@patch("vllm_ascend.attention.sfa_v1.enable_dsa_cp")
def test_ascend_sfa_metadata_builder_build(
self,
mock_enable_dsa_cp,
mock_get_cos_and_sin_mla,
mock_get_current_vllm_config,
):
mock_enable_dsa_cp.return_value = False
cfg = MagicMock()
cfg.model_config = MagicMock()
cfg.model_config.hf_text_config = MagicMock()
mock_get_current_vllm_config.return_value = cfg
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
@@ -144,9 +158,16 @@ class TestAscendSFAMetadataBuilder(TestBase):
assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens
assert metadata.slot_mapping.shape == (100, 4, 1024)
@patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config")
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
def test_ascend_sfa_metadata_builder_build_for_graph_capture(
self, mock_get_cos_and_sin_mla):
self, mock_get_cos_and_sin_mla, mock_get_current_vllm_config):
cfg = MagicMock()
cfg.model_config = MagicMock()
cfg.model_config.hf_text_config = MagicMock()
mock_get_current_vllm_config.return_value = cfg
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()

View File

@@ -10,8 +10,7 @@ from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.linear import (ReplicatedLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
from vllm.v1.attention.backends.utils import AttentionCGSupport
@@ -34,7 +33,7 @@ from vllm_ascend.ops.triton.rope import rope_forward_triton
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer,
enable_sp, maybe_trans_nz, replace_layer)
enable_dsa_cp, maybe_trans_nz)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
if TYPE_CHECKING:
@@ -149,8 +148,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
got {self.decode_threshold}"
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.enable_sfa_cp = enable_sp() and \
hasattr(self.model_config.hf_text_config, "index_topk")
self.enable_sfa_cp = enable_dsa_cp()
assert not (
self.enable_sfa_cp
@@ -368,13 +366,11 @@ class AscendSFAImpl(MLAAttentionImpl):
assert self.indexer is not None, "Indexer is required for DSA."
self.enable_sfa_cp = enable_sp()
self.enable_sfa_cp = enable_dsa_cp()
self.local_num_heads = self.num_heads
self.vllm_config = get_current_vllm_config()
if self.enable_sfa_cp:
self.local_num_heads = self.num_heads * self.tp_size
self._replace_linear_class_for_sfa_cp()
self.layer_sharding_kwargs = []
for layer_name in (get_ascend_config().layer_sharding or []):
if layer_name in kwargs:
@@ -925,42 +921,3 @@ class AscendSFAImpl(MLAAttentionImpl):
sparse_count=2048,
sparse_mode=3)
return topk_indices
def _replace_linear_class_for_sfa_cp(self):
vllm_config = get_current_vllm_config()
# Dispose tensor from the original q_proj
dispose_layer(self.q_proj)
# Construct the new q_proj using ReplicatedLinear
new_q_proj = ReplicatedLinear(self.q_lora_rank,
self.local_num_heads * self.qk_head_dim,
bias=False,
quant_config=vllm_config.quant_config,
prefix=self.q_proj.prefix)
# Replace the q_proj with the new one
replace_layer(self.q_proj, new_q_proj)
# Dispose tensor from the original kv_b_proj
dispose_layer(self.kv_b_proj)
# Construct the new kv_b_proj using ReplicatedLinear
new_kv_b_proj = ReplicatedLinear(
self.kv_lora_rank,
self.local_num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=vllm_config.quant_config,
prefix=self.kv_b_proj.prefix)
# Replace the kv_b_proj with the new one
replace_layer(self.kv_b_proj, new_kv_b_proj)
# Dispose tensor from the original o_proj
dispose_layer(self.o_proj)
# Construct the new o_proj using ReplicatedLinear
config = vllm_config.model_config.hf_text_config
new_o_proj = ReplicatedLinear(config.num_attention_heads *
config.v_head_dim,
config.hidden_size,
bias=False,
quant_config=vllm_config.quant_config,
prefix=self.o_proj.prefix)
# Replace the o_proj with the new one
replace_layer(self.o_proj, new_o_proj)

View File

@@ -38,6 +38,7 @@ Row parallel op follows a similar approach - inherit from RowColumnParallelOp an
import re
from functools import lru_cache
from types import SimpleNamespace
from typing import Optional, Union
import torch
@@ -59,7 +60,7 @@ from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group,
get_flashcomm2_otp_group,
get_mlp_tp_group,
get_otp_group)
from vllm_ascend.utils import (enable_sp, flashcomm2_enable,
from vllm_ascend.utils import (enable_dsa_cp, enable_sp, flashcomm2_enable,
get_flashcomm2_reorgnized_batch_ids,
matmul_allreduce_enable, mlp_tp_enable,
oproj_tp_enable, shared_expert_dp_enabled)
@@ -609,9 +610,60 @@ class SequenceRowParallelOp(CustomRowParallelOp):
self.unique_prefix = self.layer.unique_prefix
class ShardedCPRowParallelOp(CustomRowParallelOp):
@property
def comm_group(self):
# fake comm group to bypass tp logic
return SimpleNamespace(world_size=1,
rank_in_group=0,
device_group=None)
def apply_impl(
self,
input_,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
assert self.quant_method is not None
output = self.quant_method.apply(self.layer, input_, bias_)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def update_attrs(self):
super().update_attrs()
self.layer.reduce_results = False
class ShardedCPColumnParallelOp(CustomColumnParallelOp):
@property
def comm_group(self):
# fake comm group to bypass tp logic
return SimpleNamespace(world_size=1,
rank_in_group=0,
device_group=None)
def apply_impl(
self,
input_,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self.layer, input_, bias)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def _get_column_parallel_op(
prefix, layer
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]:
prefix, layer
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
ShardedCPColumnParallelOp]]:
if enable_dsa_cp() and ("q_b_proj" in prefix or "kv_b_proj" in prefix):
return ShardedCPColumnParallelOp(layer)
if "gate_up_proj" in prefix and mlp_tp_enable(
) and not is_moe_layer(prefix):
return MLPColumnParallelOp(layer)
@@ -636,7 +688,9 @@ def _get_row_parallel_op(
prefix, layer
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]]:
SequenceRowParallelOp, ShardedCPRowParallelOp]]:
if enable_dsa_cp() and "o_proj" in prefix:
return ShardedCPRowParallelOp(layer)
if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix):
return MLPRowParallelOp(layer)
if "o_proj" in prefix and oproj_tp_enable():
@@ -670,7 +724,8 @@ def get_parallel_op(disable_tp, prefix, layer, direct):
MLPRowParallelOp, OProjRowParallelOp,
Flashcomm2OProjRowParallelOp,
MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]] = None
SequenceRowParallelOp, ShardedCPRowParallelOp,
ShardedCPColumnParallelOp]] = None
if direct == "row":
custom_op = _get_row_parallel_op(prefix, layer)

View File

@@ -1119,11 +1119,6 @@ def dispose_layer(layer: Any):
dispose_tensor(attr_value)
def replace_layer(original_layer: Any, new_layer: Any):
original_layer.__class__ = new_layer.__class__
original_layer.__dict__ = new_layer.__dict__
def check_kv_extra_config(vllm_config):
def _check(name: str, config: dict):
@@ -1166,17 +1161,31 @@ def singleton(cls):
return get_instance
@lru_cache(maxsize=1)
def get_current_model_config():
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
return vllm_config.model_config
#TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32. and subsequent updates will introduce new interfaces. --zzhx1
@lru_cache(maxsize=1)
def enable_dsa_cp() -> bool:
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
return is_ds_v32 and enable_sp()
if vllm_config is None:
return False
model_config = getattr(vllm_config, "model_config", None)
if model_config is None:
return False
hf_text_config = getattr(model_config, "hf_text_config", None)
if hf_text_config is None:
return False
return hasattr(hf_text_config, "index_topk")
@lru_cache(maxsize=1)
def enable_dsa_cp_with_layer_shard() -> bool:
if not enable_dsa_cp():
return False
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
is_prefill_instance = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_producer
return is_prefill_instance