[Bugfix] Fix the bug in sfa-cp under multi-DP scenarios. (#4850)
### What this PR does / why we need it?
This PR fix the bug in sfa-cp under multi-DP scenarios.
### Does this PR introduce _any_ user-facing change?
None
### How was this patch tested?
None
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: zzhxx <2783294813@qq.com>
Co-authored-by: clrs97 <524936896@qq.com>
This commit is contained in:
@@ -66,8 +66,6 @@ class SfaCpContext:
|
||||
local_start: int
|
||||
local_end: int
|
||||
local_end_with_pad: int
|
||||
pad_size: int
|
||||
local_pad_size: int
|
||||
slot_mapping_cp: torch.Tensor
|
||||
actual_seq_lengths_query: torch.Tensor
|
||||
actual_seq_lengths_key: torch.Tensor
|
||||
@@ -206,23 +204,41 @@ class AscendSFAMetadataBuilder:
|
||||
sfa_cp_context = None
|
||||
if self.enable_sfa_cp:
|
||||
global_tp_size = get_tp_group().world_size
|
||||
num_tokens = num_actual_tokens
|
||||
num_tokens_pad = _round_up(num_actual_tokens, global_tp_size)
|
||||
num_tokens = num_input_tokens
|
||||
num_tokens_pad = _round_up(num_tokens, global_tp_size)
|
||||
num_tokens_per_device = num_tokens_pad // global_tp_size
|
||||
pad_size = num_tokens_pad - num_tokens
|
||||
local_start = get_tp_group().rank_in_group * num_tokens_per_device
|
||||
local_end_with_pad = local_start + num_tokens_per_device
|
||||
local_end = min(local_end_with_pad, num_actual_tokens)
|
||||
local_pad_size = local_end_with_pad - local_end
|
||||
|
||||
pad_size = num_tokens_pad - cos.shape[0]
|
||||
assert cos.shape == sin.shape, \
|
||||
f"cos.shape must be equal to sin.shape, got {cos.shape} and {sin.shape}"
|
||||
|
||||
if pad_size > 0:
|
||||
cos = nn.functional.pad(cos, (0, 0, 0, 0, 0, 0, 0, pad_size))
|
||||
sin = nn.functional.pad(sin, (0, 0, 0, 0, 0, 0, 0, pad_size))
|
||||
slot_mapping = nn.functional.pad(slot_mapping, (0, pad_size),
|
||||
|
||||
pad_size_slot = num_tokens_pad - slot_mapping.shape[0]
|
||||
if pad_size_slot > 0:
|
||||
slot_mapping = nn.functional.pad(slot_mapping,
|
||||
(0, pad_size_slot),
|
||||
value=-1)
|
||||
else:
|
||||
slot_mapping = slot_mapping[:num_tokens_pad]
|
||||
|
||||
cos = cos[local_start:local_end_with_pad]
|
||||
sin = sin[local_start:local_end_with_pad]
|
||||
slot_mapping_cp = slot_mapping[local_start:local_end_with_pad]
|
||||
assert cos.shape[0] == num_tokens_per_device, \
|
||||
f"cos.shape[0] must be equal to num_tokens_per_device, \
|
||||
got {cos.shape[0]} and {num_tokens_per_device}"
|
||||
assert slot_mapping_cp.shape[0] == num_tokens_per_device, \
|
||||
f"slot_mapping_cp.shape[0] must be equal to num_tokens_per_device, \
|
||||
got {slot_mapping_cp.shape[0]} and {num_tokens_per_device}"
|
||||
assert slot_mapping.shape[0] == num_tokens_pad, \
|
||||
f"slot_mapping.shape[0] must be equal to num_tokens_pad, \
|
||||
got {slot_mapping.shape[0]} and {num_tokens_pad}"
|
||||
|
||||
actual_seq_lengths_query = torch.empty_like(cum_query_lens)
|
||||
actual_seq_lengths_key = torch.empty_like(seq_lens)
|
||||
@@ -254,8 +270,6 @@ class AscendSFAMetadataBuilder:
|
||||
local_start=local_start,
|
||||
local_end=local_end,
|
||||
local_end_with_pad=local_end_with_pad,
|
||||
pad_size=pad_size,
|
||||
local_pad_size=local_pad_size,
|
||||
slot_mapping_cp=slot_mapping_cp,
|
||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||
|
||||
Reference in New Issue
Block a user