From 2f965d833930455ca099693f4a82ab1fb60dc960 Mon Sep 17 00:00:00 2001 From: zzhxxx <2783294813@qq.com> Date: Thu, 11 Dec 2025 16:44:14 +0800 Subject: [PATCH] [Bugfix] Fix the bug in sfa-cp under multi-DP scenarios. (#4850) ### What this PR does / why we need it? This PR fix the bug in sfa-cp under multi-DP scenarios. ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? None - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: zzhxx <2783294813@qq.com> Co-authored-by: clrs97 <524936896@qq.com> --- vllm_ascend/attention/sfa_v1.py | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 3a962b87..03eed5a7 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -66,8 +66,6 @@ class SfaCpContext: local_start: int local_end: int local_end_with_pad: int - pad_size: int - local_pad_size: int slot_mapping_cp: torch.Tensor actual_seq_lengths_query: torch.Tensor actual_seq_lengths_key: torch.Tensor @@ -206,23 +204,41 @@ class AscendSFAMetadataBuilder: sfa_cp_context = None if self.enable_sfa_cp: global_tp_size = get_tp_group().world_size - num_tokens = num_actual_tokens - num_tokens_pad = _round_up(num_actual_tokens, global_tp_size) + num_tokens = num_input_tokens + num_tokens_pad = _round_up(num_tokens, global_tp_size) num_tokens_per_device = num_tokens_pad // global_tp_size - pad_size = num_tokens_pad - num_tokens local_start = get_tp_group().rank_in_group * num_tokens_per_device local_end_with_pad = local_start + num_tokens_per_device local_end = min(local_end_with_pad, num_actual_tokens) - local_pad_size = local_end_with_pad - local_end + + pad_size = num_tokens_pad - cos.shape[0] + assert cos.shape == sin.shape, \ + f"cos.shape must be equal to sin.shape, got {cos.shape} and {sin.shape}" if pad_size > 0: cos = nn.functional.pad(cos, (0, 0, 0, 0, 0, 0, 0, pad_size)) sin = nn.functional.pad(sin, (0, 0, 0, 0, 0, 0, 0, pad_size)) - slot_mapping = nn.functional.pad(slot_mapping, (0, pad_size), + + pad_size_slot = num_tokens_pad - slot_mapping.shape[0] + if pad_size_slot > 0: + slot_mapping = nn.functional.pad(slot_mapping, + (0, pad_size_slot), value=-1) + else: + slot_mapping = slot_mapping[:num_tokens_pad] + cos = cos[local_start:local_end_with_pad] sin = sin[local_start:local_end_with_pad] slot_mapping_cp = slot_mapping[local_start:local_end_with_pad] + assert cos.shape[0] == num_tokens_per_device, \ + f"cos.shape[0] must be equal to num_tokens_per_device, \ + got {cos.shape[0]} and {num_tokens_per_device}" + assert slot_mapping_cp.shape[0] == num_tokens_per_device, \ + f"slot_mapping_cp.shape[0] must be equal to num_tokens_per_device, \ + got {slot_mapping_cp.shape[0]} and {num_tokens_per_device}" + assert slot_mapping.shape[0] == num_tokens_pad, \ + f"slot_mapping.shape[0] must be equal to num_tokens_pad, \ + got {slot_mapping.shape[0]} and {num_tokens_pad}" actual_seq_lengths_query = torch.empty_like(cum_query_lens) actual_seq_lengths_key = torch.empty_like(seq_lens) @@ -254,8 +270,6 @@ class AscendSFAMetadataBuilder: local_start=local_start, local_end=local_end, local_end_with_pad=local_end_with_pad, - pad_size=pad_size, - local_pad_size=local_pad_size, slot_mapping_cp=slot_mapping_cp, actual_seq_lengths_query=actual_seq_lengths_query, actual_seq_lengths_key=actual_seq_lengths_key,