diff --git a/tests/e2e/multicard/4-cards/long_sequence/test_basic.py b/tests/e2e/multicard/4-cards/long_sequence/test_basic.py index fa6e2633..66dbbc0b 100644 --- a/tests/e2e/multicard/4-cards/long_sequence/test_basic.py +++ b/tests/e2e/multicard/4-cards/long_sequence/test_basic.py @@ -25,6 +25,7 @@ from tests.e2e.conftest import VllmRunner, wait_until_npu_memory_free os.environ["HCCL_BUFFSIZE"] = "768" +@wait_until_npu_memory_free() def test_models_pcp_dcp_basic(): prompts = [ "The capital of France is", "Hello, my name is Tom, I am", @@ -59,12 +60,12 @@ def test_models_pcp_dcp_basic(): model = "vllm-ascend/DeepSeek-V3.2-W8A8-Pruning" with VllmRunner( model, - enforce_eager=True, max_model_len=1024, tensor_parallel_size=2, prefill_context_parallel_size=2, decode_context_parallel_size=2, enable_expert_parallel=True, + gpu_memory_utilization=0.2, block_size=128, quantization="ascend", ) as runner: @@ -84,6 +85,7 @@ def test_models_pcp_dcp_basic(): runner.model.generate(prompts, sampling_params) +@wait_until_npu_memory_free() def test_models_pcp_dcp_full_graph(): prompts = [ "The capital of France is", "Hello, my name is Tom, I am", @@ -121,6 +123,7 @@ def test_models_pcp_dcp_full_graph(): runner.model.generate(prompts, sampling_params) +@wait_until_npu_memory_free() def test_models_pcp_dcp_piece_wise(): prompts = [ "The capital of France is", "Hello, my name is Tom, I am", @@ -172,6 +175,7 @@ def test_pcp_basic(): runner.model.generate(prompts, sampling_params) +@wait_until_npu_memory_free() def test_pcp_full_graph(): prompts = [ "The capital of France is", "Hello, my name is Tom, I am", @@ -195,6 +199,7 @@ def test_pcp_full_graph(): runner.model.generate(prompts, sampling_params) +@wait_until_npu_memory_free() def test_pcp_piece_wise(): prompts = [ "The capital of France is", "Hello, my name is Tom, I am", @@ -214,6 +219,7 @@ def test_pcp_piece_wise(): runner.model.generate(prompts, sampling_params) +@wait_until_npu_memory_free() def test_dcp_basic(): prompts = [ "The capital of France is", "Hello, my name is Tom, I am", @@ -233,6 +239,7 @@ def test_dcp_basic(): runner.model.generate(prompts, sampling_params) +@wait_until_npu_memory_free() def test_dcp_full_graph(): prompts = [ "The capital of France is", "Hello, my name is Tom, I am", @@ -256,6 +263,7 @@ def test_dcp_full_graph(): runner.model.generate(prompts, sampling_params) +@wait_until_npu_memory_free() def test_dcp_piece_wise(): prompts = [ "The capital of France is", "Hello, my name is Tom, I am", diff --git a/vllm_ascend/attention/context_parallel/common_cp.py b/vllm_ascend/attention/context_parallel/common_cp.py index 6eb42040..e06b0e11 100644 --- a/vllm_ascend/attention/context_parallel/common_cp.py +++ b/vllm_ascend/attention/context_parallel/common_cp.py @@ -34,6 +34,7 @@ class AscendPCPMetadata: block_table_cp: torch.Tensor = None valid_block_ids: torch.Tensor = None prefill_q_cum_seqlens: torch.Tensor = None + block_arange: torch.Tensor = None @dataclass diff --git a/vllm_ascend/attention/context_parallel/sfa_cp.py b/vllm_ascend/attention/context_parallel/sfa_cp.py index 2b6f361d..dbf6d163 100644 --- a/vllm_ascend/attention/context_parallel/sfa_cp.py +++ b/vllm_ascend/attention/context_parallel/sfa_cp.py @@ -5,6 +5,7 @@ import torch import torch_npu from vllm.config import VllmConfig from vllm.distributed import get_dcp_group, get_pcp_group +from vllm.forward_context import get_forward_context from vllm.triton_utils import HAS_TRITON from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata @@ -55,6 +56,7 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder): dtype=torch.int32, device=device, ) + self.block_arange_buffer = torch.arange(self.pcp_size * self.dcp_size, dtype=torch.int32, device=device) def build( self, @@ -70,20 +72,7 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder): assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == common_attn_metadata.num_actual_tokens - block_table = metadata_cls.block_table - valid_block_ids, new_block_table = block_table.flatten().unique(return_inverse=True) - num_blocks = valid_block_ids.shape[0] - # Note(qcs): `block_table_cp` will have dirty values in the part beyond kv_lens. - # We assume that we can always get the correct kv_lens or kv index, - # so we omit the dirty value processing here. - block_table_cp = ( - new_block_table.unsqueeze(-1).to(block_table) - + (torch.arange(self.pcp_size * self.dcp_size) * num_blocks).view(1, 1, -1).to(block_table) - ).reshape(block_table.shape[0], -1) - - sfa_cp_metadata = self.build_cp_metadata( - block_table_cp, valid_block_ids, metadata_cls.seq_lens, common_attn_metadata - ) + sfa_cp_metadata = self.build_cp_metadata(self.block_arange_buffer, metadata_cls.seq_lens, common_attn_metadata) metadata_cls.num_decode_tokens = num_decode_tokens metadata_cls.num_decodes = num_decodes metadata_cls.num_prefills = num_prefills @@ -127,8 +116,7 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder): def build_cp_metadata( self, - block_table_cp: torch.Tensor, - valid_block_ids: torch.Tensor, + block_arange: torch.Tensor, seq_lens: torch.Tensor, common_attn_metadata: AscendCommonAttentionMetadata, ) -> AscendPCPMetadata | None: @@ -144,8 +132,7 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder): head_attn_nomask_seqlens=q_head_kv_lens, tail_attn_nomask_seqlens=q_tail_kv_lens, pcp_allgather_restore_idx=common_long_seq_metadata.pcp_allgather_restore_idx, - block_table_cp=block_table_cp, - valid_block_ids=valid_block_ids, + block_arange=block_arange, ) @@ -198,16 +185,17 @@ class AscendSFACPImpl(AscendSFAImpl): kv = kv_cache[0] key_rope = kv_cache[1] + block_table = attn_metadata.block_table assert attn_metadata.sfa_cp_metadata is not None - valid_block_ids = attn_metadata.sfa_cp_metadata.valid_block_ids - kv = self.gather_kv_cross_cp(kv, valid_block_ids) - key_rope = self.gather_kv_cross_cp(key_rope, valid_block_ids) - block_table = attn_metadata.sfa_cp_metadata.block_table_cp - + block_arange = attn_metadata.sfa_cp_metadata.block_arange + kv, block_table = self.gather_kv_cross_cp(kv, block_table, block_arange) + key_rope, _ = self.gather_kv_cross_cp(key_rope) + assert block_table is not None if self.pcp_size == 1: - return self._execute_sparse_flash_attention( + attn_output = self._execute_sparse_flash_attention( ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key ) + return self._align_to_graph_bucket_tokens(attn_output, attn_metadata) num_decodes = attn_metadata.num_decodes num_decode_tokens = attn_metadata.num_decode_tokens num_prefills = attn_metadata.num_prefills @@ -225,7 +213,7 @@ class AscendSFACPImpl(AscendSFAImpl): ) if num_prefills < 1: - return decode_attn_out + return self._align_to_graph_bucket_tokens(decode_attn_out, attn_metadata) # q split for head and tail q_head_idx = attn_metadata.sfa_cp_metadata.q_head_idx @@ -266,7 +254,30 @@ class AscendSFACPImpl(AscendSFAImpl): if decode_attn_out is not None: attn_output = torch.cat([decode_attn_out, attn_output], dim=0) - return attn_output + return self._align_to_graph_bucket_tokens(attn_output, attn_metadata) + + def _align_to_graph_bucket_tokens(self, attn_output: torch.Tensor | None, attn_metadata: M) -> torch.Tensor | None: + if attn_output is None: + return None + # In graph/piecewise mode, output buffer uses graph bucket token size + # (forward_context.num_tokens), while PCP path may compute only valid + # tokens. Align to the larger one to avoid later write-back mismatch. + forward_context = get_forward_context() + target_tokens = max( + attn_metadata.num_input_tokens, + forward_context.num_tokens if forward_context is not None else 0, + ) + + if attn_output.shape[0] == target_tokens: + return attn_output + aligned = torch.zeros( + (target_tokens, *attn_output.shape[1:]), + dtype=attn_output.dtype, + device=attn_output.device, + ) + valid_tokens = min(attn_output.shape[0], target_tokens) + aligned[:valid_tokens] = attn_output[:valid_tokens] + return aligned def _execute_sparse_flash_attention( self, ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key @@ -289,14 +300,20 @@ class AscendSFACPImpl(AscendSFAImpl): ) return attn_output - def gather_kv_cross_cp(self, kv_cache: torch.Tensor, valid_block_ids: torch.Tensor) -> torch.Tensor: + def gather_kv_cross_cp( + self, kv_cache: torch.Tensor, block_tables: torch.Tensor | None = None, block_arange: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor | None]: # Note(qcs): we need set kv_cache_interleave_size = block_size for sfa!!! - kv_cache = torch.index_select(kv_cache, 0, valid_block_ids) + block_num = kv_cache.shape[0] if self.dcp_size > 1: kv_cache = get_dcp_group().all_gather(kv_cache, 0) if self.pcp_size > 1: kv_cache = get_pcp_group().all_gather(kv_cache, 0) - return kv_cache + if block_tables is not None and block_arange is not None: + block_tables = ( + block_tables.unsqueeze(-1) + (block_arange * block_num).view(1, 1, -1).to(block_tables) + ).reshape(block_tables.shape[0], -1) + return kv_cache, block_tables def indexer_select_post_process( self, @@ -330,9 +347,11 @@ class AscendSFACPImpl(AscendSFAImpl): q = q_li key = kv_cache[2] + block_table = attn_metadata.block_table assert attn_metadata.sfa_cp_metadata is not None - key = self.gather_kv_cross_cp(key, attn_metadata.sfa_cp_metadata.valid_block_ids) - block_table = attn_metadata.sfa_cp_metadata.block_table_cp + block_arange = attn_metadata.sfa_cp_metadata.block_arange + + key, block_table = self.gather_kv_cross_cp(key, block_table, block_arange) if self.pcp_size == 1: return self._execute_indexer_select( @@ -353,7 +372,6 @@ class AscendSFACPImpl(AscendSFAImpl): actual_seq_lengths_key[:num_decodes], block_table[:num_decodes], ) - # prefill compute if num_prefills == 0: return decode_topk_indices