[Refactor] Replace the implementations of o_proj, q_b_proj, and kv_b_proj with custom_op for sharded CP (#5698)
### What this PR does / why we need it?
Based on the Sharded-CP feature
PR:https://github.com/vllm-project/vllm-ascend/pull/4702;
RFC:https://github.com/vllm-project/vllm/issues/30055
This PR officially integrates Deepseek V3.2's DSA-CP support on the
basis of https://github.com/vllm-project/vllm-ascend/pull/4702,
improving inference efficiency and scalability under mixed
prefill-decode workloads. The main improvements include:
- Replace the implementations of o_proj, q_b_proj, and kv_b_proj with
custom_op for TP=1.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Signed-off-by: Kurumi5210 <jaychou1620@gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
@@ -100,8 +100,22 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
|||||||
assert builder.device == device
|
assert builder.device == device
|
||||||
assert builder.vllm_config == vllm_config
|
assert builder.vllm_config == vllm_config
|
||||||
|
|
||||||
|
@patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config")
|
||||||
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
|
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
|
||||||
def test_ascend_sfa_metadata_builder_build(self, mock_get_cos_and_sin_mla):
|
@patch("vllm_ascend.attention.sfa_v1.enable_dsa_cp")
|
||||||
|
def test_ascend_sfa_metadata_builder_build(
|
||||||
|
self,
|
||||||
|
mock_enable_dsa_cp,
|
||||||
|
mock_get_cos_and_sin_mla,
|
||||||
|
mock_get_current_vllm_config,
|
||||||
|
):
|
||||||
|
mock_enable_dsa_cp.return_value = False
|
||||||
|
|
||||||
|
cfg = MagicMock()
|
||||||
|
cfg.model_config = MagicMock()
|
||||||
|
cfg.model_config.hf_text_config = MagicMock()
|
||||||
|
|
||||||
|
mock_get_current_vllm_config.return_value = cfg
|
||||||
kv_cache_spec = MagicMock()
|
kv_cache_spec = MagicMock()
|
||||||
layer_names = ["layer1", "layer2"]
|
layer_names = ["layer1", "layer2"]
|
||||||
vllm_config = MagicMock()
|
vllm_config = MagicMock()
|
||||||
@@ -144,9 +158,16 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
|||||||
assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens
|
assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens
|
||||||
assert metadata.slot_mapping.shape == (100, 4, 1024)
|
assert metadata.slot_mapping.shape == (100, 4, 1024)
|
||||||
|
|
||||||
|
@patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config")
|
||||||
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
|
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
|
||||||
def test_ascend_sfa_metadata_builder_build_for_graph_capture(
|
def test_ascend_sfa_metadata_builder_build_for_graph_capture(
|
||||||
self, mock_get_cos_and_sin_mla):
|
self, mock_get_cos_and_sin_mla, mock_get_current_vllm_config):
|
||||||
|
cfg = MagicMock()
|
||||||
|
cfg.model_config = MagicMock()
|
||||||
|
cfg.model_config.hf_text_config = MagicMock()
|
||||||
|
|
||||||
|
mock_get_current_vllm_config.return_value = cfg
|
||||||
|
|
||||||
kv_cache_spec = MagicMock()
|
kv_cache_spec = MagicMock()
|
||||||
layer_names = ["layer1", "layer2"]
|
layer_names = ["layer1", "layer2"]
|
||||||
vllm_config = MagicMock()
|
vllm_config = MagicMock()
|
||||||
|
|||||||
@@ -10,8 +10,7 @@ from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
|
|||||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.model_executor.layers.linear import (ReplicatedLinear,
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||||
UnquantizedLinearMethod)
|
|
||||||
from vllm.triton_utils import HAS_TRITON
|
from vllm.triton_utils import HAS_TRITON
|
||||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
|
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
|
||||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||||
@@ -34,7 +33,7 @@ from vllm_ascend.ops.triton.rope import rope_forward_triton
|
|||||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer,
|
||||||
enable_sp, maybe_trans_nz, replace_layer)
|
enable_dsa_cp, maybe_trans_nz)
|
||||||
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -149,8 +148,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|||||||
got {self.decode_threshold}"
|
got {self.decode_threshold}"
|
||||||
|
|
||||||
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||||
self.enable_sfa_cp = enable_sp() and \
|
self.enable_sfa_cp = enable_dsa_cp()
|
||||||
hasattr(self.model_config.hf_text_config, "index_topk")
|
|
||||||
|
|
||||||
assert not (
|
assert not (
|
||||||
self.enable_sfa_cp
|
self.enable_sfa_cp
|
||||||
@@ -368,13 +366,11 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
assert self.indexer is not None, "Indexer is required for DSA."
|
assert self.indexer is not None, "Indexer is required for DSA."
|
||||||
|
|
||||||
self.enable_sfa_cp = enable_sp()
|
self.enable_sfa_cp = enable_dsa_cp()
|
||||||
self.local_num_heads = self.num_heads
|
self.local_num_heads = self.num_heads
|
||||||
self.vllm_config = get_current_vllm_config()
|
self.vllm_config = get_current_vllm_config()
|
||||||
if self.enable_sfa_cp:
|
if self.enable_sfa_cp:
|
||||||
self.local_num_heads = self.num_heads * self.tp_size
|
self.local_num_heads = self.num_heads * self.tp_size
|
||||||
|
|
||||||
self._replace_linear_class_for_sfa_cp()
|
|
||||||
self.layer_sharding_kwargs = []
|
self.layer_sharding_kwargs = []
|
||||||
for layer_name in (get_ascend_config().layer_sharding or []):
|
for layer_name in (get_ascend_config().layer_sharding or []):
|
||||||
if layer_name in kwargs:
|
if layer_name in kwargs:
|
||||||
@@ -925,42 +921,3 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
sparse_count=2048,
|
sparse_count=2048,
|
||||||
sparse_mode=3)
|
sparse_mode=3)
|
||||||
return topk_indices
|
return topk_indices
|
||||||
|
|
||||||
def _replace_linear_class_for_sfa_cp(self):
|
|
||||||
|
|
||||||
vllm_config = get_current_vllm_config()
|
|
||||||
# Dispose tensor from the original q_proj
|
|
||||||
dispose_layer(self.q_proj)
|
|
||||||
# Construct the new q_proj using ReplicatedLinear
|
|
||||||
new_q_proj = ReplicatedLinear(self.q_lora_rank,
|
|
||||||
self.local_num_heads * self.qk_head_dim,
|
|
||||||
bias=False,
|
|
||||||
quant_config=vllm_config.quant_config,
|
|
||||||
prefix=self.q_proj.prefix)
|
|
||||||
# Replace the q_proj with the new one
|
|
||||||
replace_layer(self.q_proj, new_q_proj)
|
|
||||||
|
|
||||||
# Dispose tensor from the original kv_b_proj
|
|
||||||
dispose_layer(self.kv_b_proj)
|
|
||||||
# Construct the new kv_b_proj using ReplicatedLinear
|
|
||||||
new_kv_b_proj = ReplicatedLinear(
|
|
||||||
self.kv_lora_rank,
|
|
||||||
self.local_num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
|
||||||
bias=False,
|
|
||||||
quant_config=vllm_config.quant_config,
|
|
||||||
prefix=self.kv_b_proj.prefix)
|
|
||||||
# Replace the kv_b_proj with the new one
|
|
||||||
replace_layer(self.kv_b_proj, new_kv_b_proj)
|
|
||||||
|
|
||||||
# Dispose tensor from the original o_proj
|
|
||||||
dispose_layer(self.o_proj)
|
|
||||||
# Construct the new o_proj using ReplicatedLinear
|
|
||||||
config = vllm_config.model_config.hf_text_config
|
|
||||||
new_o_proj = ReplicatedLinear(config.num_attention_heads *
|
|
||||||
config.v_head_dim,
|
|
||||||
config.hidden_size,
|
|
||||||
bias=False,
|
|
||||||
quant_config=vllm_config.quant_config,
|
|
||||||
prefix=self.o_proj.prefix)
|
|
||||||
# Replace the o_proj with the new one
|
|
||||||
replace_layer(self.o_proj, new_o_proj)
|
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ Row parallel op follows a similar approach - inherit from RowColumnParallelOp an
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from types import SimpleNamespace
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -59,7 +60,7 @@ from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group,
|
|||||||
get_flashcomm2_otp_group,
|
get_flashcomm2_otp_group,
|
||||||
get_mlp_tp_group,
|
get_mlp_tp_group,
|
||||||
get_otp_group)
|
get_otp_group)
|
||||||
from vllm_ascend.utils import (enable_sp, flashcomm2_enable,
|
from vllm_ascend.utils import (enable_dsa_cp, enable_sp, flashcomm2_enable,
|
||||||
get_flashcomm2_reorgnized_batch_ids,
|
get_flashcomm2_reorgnized_batch_ids,
|
||||||
matmul_allreduce_enable, mlp_tp_enable,
|
matmul_allreduce_enable, mlp_tp_enable,
|
||||||
oproj_tp_enable, shared_expert_dp_enabled)
|
oproj_tp_enable, shared_expert_dp_enabled)
|
||||||
@@ -609,9 +610,60 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
|||||||
self.unique_prefix = self.layer.unique_prefix
|
self.unique_prefix = self.layer.unique_prefix
|
||||||
|
|
||||||
|
|
||||||
|
class ShardedCPRowParallelOp(CustomRowParallelOp):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def comm_group(self):
|
||||||
|
# fake comm group to bypass tp logic
|
||||||
|
return SimpleNamespace(world_size=1,
|
||||||
|
rank_in_group=0,
|
||||||
|
device_group=None)
|
||||||
|
|
||||||
|
def apply_impl(
|
||||||
|
self,
|
||||||
|
input_,
|
||||||
|
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||||
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||||
|
assert self.quant_method is not None
|
||||||
|
output = self.quant_method.apply(self.layer, input_, bias_)
|
||||||
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
|
if not self.return_bias:
|
||||||
|
return output
|
||||||
|
return output, output_bias
|
||||||
|
|
||||||
|
def update_attrs(self):
|
||||||
|
super().update_attrs()
|
||||||
|
self.layer.reduce_results = False
|
||||||
|
|
||||||
|
|
||||||
|
class ShardedCPColumnParallelOp(CustomColumnParallelOp):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def comm_group(self):
|
||||||
|
# fake comm group to bypass tp logic
|
||||||
|
return SimpleNamespace(world_size=1,
|
||||||
|
rank_in_group=0,
|
||||||
|
device_group=None)
|
||||||
|
|
||||||
|
def apply_impl(
|
||||||
|
self,
|
||||||
|
input_,
|
||||||
|
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||||
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
assert self.quant_method is not None
|
||||||
|
output = self.quant_method.apply(self.layer, input_, bias)
|
||||||
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
|
if not self.return_bias:
|
||||||
|
return output
|
||||||
|
return output, output_bias
|
||||||
|
|
||||||
|
|
||||||
def _get_column_parallel_op(
|
def _get_column_parallel_op(
|
||||||
prefix, layer
|
prefix, layer
|
||||||
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]:
|
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
|
||||||
|
ShardedCPColumnParallelOp]]:
|
||||||
|
if enable_dsa_cp() and ("q_b_proj" in prefix or "kv_b_proj" in prefix):
|
||||||
|
return ShardedCPColumnParallelOp(layer)
|
||||||
if "gate_up_proj" in prefix and mlp_tp_enable(
|
if "gate_up_proj" in prefix and mlp_tp_enable(
|
||||||
) and not is_moe_layer(prefix):
|
) and not is_moe_layer(prefix):
|
||||||
return MLPColumnParallelOp(layer)
|
return MLPColumnParallelOp(layer)
|
||||||
@@ -636,7 +688,9 @@ def _get_row_parallel_op(
|
|||||||
prefix, layer
|
prefix, layer
|
||||||
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
||||||
Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp,
|
Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp,
|
||||||
SequenceRowParallelOp]]:
|
SequenceRowParallelOp, ShardedCPRowParallelOp]]:
|
||||||
|
if enable_dsa_cp() and "o_proj" in prefix:
|
||||||
|
return ShardedCPRowParallelOp(layer)
|
||||||
if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix):
|
if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix):
|
||||||
return MLPRowParallelOp(layer)
|
return MLPRowParallelOp(layer)
|
||||||
if "o_proj" in prefix and oproj_tp_enable():
|
if "o_proj" in prefix and oproj_tp_enable():
|
||||||
@@ -670,7 +724,8 @@ def get_parallel_op(disable_tp, prefix, layer, direct):
|
|||||||
MLPRowParallelOp, OProjRowParallelOp,
|
MLPRowParallelOp, OProjRowParallelOp,
|
||||||
Flashcomm2OProjRowParallelOp,
|
Flashcomm2OProjRowParallelOp,
|
||||||
MatmulAllreduceRowParallelOp,
|
MatmulAllreduceRowParallelOp,
|
||||||
SequenceRowParallelOp]] = None
|
SequenceRowParallelOp, ShardedCPRowParallelOp,
|
||||||
|
ShardedCPColumnParallelOp]] = None
|
||||||
if direct == "row":
|
if direct == "row":
|
||||||
custom_op = _get_row_parallel_op(prefix, layer)
|
custom_op = _get_row_parallel_op(prefix, layer)
|
||||||
|
|
||||||
|
|||||||
@@ -1119,11 +1119,6 @@ def dispose_layer(layer: Any):
|
|||||||
dispose_tensor(attr_value)
|
dispose_tensor(attr_value)
|
||||||
|
|
||||||
|
|
||||||
def replace_layer(original_layer: Any, new_layer: Any):
|
|
||||||
original_layer.__class__ = new_layer.__class__
|
|
||||||
original_layer.__dict__ = new_layer.__dict__
|
|
||||||
|
|
||||||
|
|
||||||
def check_kv_extra_config(vllm_config):
|
def check_kv_extra_config(vllm_config):
|
||||||
|
|
||||||
def _check(name: str, config: dict):
|
def _check(name: str, config: dict):
|
||||||
@@ -1166,17 +1161,31 @@ def singleton(cls):
|
|||||||
return get_instance
|
return get_instance
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
|
||||||
def get_current_model_config():
|
|
||||||
from vllm.config import get_current_vllm_config
|
|
||||||
vllm_config = get_current_vllm_config()
|
|
||||||
return vllm_config.model_config
|
|
||||||
|
|
||||||
|
|
||||||
#TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32. and subsequent updates will introduce new interfaces. --zzhx1
|
#TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32. and subsequent updates will introduce new interfaces. --zzhx1
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
def enable_dsa_cp() -> bool:
|
def enable_dsa_cp() -> bool:
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
|
if vllm_config is None:
|
||||||
return is_ds_v32 and enable_sp()
|
return False
|
||||||
|
|
||||||
|
model_config = getattr(vllm_config, "model_config", None)
|
||||||
|
if model_config is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
hf_text_config = getattr(model_config, "hf_text_config", None)
|
||||||
|
if hf_text_config is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return hasattr(hf_text_config, "index_topk")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def enable_dsa_cp_with_layer_shard() -> bool:
|
||||||
|
if not enable_dsa_cp():
|
||||||
|
return False
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
|
is_prefill_instance = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_producer
|
||||||
|
return is_prefill_instance
|
||||||
|
|||||||
Reference in New Issue
Block a user