From bc0fd7ca7217498d5faa91504b0e8c3f822a5cc6 Mon Sep 17 00:00:00 2001 From: xiaocongtou6 <105542647+xiaocongtou6@users.noreply.github.com> Date: Fri, 6 Mar 2026 16:10:24 +0800 Subject: [PATCH] [Feat]Adapt the graph mode (piecewise and full_decode_only) of PCP and DCP for DeepSeek v3.2. (#6940) ### What this PR does / why we need it? Adapt the graph mode (piecewise and full_decode_only) of PCP and DCP for DeepSeek v3.2. ### How was this patch tested? Test output: {"object":"text_completion","model":"deepeek_v3","choices":[{"index":0,"text":" the head of state and head of government of the United States, indirectly elected to a four-year term by the American people through the Electoral College. The officeholder leads the executive branch of the federal government and is the commander-in-chief of the United States","logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null},{"index":1,"text":" Paris. This is the largest city in France and its main political, cultural and commercial center. The modern location of the city is the north of the central part of the country, on the banks of the Seine River Seine River Seine in 3\n\n","logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null},{"index":2,"text":" now\n\n# AI future is now\n\nThe world is changing at a rapid pace, and artificial intelligence (AI) is at the forefront of this transformation. From self-driving cars to virtual assistants, AI is already making a significant impact on our daily lives","logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null},{"index":3,"text":" a 3rd year student at the University of Lincoln studying Media Production. This blog is about my work throughout my final year on the course.\n\n## Tuesday 3 May 2016\n### Final Major Project - Evaluation\n\nFor my final project I","logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":27,"total_tokens":227,"completion_tokens":200,"prompt_tokens_details":null},"kv_transfer_params":null} - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/15d76f74e2fdb12a95ea00f0ca283acf6219a2b7 --------- Signed-off-by: xiaocongtou6 <2066962956@qq.com> Signed-off-by: xiaocongtou6 <105542647+xiaocongtou6@users.noreply.github.com> --- .../4-cards/long_sequence/test_basic.py | 10 ++- .../attention/context_parallel/common_cp.py | 1 + .../attention/context_parallel/sfa_cp.py | 82 +++++++++++-------- 3 files changed, 60 insertions(+), 33 deletions(-) diff --git a/tests/e2e/multicard/4-cards/long_sequence/test_basic.py b/tests/e2e/multicard/4-cards/long_sequence/test_basic.py index fa6e2633..66dbbc0b 100644 --- a/tests/e2e/multicard/4-cards/long_sequence/test_basic.py +++ b/tests/e2e/multicard/4-cards/long_sequence/test_basic.py @@ -25,6 +25,7 @@ from tests.e2e.conftest import VllmRunner, wait_until_npu_memory_free os.environ["HCCL_BUFFSIZE"] = "768" +@wait_until_npu_memory_free() def test_models_pcp_dcp_basic(): prompts = [ "The capital of France is", "Hello, my name is Tom, I am", @@ -59,12 +60,12 @@ def test_models_pcp_dcp_basic(): model = "vllm-ascend/DeepSeek-V3.2-W8A8-Pruning" with VllmRunner( model, - enforce_eager=True, max_model_len=1024, tensor_parallel_size=2, prefill_context_parallel_size=2, decode_context_parallel_size=2, enable_expert_parallel=True, + gpu_memory_utilization=0.2, block_size=128, quantization="ascend", ) as runner: @@ -84,6 +85,7 @@ def test_models_pcp_dcp_basic(): runner.model.generate(prompts, sampling_params) +@wait_until_npu_memory_free() def test_models_pcp_dcp_full_graph(): prompts = [ "The capital of France is", "Hello, my name is Tom, I am", @@ -121,6 +123,7 @@ def test_models_pcp_dcp_full_graph(): runner.model.generate(prompts, sampling_params) +@wait_until_npu_memory_free() def test_models_pcp_dcp_piece_wise(): prompts = [ "The capital of France is", "Hello, my name is Tom, I am", @@ -172,6 +175,7 @@ def test_pcp_basic(): runner.model.generate(prompts, sampling_params) +@wait_until_npu_memory_free() def test_pcp_full_graph(): prompts = [ "The capital of France is", "Hello, my name is Tom, I am", @@ -195,6 +199,7 @@ def test_pcp_full_graph(): runner.model.generate(prompts, sampling_params) +@wait_until_npu_memory_free() def test_pcp_piece_wise(): prompts = [ "The capital of France is", "Hello, my name is Tom, I am", @@ -214,6 +219,7 @@ def test_pcp_piece_wise(): runner.model.generate(prompts, sampling_params) +@wait_until_npu_memory_free() def test_dcp_basic(): prompts = [ "The capital of France is", "Hello, my name is Tom, I am", @@ -233,6 +239,7 @@ def test_dcp_basic(): runner.model.generate(prompts, sampling_params) +@wait_until_npu_memory_free() def test_dcp_full_graph(): prompts = [ "The capital of France is", "Hello, my name is Tom, I am", @@ -256,6 +263,7 @@ def test_dcp_full_graph(): runner.model.generate(prompts, sampling_params) +@wait_until_npu_memory_free() def test_dcp_piece_wise(): prompts = [ "The capital of France is", "Hello, my name is Tom, I am", diff --git a/vllm_ascend/attention/context_parallel/common_cp.py b/vllm_ascend/attention/context_parallel/common_cp.py index 6eb42040..e06b0e11 100644 --- a/vllm_ascend/attention/context_parallel/common_cp.py +++ b/vllm_ascend/attention/context_parallel/common_cp.py @@ -34,6 +34,7 @@ class AscendPCPMetadata: block_table_cp: torch.Tensor = None valid_block_ids: torch.Tensor = None prefill_q_cum_seqlens: torch.Tensor = None + block_arange: torch.Tensor = None @dataclass diff --git a/vllm_ascend/attention/context_parallel/sfa_cp.py b/vllm_ascend/attention/context_parallel/sfa_cp.py index 2b6f361d..dbf6d163 100644 --- a/vllm_ascend/attention/context_parallel/sfa_cp.py +++ b/vllm_ascend/attention/context_parallel/sfa_cp.py @@ -5,6 +5,7 @@ import torch import torch_npu from vllm.config import VllmConfig from vllm.distributed import get_dcp_group, get_pcp_group +from vllm.forward_context import get_forward_context from vllm.triton_utils import HAS_TRITON from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata @@ -55,6 +56,7 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder): dtype=torch.int32, device=device, ) + self.block_arange_buffer = torch.arange(self.pcp_size * self.dcp_size, dtype=torch.int32, device=device) def build( self, @@ -70,20 +72,7 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder): assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == common_attn_metadata.num_actual_tokens - block_table = metadata_cls.block_table - valid_block_ids, new_block_table = block_table.flatten().unique(return_inverse=True) - num_blocks = valid_block_ids.shape[0] - # Note(qcs): `block_table_cp` will have dirty values in the part beyond kv_lens. - # We assume that we can always get the correct kv_lens or kv index, - # so we omit the dirty value processing here. - block_table_cp = ( - new_block_table.unsqueeze(-1).to(block_table) - + (torch.arange(self.pcp_size * self.dcp_size) * num_blocks).view(1, 1, -1).to(block_table) - ).reshape(block_table.shape[0], -1) - - sfa_cp_metadata = self.build_cp_metadata( - block_table_cp, valid_block_ids, metadata_cls.seq_lens, common_attn_metadata - ) + sfa_cp_metadata = self.build_cp_metadata(self.block_arange_buffer, metadata_cls.seq_lens, common_attn_metadata) metadata_cls.num_decode_tokens = num_decode_tokens metadata_cls.num_decodes = num_decodes metadata_cls.num_prefills = num_prefills @@ -127,8 +116,7 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder): def build_cp_metadata( self, - block_table_cp: torch.Tensor, - valid_block_ids: torch.Tensor, + block_arange: torch.Tensor, seq_lens: torch.Tensor, common_attn_metadata: AscendCommonAttentionMetadata, ) -> AscendPCPMetadata | None: @@ -144,8 +132,7 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder): head_attn_nomask_seqlens=q_head_kv_lens, tail_attn_nomask_seqlens=q_tail_kv_lens, pcp_allgather_restore_idx=common_long_seq_metadata.pcp_allgather_restore_idx, - block_table_cp=block_table_cp, - valid_block_ids=valid_block_ids, + block_arange=block_arange, ) @@ -198,16 +185,17 @@ class AscendSFACPImpl(AscendSFAImpl): kv = kv_cache[0] key_rope = kv_cache[1] + block_table = attn_metadata.block_table assert attn_metadata.sfa_cp_metadata is not None - valid_block_ids = attn_metadata.sfa_cp_metadata.valid_block_ids - kv = self.gather_kv_cross_cp(kv, valid_block_ids) - key_rope = self.gather_kv_cross_cp(key_rope, valid_block_ids) - block_table = attn_metadata.sfa_cp_metadata.block_table_cp - + block_arange = attn_metadata.sfa_cp_metadata.block_arange + kv, block_table = self.gather_kv_cross_cp(kv, block_table, block_arange) + key_rope, _ = self.gather_kv_cross_cp(key_rope) + assert block_table is not None if self.pcp_size == 1: - return self._execute_sparse_flash_attention( + attn_output = self._execute_sparse_flash_attention( ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key ) + return self._align_to_graph_bucket_tokens(attn_output, attn_metadata) num_decodes = attn_metadata.num_decodes num_decode_tokens = attn_metadata.num_decode_tokens num_prefills = attn_metadata.num_prefills @@ -225,7 +213,7 @@ class AscendSFACPImpl(AscendSFAImpl): ) if num_prefills < 1: - return decode_attn_out + return self._align_to_graph_bucket_tokens(decode_attn_out, attn_metadata) # q split for head and tail q_head_idx = attn_metadata.sfa_cp_metadata.q_head_idx @@ -266,7 +254,30 @@ class AscendSFACPImpl(AscendSFAImpl): if decode_attn_out is not None: attn_output = torch.cat([decode_attn_out, attn_output], dim=0) - return attn_output + return self._align_to_graph_bucket_tokens(attn_output, attn_metadata) + + def _align_to_graph_bucket_tokens(self, attn_output: torch.Tensor | None, attn_metadata: M) -> torch.Tensor | None: + if attn_output is None: + return None + # In graph/piecewise mode, output buffer uses graph bucket token size + # (forward_context.num_tokens), while PCP path may compute only valid + # tokens. Align to the larger one to avoid later write-back mismatch. + forward_context = get_forward_context() + target_tokens = max( + attn_metadata.num_input_tokens, + forward_context.num_tokens if forward_context is not None else 0, + ) + + if attn_output.shape[0] == target_tokens: + return attn_output + aligned = torch.zeros( + (target_tokens, *attn_output.shape[1:]), + dtype=attn_output.dtype, + device=attn_output.device, + ) + valid_tokens = min(attn_output.shape[0], target_tokens) + aligned[:valid_tokens] = attn_output[:valid_tokens] + return aligned def _execute_sparse_flash_attention( self, ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key @@ -289,14 +300,20 @@ class AscendSFACPImpl(AscendSFAImpl): ) return attn_output - def gather_kv_cross_cp(self, kv_cache: torch.Tensor, valid_block_ids: torch.Tensor) -> torch.Tensor: + def gather_kv_cross_cp( + self, kv_cache: torch.Tensor, block_tables: torch.Tensor | None = None, block_arange: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor | None]: # Note(qcs): we need set kv_cache_interleave_size = block_size for sfa!!! - kv_cache = torch.index_select(kv_cache, 0, valid_block_ids) + block_num = kv_cache.shape[0] if self.dcp_size > 1: kv_cache = get_dcp_group().all_gather(kv_cache, 0) if self.pcp_size > 1: kv_cache = get_pcp_group().all_gather(kv_cache, 0) - return kv_cache + if block_tables is not None and block_arange is not None: + block_tables = ( + block_tables.unsqueeze(-1) + (block_arange * block_num).view(1, 1, -1).to(block_tables) + ).reshape(block_tables.shape[0], -1) + return kv_cache, block_tables def indexer_select_post_process( self, @@ -330,9 +347,11 @@ class AscendSFACPImpl(AscendSFAImpl): q = q_li key = kv_cache[2] + block_table = attn_metadata.block_table assert attn_metadata.sfa_cp_metadata is not None - key = self.gather_kv_cross_cp(key, attn_metadata.sfa_cp_metadata.valid_block_ids) - block_table = attn_metadata.sfa_cp_metadata.block_table_cp + block_arange = attn_metadata.sfa_cp_metadata.block_arange + + key, block_table = self.gather_kv_cross_cp(key, block_table, block_arange) if self.pcp_size == 1: return self._execute_indexer_select( @@ -353,7 +372,6 @@ class AscendSFACPImpl(AscendSFAImpl): actual_seq_lengths_key[:num_decodes], block_table[:num_decodes], ) - # prefill compute if num_prefills == 0: return decode_topk_indices