[Refactor] Replace the implementations of o_proj, q_b_proj, and kv_b_proj with custom_op for sharded CP (#5698)
### What this PR does / why we need it?
Based on the Sharded-CP feature
PR:https://github.com/vllm-project/vllm-ascend/pull/4702;
RFC:https://github.com/vllm-project/vllm/issues/30055
This PR officially integrates Deepseek V3.2's DSA-CP support on the
basis of https://github.com/vllm-project/vllm-ascend/pull/4702,
improving inference efficiency and scalability under mixed
prefill-decode workloads. The main improvements include:
- Replace the implementations of o_proj, q_b_proj, and kv_b_proj with
custom_op for TP=1.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Signed-off-by: Kurumi5210 <jaychou1620@gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
@@ -100,8 +100,22 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
||||
assert builder.device == device
|
||||
assert builder.vllm_config == vllm_config
|
||||
|
||||
@patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config")
|
||||
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
|
||||
def test_ascend_sfa_metadata_builder_build(self, mock_get_cos_and_sin_mla):
|
||||
@patch("vllm_ascend.attention.sfa_v1.enable_dsa_cp")
|
||||
def test_ascend_sfa_metadata_builder_build(
|
||||
self,
|
||||
mock_enable_dsa_cp,
|
||||
mock_get_cos_and_sin_mla,
|
||||
mock_get_current_vllm_config,
|
||||
):
|
||||
mock_enable_dsa_cp.return_value = False
|
||||
|
||||
cfg = MagicMock()
|
||||
cfg.model_config = MagicMock()
|
||||
cfg.model_config.hf_text_config = MagicMock()
|
||||
|
||||
mock_get_current_vllm_config.return_value = cfg
|
||||
kv_cache_spec = MagicMock()
|
||||
layer_names = ["layer1", "layer2"]
|
||||
vllm_config = MagicMock()
|
||||
@@ -144,9 +158,16 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
||||
assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens
|
||||
assert metadata.slot_mapping.shape == (100, 4, 1024)
|
||||
|
||||
@patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config")
|
||||
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
|
||||
def test_ascend_sfa_metadata_builder_build_for_graph_capture(
|
||||
self, mock_get_cos_and_sin_mla):
|
||||
self, mock_get_cos_and_sin_mla, mock_get_current_vllm_config):
|
||||
cfg = MagicMock()
|
||||
cfg.model_config = MagicMock()
|
||||
cfg.model_config.hf_text_config = MagicMock()
|
||||
|
||||
mock_get_current_vllm_config.return_value = cfg
|
||||
|
||||
kv_cache_spec = MagicMock()
|
||||
layer_names = ["layer1", "layer2"]
|
||||
vllm_config = MagicMock()
|
||||
|
||||
Reference in New Issue
Block a user