diff --git a/.github/workflows/schedule_nightly_test_a3.yaml b/.github/workflows/schedule_nightly_test_a3.yaml index 5f7dce93..9965fb00 100644 --- a/.github/workflows/schedule_nightly_test_a3.yaml +++ b/.github/workflows/schedule_nightly_test_a3.yaml @@ -77,6 +77,9 @@ jobs: - name: multi-node-qwenw8a8-2node-longseq config_file_path: Qwen3-235B-W8A8-longseq.yaml size: 2 + - name: multi-node-deepseek-V3_2-W8A8-cp + config_file_path: DeepSeek-V3_2-W8A8-cp.yaml + size: 2 - name: multi-node-qwen-disagg-pd config_file_path: Qwen3-235B-disagg-pd.yaml size: 2 diff --git a/tests/e2e/nightly/multi_node/config/DeepSeek-V3_2-W8A8-cp.yaml b/tests/e2e/nightly/multi_node/config/DeepSeek-V3_2-W8A8-cp.yaml new file mode 100644 index 00000000..ad5f0476 --- /dev/null +++ b/tests/e2e/nightly/multi_node/config/DeepSeek-V3_2-W8A8-cp.yaml @@ -0,0 +1,91 @@ +test_name: "test DeepSeek-V3.2-W8A8 for PCP&DCP" +model: "vllm-ascend/DeepSeek-V3.2-W8A8" +num_nodes: 2 +npu_per_node: 16 +env_common: + HCCL_OP_EXPANSION_MODE: "AIV" + + VLLM_USE_MODELSCOPE: true + HCCL_BUFFSIZE: 1024 + SERVER_PORT: 8080 + OMP_PROC_BIND: false + OMP_NUM_THREADS: 1 + PYTORCH_NPU_ALLOC_CONF: "expandable_segments:True" + VLLM_ASCEND_ENABLE_FLASHCOMM1: 1 + ASCEND_A3_EBA_ENABLE: 1 + + +deployment: + - + server_cmd: > + vllm serve vllm-ascend/DeepSeek-V3.2-W8A8 + --host 0.0.0.0 + --port $SERVER_PORT + --data-parallel-size 4 + --data-parallel-size-local 2 + --data-parallel-address $LOCAL_IP + --data-parallel-rpc-port 13399 + --tensor-parallel-size 8 + --decode-context-parallel-size 8 + --quantization ascend + --seed 1024 + --enable-expert-parallel + --max-num-seqs 16 + --max-model-len 8192 + --max-num-batched-tokens 4096 + --no-enable-prefix-caching + --gpu-memory-utilization 0.85 + --trust-remote-code + --speculative-config '{"num_speculative_tokens": 2, "method":"deepseek_mtp"}' + --compilation-config '{"cudagraph_capture_sizes": [3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48], "cudagraph_mode": "FULL_DECODE_ONLY"}' + --additional-config '{"layer_sharding": ["q_b_proj", "o_proj"]}' + --tokenizer-mode deepseek_v32 + --reasoning-parser deepseek_v3 + + - + server_cmd: > + vllm serve vllm-ascend/DeepSeek-V3.2-W8A8 + --headless + --data-parallel-size 4 + --data-parallel-rpc-port 13399 + --data-parallel-size-local 2 + --data-parallel-start-rank 2 + --data-parallel-address $MASTER_IP + --tensor-parallel-size 8 + --decode-context-parallel-size 8 + --quantization ascend + --seed 1024 + --enable-expert-parallel + --max-num-seqs 16 + --max-model-len 8192 + --max-num-batched-tokens 4096 + --no-enable-prefix-caching + --gpu-memory-utilization 0.85 + --trust-remote-code + --speculative-config '{"num_speculative_tokens": 2, "method":"deepseek_mtp"}' + --compilation-config '{"cudagraph_capture_sizes": [3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48], "cudagraph_mode": "FULL_DECODE_ONLY"}' + --additional-config '{"layer_sharding": ["q_b_proj", "o_proj"]}' + --tokenizer-mode deepseek_v32 + --reasoning-parser deepseek_v3 +benchmarks: + perf: + case_type: performance + dataset_path: vllm-ascend/GSM8K-in3500-bs2800 + request_conf: vllm_api_stream_chat + dataset_conf: gsm8k/gsm8k_gen_0_shot_cot_str_perf + num_prompts: 512 + max_out_len: 3000 + batch_size: 512 + request_rate: 11.2 + baseline: 1253.8466 + threshold: 0.97 + + acc: + case_type: accuracy + dataset_path: vllm-ascend/gsm8k-lite + request_conf: vllm_api_general_chat + dataset_conf: gsm8k/gsm8k_gen_0_shot_cot_chat_prompt + max_out_len: 4096 + batch_size: 64 + baseline: 95 + threshold: 5 diff --git a/tests/ut/attention/test_sfa_v1.py b/tests/ut/attention/test_sfa_v1.py index fd456b46..8d3cde39 100644 --- a/tests/ut/attention/test_sfa_v1.py +++ b/tests/ut/attention/test_sfa_v1.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import torch +from tests.ut.attention.utils import patch_distributed_groups from tests.ut.base import TestBase from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm.distributed.parallel_state import GroupCoordinator @@ -41,7 +42,7 @@ class TestAscendSFAMetadata(TestBase): slot_mapping = torch.randn(100, 4, 1024) seq_lens = torch.tensor([30, 50]) cum_query_lens = torch.tensor([0, 30, 80]) - block_tables = torch.randint(0, 100, (100, 4)) + block_table = torch.randint(0, 100, (100, 4)) rope_dim = 32 max_seq_len = int(seq_lens.max().item()) @@ -58,7 +59,7 @@ class TestAscendSFAMetadata(TestBase): slot_mapping=slot_mapping, seq_lens=seq_lens, cum_query_lens=cum_query_lens, - block_tables=block_tables, + block_table=block_table, sin=sin, cos=cos, num_input_tokens=num_input_tokens, @@ -71,7 +72,7 @@ class TestAscendSFAMetadata(TestBase): self.assertIs(metadata.slot_mapping, slot_mapping) self.assertTrue(torch.equal(metadata.seq_lens, seq_lens)) self.assertTrue(torch.equal(metadata.cum_query_lens, cum_query_lens)) - self.assertIs(metadata.block_tables, block_tables) + self.assertIs(metadata.block_table, block_table) self.assertIs(metadata.sin, sin) self.assertIs(metadata.cos, cos) self.assertEqual(metadata.num_input_tokens, num_input_tokens) @@ -135,6 +136,7 @@ class TestAscendSFAMetadataBuilder(TestBase): self.patcher.stop() self.parent_init_patcher.stop() + @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) def test_ascend_sfa_metadata_builder_default(self): kv_cache_spec = MagicMock() layer_names = ["layer1", "layer2"] @@ -160,6 +162,7 @@ class TestAscendSFAMetadataBuilder(TestBase): @patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config") @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla") @patch("vllm_ascend.attention.sfa_v1.enable_dsa_cp") + @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) def test_ascend_sfa_metadata_builder_build( self, mock_enable_dsa_cp, @@ -222,6 +225,7 @@ class TestAscendSFAMetadataBuilder(TestBase): @patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config") @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla") + @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) def test_ascend_sfa_metadata_builder_build_for_graph_capture( self, mock_get_cos_and_sin_mla, mock_get_current_vllm_config): cfg = MagicMock() diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index e2271956..c926fdf5 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -6,7 +6,7 @@ import torch_npu import vllm.envs as envs_vllm from torch import nn from vllm.config import VllmConfig, get_current_vllm_config -from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group +from vllm.distributed import get_dcp_group, get_pcp_group, get_tensor_model_parallel_world_size, get_tp_group from vllm.forward_context import get_forward_context from vllm.logger import logger from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder @@ -95,6 +95,12 @@ class DSACPContext: actual_seq_lengths_key: torch.Tensor +@dataclass +class SFACPMetadata: + block_table_cp: torch.Tensor + valid_block_ids: torch.Tensor + + @dataclass class AscendSFAMetadata: """Metadata for MLACommon. @@ -114,7 +120,7 @@ class AscendSFAMetadata: slot_mapping: torch.Tensor seq_lens: torch.Tensor cum_query_lens: torch.Tensor - block_tables: torch.Tensor + block_table: torch.Tensor sin: torch.Tensor cos: torch.Tensor @@ -127,6 +133,7 @@ class AscendSFAMetadata: attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill dsa_cp_context: DSACPContext | None = None reshape_cache_event: torch.npu.Event = None + sfa_cp_metadata: SFACPMetadata | None = None M = TypeVar("M", bound=AscendSFAMetadata) @@ -178,6 +185,14 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1, dtype=torch.int32, device=device) self.actual_seq_lengths_key = torch.empty_like(self.actual_seq_lengths_query) + self.pcp_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 + self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None + + self.dcp_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group if self.dcp_size > 1 else 0 + self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None + @staticmethod def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int: return ascend_chunked_prefill_workspace_size(vllm_config) @@ -294,6 +309,22 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): actual_seq_lengths_key=actual_seq_lengths_key, ) + sfa_cp_metadata = None + if self.pcp_size * self.dcp_size > 1: + valid_block_ids, new_block_table = block_table.flatten().unique(return_inverse=True) + num_blocks = valid_block_ids.shape[0] + # Note(qcs): `block_table_cp` will have dirty values in the part beyond kv_lens. + # We assume that we can always get the correct kv_lens or kv index, + # so we omit the dirty value processing here. + block_table_cp = ( + new_block_table.unsqueeze(-1).to(block_table) + + (torch.arange(self.pcp_size * self.dcp_size) * num_blocks).view(1, 1, -1).to(block_table) + ).reshape(block_table.shape[0], -1) + sfa_cp_metadata = SFACPMetadata( + block_table_cp=block_table_cp, + valid_block_ids=valid_block_ids, + ) + return self.metadata_cls( # type: ignore num_input_tokens=common_attn_metadata.num_input_tokens, num_actual_tokens=num_actual_tokens, @@ -303,10 +334,11 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): head_dim=self.model_config.get_head_size(), attn_mask=self.attn_mask_builder.get_attention_mask(self.model_config), attn_state=common_attn_metadata.attn_state, - block_tables=block_table, + block_table=block_table, sin=sin[:num_input_tokens], cos=cos[:num_input_tokens], dsa_cp_context=dsa_cp_context, + sfa_cp_metadata=sfa_cp_metadata, ) def build_for_graph_capture( @@ -417,6 +449,14 @@ class AscendSFAImpl(MLAAttentionImpl): ) register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs) + self.pcp_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 + self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None + + self.dcp_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group if self.dcp_size > 1 else 0 + self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None + def process_weights_after_loading(self, act_dtype: torch.dtype): # NOTE: We currently do not support quant kv_b_proj. assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod) @@ -849,18 +889,28 @@ class AscendSFAImpl(MLAAttentionImpl): need_gather_q_kv=need_gather_q_kv, ) + block_table = attn_metadata.block_table + kv = kv_cache[0] + key_rope = kv_cache[1] + if self.pcp_size * self.dcp_size > 1: + assert attn_metadata.sfa_cp_metadata is not None + valid_block_ids = attn_metadata.sfa_cp_metadata.valid_block_ids + kv = self.gather_kv_cross_cp(kv, valid_block_ids) + key_rope = self.gather_kv_cross_cp(key_rope, valid_block_ids) + block_table = attn_metadata.sfa_cp_metadata.block_table_cp + attn_output = torch.ops._C_ascend.npu_sparse_flash_attention( query=ql_nope, - key=kv_cache[0], - value=kv_cache[0], + key=kv, + value=kv, sparse_indices=topk_indices, scale_value=self.scale, sparse_block_size=1, - block_table=attn_metadata.block_tables, + block_table=block_table, actual_seq_lengths_query=actual_seq_lengths_query, actual_seq_lengths_kv=actual_seq_lengths_key, query_rope=q_pe, - key_rope=kv_cache[1], + key_rope=key_rope, layout_query="TND", layout_kv="PA_BSND", sparse_mode=3, @@ -894,6 +944,15 @@ class AscendSFAImpl(MLAAttentionImpl): return output_padded + def gather_kv_cross_cp(self, kv_cache: torch.Tensor, valid_block_ids: torch.Tensor) -> torch.Tensor: + # Note(qcs): we need set kv_cache_interleave_size = block_size for sfa!!! + kv_cache = torch.index_select(kv_cache, 0, valid_block_ids) + if self.dcp_size > 1: + kv_cache = get_dcp_group().all_gather(kv_cache, 0) + if self.pcp_size > 1: + kv_cache = get_pcp_group().all_gather(kv_cache, 0) + return kv_cache + def indexer_select_pre_process( self, x: torch.Tensor, @@ -969,11 +1028,16 @@ class AscendSFAImpl(MLAAttentionImpl): weights, _ = self.weights_proj(x) weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(weights, need_gather_q_kv) - block_table = attn_metadata.block_tables + key = kv_cache[2] + block_table = attn_metadata.block_table + if self.pcp_size * self.dcp_size > 1: + assert attn_metadata.sfa_cp_metadata is not None + key = self.gather_kv_cross_cp(key, attn_metadata.sfa_cp_metadata.valid_block_ids) + block_table = attn_metadata.sfa_cp_metadata.block_table_cp topk_indices = torch.ops._C_ascend.npu_lightning_indexer( query=q, - key=kv_cache[2], + key=key, weights=weights, actual_seq_lengths_query=actual_seq_lengths_query, actual_seq_lengths_key=actual_seq_lengths_key, diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 940d02a5..00c4e6cf 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -378,10 +378,11 @@ class NPUPlatform(Platform): vllm_config.scheduler_config.enable_chunked_prefill = True vllm_config.scheduler_config.SLO_limits_for_dynamic_batch = ascend_config.SLO_limits_for_dynamic_batch + cp_size = parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size if ( vllm_config.kv_transfer_config is not None and cache_config.block_size != parallel_config.cp_kv_cache_interleave_size - and parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size > 1 + and cp_size > 1 ): raise AssertionError( f"cp_kv_cache_interleave_size({parallel_config.cp_kv_cache_interleave_size}) " @@ -389,6 +390,20 @@ class NPUPlatform(Platform): "needs to be equal if use pcp or dcp > 1 in P/D disaggregate and kv pool scenario." ) + use_sparse = ( + model_config is not None + and model_config.hf_text_config is not None + and hasattr(model_config.hf_text_config, "index_topk") + ) + if use_sparse and cp_size > 1 and parallel_config.cp_kv_cache_interleave_size != cache_config.block_size: + logger.warning_once( + "The current SFA's PCP&DCP implementation requires" + f"cp_kv_cache_interleave_size({parallel_config.cp_kv_cache_interleave_size})" + f" == block_size({cache_config.block_size}). " + f"Override cp_kv_cache_interleave_size to {cache_config.block_size}." + ) + vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size + if is_vl_model(vllm_config): if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0"))) or bool( int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", "0"))