From cb7c419bc0365fc3ae586893354addc649289d27 Mon Sep 17 00:00:00 2001 From: Qiu Date: Mon, 9 Feb 2026 18:52:25 +0800 Subject: [PATCH] [Feat](sfa,dcp) support dcp for sfa (#6563) ### What this PR does / why we need it? This PR adds DCP support to the SFA backend. Please note that due to operator constraints, the current implementation has to all-gather the entire KV cache and modify the block table to satisfy the operator input requirements. This results in significantly increased communication overhead and peak memory usage. Therefore, this is only a temporary workaround and will be refactored once the operator provides proper support. Additionally, because of the above limitations, `cp_kv_cache_interleave_size` is currently required to be equal to `block_size`. This restriction will also be removed after the refactor. #### Test accuracy test using DeepSeek-V3.2-Exp-W8A8 with dp2tp8dcp8 | dataset | version | metric | mode | vllm-api-general-stream | |----- | ----- | ----- | ----- | -----| | gsm8kdataset | - | accuracy | gen | 96.35 | - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 --------- Signed-off-by: QiuChunshuo --- .../workflows/schedule_nightly_test_a3.yaml | 3 + .../config/DeepSeek-V3_2-W8A8-cp.yaml | 91 +++++++++++++++++++ tests/ut/attention/test_sfa_v1.py | 10 +- vllm_ascend/attention/sfa_v1.py | 82 +++++++++++++++-- vllm_ascend/platform.py | 17 +++- 5 files changed, 190 insertions(+), 13 deletions(-) create mode 100644 tests/e2e/nightly/multi_node/config/DeepSeek-V3_2-W8A8-cp.yaml diff --git a/.github/workflows/schedule_nightly_test_a3.yaml b/.github/workflows/schedule_nightly_test_a3.yaml index 5f7dce93..9965fb00 100644 --- a/.github/workflows/schedule_nightly_test_a3.yaml +++ b/.github/workflows/schedule_nightly_test_a3.yaml @@ -77,6 +77,9 @@ jobs: - name: multi-node-qwenw8a8-2node-longseq config_file_path: Qwen3-235B-W8A8-longseq.yaml size: 2 + - name: multi-node-deepseek-V3_2-W8A8-cp + config_file_path: DeepSeek-V3_2-W8A8-cp.yaml + size: 2 - name: multi-node-qwen-disagg-pd config_file_path: Qwen3-235B-disagg-pd.yaml size: 2 diff --git a/tests/e2e/nightly/multi_node/config/DeepSeek-V3_2-W8A8-cp.yaml b/tests/e2e/nightly/multi_node/config/DeepSeek-V3_2-W8A8-cp.yaml new file mode 100644 index 00000000..ad5f0476 --- /dev/null +++ b/tests/e2e/nightly/multi_node/config/DeepSeek-V3_2-W8A8-cp.yaml @@ -0,0 +1,91 @@ +test_name: "test DeepSeek-V3.2-W8A8 for PCP&DCP" +model: "vllm-ascend/DeepSeek-V3.2-W8A8" +num_nodes: 2 +npu_per_node: 16 +env_common: + HCCL_OP_EXPANSION_MODE: "AIV" + + VLLM_USE_MODELSCOPE: true + HCCL_BUFFSIZE: 1024 + SERVER_PORT: 8080 + OMP_PROC_BIND: false + OMP_NUM_THREADS: 1 + PYTORCH_NPU_ALLOC_CONF: "expandable_segments:True" + VLLM_ASCEND_ENABLE_FLASHCOMM1: 1 + ASCEND_A3_EBA_ENABLE: 1 + + +deployment: + - + server_cmd: > + vllm serve vllm-ascend/DeepSeek-V3.2-W8A8 + --host 0.0.0.0 + --port $SERVER_PORT + --data-parallel-size 4 + --data-parallel-size-local 2 + --data-parallel-address $LOCAL_IP + --data-parallel-rpc-port 13399 + --tensor-parallel-size 8 + --decode-context-parallel-size 8 + --quantization ascend + --seed 1024 + --enable-expert-parallel + --max-num-seqs 16 + --max-model-len 8192 + --max-num-batched-tokens 4096 + --no-enable-prefix-caching + --gpu-memory-utilization 0.85 + --trust-remote-code + --speculative-config '{"num_speculative_tokens": 2, "method":"deepseek_mtp"}' + --compilation-config '{"cudagraph_capture_sizes": [3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48], "cudagraph_mode": "FULL_DECODE_ONLY"}' + --additional-config '{"layer_sharding": ["q_b_proj", "o_proj"]}' + --tokenizer-mode deepseek_v32 + --reasoning-parser deepseek_v3 + + - + server_cmd: > + vllm serve vllm-ascend/DeepSeek-V3.2-W8A8 + --headless + --data-parallel-size 4 + --data-parallel-rpc-port 13399 + --data-parallel-size-local 2 + --data-parallel-start-rank 2 + --data-parallel-address $MASTER_IP + --tensor-parallel-size 8 + --decode-context-parallel-size 8 + --quantization ascend + --seed 1024 + --enable-expert-parallel + --max-num-seqs 16 + --max-model-len 8192 + --max-num-batched-tokens 4096 + --no-enable-prefix-caching + --gpu-memory-utilization 0.85 + --trust-remote-code + --speculative-config '{"num_speculative_tokens": 2, "method":"deepseek_mtp"}' + --compilation-config '{"cudagraph_capture_sizes": [3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48], "cudagraph_mode": "FULL_DECODE_ONLY"}' + --additional-config '{"layer_sharding": ["q_b_proj", "o_proj"]}' + --tokenizer-mode deepseek_v32 + --reasoning-parser deepseek_v3 +benchmarks: + perf: + case_type: performance + dataset_path: vllm-ascend/GSM8K-in3500-bs2800 + request_conf: vllm_api_stream_chat + dataset_conf: gsm8k/gsm8k_gen_0_shot_cot_str_perf + num_prompts: 512 + max_out_len: 3000 + batch_size: 512 + request_rate: 11.2 + baseline: 1253.8466 + threshold: 0.97 + + acc: + case_type: accuracy + dataset_path: vllm-ascend/gsm8k-lite + request_conf: vllm_api_general_chat + dataset_conf: gsm8k/gsm8k_gen_0_shot_cot_chat_prompt + max_out_len: 4096 + batch_size: 64 + baseline: 95 + threshold: 5 diff --git a/tests/ut/attention/test_sfa_v1.py b/tests/ut/attention/test_sfa_v1.py index fd456b46..8d3cde39 100644 --- a/tests/ut/attention/test_sfa_v1.py +++ b/tests/ut/attention/test_sfa_v1.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import torch +from tests.ut.attention.utils import patch_distributed_groups from tests.ut.base import TestBase from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm.distributed.parallel_state import GroupCoordinator @@ -41,7 +42,7 @@ class TestAscendSFAMetadata(TestBase): slot_mapping = torch.randn(100, 4, 1024) seq_lens = torch.tensor([30, 50]) cum_query_lens = torch.tensor([0, 30, 80]) - block_tables = torch.randint(0, 100, (100, 4)) + block_table = torch.randint(0, 100, (100, 4)) rope_dim = 32 max_seq_len = int(seq_lens.max().item()) @@ -58,7 +59,7 @@ class TestAscendSFAMetadata(TestBase): slot_mapping=slot_mapping, seq_lens=seq_lens, cum_query_lens=cum_query_lens, - block_tables=block_tables, + block_table=block_table, sin=sin, cos=cos, num_input_tokens=num_input_tokens, @@ -71,7 +72,7 @@ class TestAscendSFAMetadata(TestBase): self.assertIs(metadata.slot_mapping, slot_mapping) self.assertTrue(torch.equal(metadata.seq_lens, seq_lens)) self.assertTrue(torch.equal(metadata.cum_query_lens, cum_query_lens)) - self.assertIs(metadata.block_tables, block_tables) + self.assertIs(metadata.block_table, block_table) self.assertIs(metadata.sin, sin) self.assertIs(metadata.cos, cos) self.assertEqual(metadata.num_input_tokens, num_input_tokens) @@ -135,6 +136,7 @@ class TestAscendSFAMetadataBuilder(TestBase): self.patcher.stop() self.parent_init_patcher.stop() + @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) def test_ascend_sfa_metadata_builder_default(self): kv_cache_spec = MagicMock() layer_names = ["layer1", "layer2"] @@ -160,6 +162,7 @@ class TestAscendSFAMetadataBuilder(TestBase): @patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config") @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla") @patch("vllm_ascend.attention.sfa_v1.enable_dsa_cp") + @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) def test_ascend_sfa_metadata_builder_build( self, mock_enable_dsa_cp, @@ -222,6 +225,7 @@ class TestAscendSFAMetadataBuilder(TestBase): @patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config") @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla") + @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) def test_ascend_sfa_metadata_builder_build_for_graph_capture( self, mock_get_cos_and_sin_mla, mock_get_current_vllm_config): cfg = MagicMock() diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index e2271956..c926fdf5 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -6,7 +6,7 @@ import torch_npu import vllm.envs as envs_vllm from torch import nn from vllm.config import VllmConfig, get_current_vllm_config -from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group +from vllm.distributed import get_dcp_group, get_pcp_group, get_tensor_model_parallel_world_size, get_tp_group from vllm.forward_context import get_forward_context from vllm.logger import logger from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder @@ -95,6 +95,12 @@ class DSACPContext: actual_seq_lengths_key: torch.Tensor +@dataclass +class SFACPMetadata: + block_table_cp: torch.Tensor + valid_block_ids: torch.Tensor + + @dataclass class AscendSFAMetadata: """Metadata for MLACommon. @@ -114,7 +120,7 @@ class AscendSFAMetadata: slot_mapping: torch.Tensor seq_lens: torch.Tensor cum_query_lens: torch.Tensor - block_tables: torch.Tensor + block_table: torch.Tensor sin: torch.Tensor cos: torch.Tensor @@ -127,6 +133,7 @@ class AscendSFAMetadata: attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill dsa_cp_context: DSACPContext | None = None reshape_cache_event: torch.npu.Event = None + sfa_cp_metadata: SFACPMetadata | None = None M = TypeVar("M", bound=AscendSFAMetadata) @@ -178,6 +185,14 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1, dtype=torch.int32, device=device) self.actual_seq_lengths_key = torch.empty_like(self.actual_seq_lengths_query) + self.pcp_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 + self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None + + self.dcp_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group if self.dcp_size > 1 else 0 + self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None + @staticmethod def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int: return ascend_chunked_prefill_workspace_size(vllm_config) @@ -294,6 +309,22 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): actual_seq_lengths_key=actual_seq_lengths_key, ) + sfa_cp_metadata = None + if self.pcp_size * self.dcp_size > 1: + valid_block_ids, new_block_table = block_table.flatten().unique(return_inverse=True) + num_blocks = valid_block_ids.shape[0] + # Note(qcs): `block_table_cp` will have dirty values in the part beyond kv_lens. + # We assume that we can always get the correct kv_lens or kv index, + # so we omit the dirty value processing here. + block_table_cp = ( + new_block_table.unsqueeze(-1).to(block_table) + + (torch.arange(self.pcp_size * self.dcp_size) * num_blocks).view(1, 1, -1).to(block_table) + ).reshape(block_table.shape[0], -1) + sfa_cp_metadata = SFACPMetadata( + block_table_cp=block_table_cp, + valid_block_ids=valid_block_ids, + ) + return self.metadata_cls( # type: ignore num_input_tokens=common_attn_metadata.num_input_tokens, num_actual_tokens=num_actual_tokens, @@ -303,10 +334,11 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): head_dim=self.model_config.get_head_size(), attn_mask=self.attn_mask_builder.get_attention_mask(self.model_config), attn_state=common_attn_metadata.attn_state, - block_tables=block_table, + block_table=block_table, sin=sin[:num_input_tokens], cos=cos[:num_input_tokens], dsa_cp_context=dsa_cp_context, + sfa_cp_metadata=sfa_cp_metadata, ) def build_for_graph_capture( @@ -417,6 +449,14 @@ class AscendSFAImpl(MLAAttentionImpl): ) register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs) + self.pcp_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 + self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None + + self.dcp_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group if self.dcp_size > 1 else 0 + self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None + def process_weights_after_loading(self, act_dtype: torch.dtype): # NOTE: We currently do not support quant kv_b_proj. assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod) @@ -849,18 +889,28 @@ class AscendSFAImpl(MLAAttentionImpl): need_gather_q_kv=need_gather_q_kv, ) + block_table = attn_metadata.block_table + kv = kv_cache[0] + key_rope = kv_cache[1] + if self.pcp_size * self.dcp_size > 1: + assert attn_metadata.sfa_cp_metadata is not None + valid_block_ids = attn_metadata.sfa_cp_metadata.valid_block_ids + kv = self.gather_kv_cross_cp(kv, valid_block_ids) + key_rope = self.gather_kv_cross_cp(key_rope, valid_block_ids) + block_table = attn_metadata.sfa_cp_metadata.block_table_cp + attn_output = torch.ops._C_ascend.npu_sparse_flash_attention( query=ql_nope, - key=kv_cache[0], - value=kv_cache[0], + key=kv, + value=kv, sparse_indices=topk_indices, scale_value=self.scale, sparse_block_size=1, - block_table=attn_metadata.block_tables, + block_table=block_table, actual_seq_lengths_query=actual_seq_lengths_query, actual_seq_lengths_kv=actual_seq_lengths_key, query_rope=q_pe, - key_rope=kv_cache[1], + key_rope=key_rope, layout_query="TND", layout_kv="PA_BSND", sparse_mode=3, @@ -894,6 +944,15 @@ class AscendSFAImpl(MLAAttentionImpl): return output_padded + def gather_kv_cross_cp(self, kv_cache: torch.Tensor, valid_block_ids: torch.Tensor) -> torch.Tensor: + # Note(qcs): we need set kv_cache_interleave_size = block_size for sfa!!! + kv_cache = torch.index_select(kv_cache, 0, valid_block_ids) + if self.dcp_size > 1: + kv_cache = get_dcp_group().all_gather(kv_cache, 0) + if self.pcp_size > 1: + kv_cache = get_pcp_group().all_gather(kv_cache, 0) + return kv_cache + def indexer_select_pre_process( self, x: torch.Tensor, @@ -969,11 +1028,16 @@ class AscendSFAImpl(MLAAttentionImpl): weights, _ = self.weights_proj(x) weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(weights, need_gather_q_kv) - block_table = attn_metadata.block_tables + key = kv_cache[2] + block_table = attn_metadata.block_table + if self.pcp_size * self.dcp_size > 1: + assert attn_metadata.sfa_cp_metadata is not None + key = self.gather_kv_cross_cp(key, attn_metadata.sfa_cp_metadata.valid_block_ids) + block_table = attn_metadata.sfa_cp_metadata.block_table_cp topk_indices = torch.ops._C_ascend.npu_lightning_indexer( query=q, - key=kv_cache[2], + key=key, weights=weights, actual_seq_lengths_query=actual_seq_lengths_query, actual_seq_lengths_key=actual_seq_lengths_key, diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 940d02a5..00c4e6cf 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -378,10 +378,11 @@ class NPUPlatform(Platform): vllm_config.scheduler_config.enable_chunked_prefill = True vllm_config.scheduler_config.SLO_limits_for_dynamic_batch = ascend_config.SLO_limits_for_dynamic_batch + cp_size = parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size if ( vllm_config.kv_transfer_config is not None and cache_config.block_size != parallel_config.cp_kv_cache_interleave_size - and parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size > 1 + and cp_size > 1 ): raise AssertionError( f"cp_kv_cache_interleave_size({parallel_config.cp_kv_cache_interleave_size}) " @@ -389,6 +390,20 @@ class NPUPlatform(Platform): "needs to be equal if use pcp or dcp > 1 in P/D disaggregate and kv pool scenario." ) + use_sparse = ( + model_config is not None + and model_config.hf_text_config is not None + and hasattr(model_config.hf_text_config, "index_topk") + ) + if use_sparse and cp_size > 1 and parallel_config.cp_kv_cache_interleave_size != cache_config.block_size: + logger.warning_once( + "The current SFA's PCP&DCP implementation requires" + f"cp_kv_cache_interleave_size({parallel_config.cp_kv_cache_interleave_size})" + f" == block_size({cache_config.block_size}). " + f"Override cp_kv_cache_interleave_size to {cache_config.block_size}." + ) + vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size + if is_vl_model(vllm_config): if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0"))) or bool( int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", "0"))