[Bugfix] Fix the bug in sfa-cp under multi-DP scenarios. (#4850)

### What this PR does / why we need it?
This PR fix the bug in sfa-cp under multi-DP scenarios.

### Does this PR introduce _any_ user-facing change?
None

### How was this patch tested?
None

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: zzhxx <2783294813@qq.com>
Co-authored-by: clrs97 <524936896@qq.com>
This commit is contained in:
zzhxxx
2025-12-11 16:44:14 +08:00
committed by GitHub
parent 5ebb9bd8d2
commit 2f965d8339

View File

@@ -66,8 +66,6 @@ class SfaCpContext:
local_start: int
local_end: int
local_end_with_pad: int
pad_size: int
local_pad_size: int
slot_mapping_cp: torch.Tensor
actual_seq_lengths_query: torch.Tensor
actual_seq_lengths_key: torch.Tensor
@@ -206,23 +204,41 @@ class AscendSFAMetadataBuilder:
sfa_cp_context = None
if self.enable_sfa_cp:
global_tp_size = get_tp_group().world_size
num_tokens = num_actual_tokens
num_tokens_pad = _round_up(num_actual_tokens, global_tp_size)
num_tokens = num_input_tokens
num_tokens_pad = _round_up(num_tokens, global_tp_size)
num_tokens_per_device = num_tokens_pad // global_tp_size
pad_size = num_tokens_pad - num_tokens
local_start = get_tp_group().rank_in_group * num_tokens_per_device
local_end_with_pad = local_start + num_tokens_per_device
local_end = min(local_end_with_pad, num_actual_tokens)
local_pad_size = local_end_with_pad - local_end
pad_size = num_tokens_pad - cos.shape[0]
assert cos.shape == sin.shape, \
f"cos.shape must be equal to sin.shape, got {cos.shape} and {sin.shape}"
if pad_size > 0:
cos = nn.functional.pad(cos, (0, 0, 0, 0, 0, 0, 0, pad_size))
sin = nn.functional.pad(sin, (0, 0, 0, 0, 0, 0, 0, pad_size))
slot_mapping = nn.functional.pad(slot_mapping, (0, pad_size),
pad_size_slot = num_tokens_pad - slot_mapping.shape[0]
if pad_size_slot > 0:
slot_mapping = nn.functional.pad(slot_mapping,
(0, pad_size_slot),
value=-1)
else:
slot_mapping = slot_mapping[:num_tokens_pad]
cos = cos[local_start:local_end_with_pad]
sin = sin[local_start:local_end_with_pad]
slot_mapping_cp = slot_mapping[local_start:local_end_with_pad]
assert cos.shape[0] == num_tokens_per_device, \
f"cos.shape[0] must be equal to num_tokens_per_device, \
got {cos.shape[0]} and {num_tokens_per_device}"
assert slot_mapping_cp.shape[0] == num_tokens_per_device, \
f"slot_mapping_cp.shape[0] must be equal to num_tokens_per_device, \
got {slot_mapping_cp.shape[0]} and {num_tokens_per_device}"
assert slot_mapping.shape[0] == num_tokens_pad, \
f"slot_mapping.shape[0] must be equal to num_tokens_pad, \
got {slot_mapping.shape[0]} and {num_tokens_pad}"
actual_seq_lengths_query = torch.empty_like(cum_query_lens)
actual_seq_lengths_key = torch.empty_like(seq_lens)
@@ -254,8 +270,6 @@ class AscendSFAMetadataBuilder:
local_start=local_start,
local_end=local_end,
local_end_with_pad=local_end_with_pad,
pad_size=pad_size,
local_pad_size=local_pad_size,
slot_mapping_cp=slot_mapping_cp,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,