[Bugfix] Fix the bug in sfa-cp under multi-DP scenarios. (#4850)
### What this PR does / why we need it?
This PR fix the bug in sfa-cp under multi-DP scenarios.
### Does this PR introduce _any_ user-facing change?
None
### How was this patch tested?
None
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: zzhxx <2783294813@qq.com>
Co-authored-by: clrs97 <524936896@qq.com>
This commit is contained in:
@@ -66,8 +66,6 @@ class SfaCpContext:
|
|||||||
local_start: int
|
local_start: int
|
||||||
local_end: int
|
local_end: int
|
||||||
local_end_with_pad: int
|
local_end_with_pad: int
|
||||||
pad_size: int
|
|
||||||
local_pad_size: int
|
|
||||||
slot_mapping_cp: torch.Tensor
|
slot_mapping_cp: torch.Tensor
|
||||||
actual_seq_lengths_query: torch.Tensor
|
actual_seq_lengths_query: torch.Tensor
|
||||||
actual_seq_lengths_key: torch.Tensor
|
actual_seq_lengths_key: torch.Tensor
|
||||||
@@ -206,23 +204,41 @@ class AscendSFAMetadataBuilder:
|
|||||||
sfa_cp_context = None
|
sfa_cp_context = None
|
||||||
if self.enable_sfa_cp:
|
if self.enable_sfa_cp:
|
||||||
global_tp_size = get_tp_group().world_size
|
global_tp_size = get_tp_group().world_size
|
||||||
num_tokens = num_actual_tokens
|
num_tokens = num_input_tokens
|
||||||
num_tokens_pad = _round_up(num_actual_tokens, global_tp_size)
|
num_tokens_pad = _round_up(num_tokens, global_tp_size)
|
||||||
num_tokens_per_device = num_tokens_pad // global_tp_size
|
num_tokens_per_device = num_tokens_pad // global_tp_size
|
||||||
pad_size = num_tokens_pad - num_tokens
|
|
||||||
local_start = get_tp_group().rank_in_group * num_tokens_per_device
|
local_start = get_tp_group().rank_in_group * num_tokens_per_device
|
||||||
local_end_with_pad = local_start + num_tokens_per_device
|
local_end_with_pad = local_start + num_tokens_per_device
|
||||||
local_end = min(local_end_with_pad, num_actual_tokens)
|
local_end = min(local_end_with_pad, num_actual_tokens)
|
||||||
local_pad_size = local_end_with_pad - local_end
|
|
||||||
|
pad_size = num_tokens_pad - cos.shape[0]
|
||||||
|
assert cos.shape == sin.shape, \
|
||||||
|
f"cos.shape must be equal to sin.shape, got {cos.shape} and {sin.shape}"
|
||||||
|
|
||||||
if pad_size > 0:
|
if pad_size > 0:
|
||||||
cos = nn.functional.pad(cos, (0, 0, 0, 0, 0, 0, 0, pad_size))
|
cos = nn.functional.pad(cos, (0, 0, 0, 0, 0, 0, 0, pad_size))
|
||||||
sin = nn.functional.pad(sin, (0, 0, 0, 0, 0, 0, 0, pad_size))
|
sin = nn.functional.pad(sin, (0, 0, 0, 0, 0, 0, 0, pad_size))
|
||||||
slot_mapping = nn.functional.pad(slot_mapping, (0, pad_size),
|
|
||||||
|
pad_size_slot = num_tokens_pad - slot_mapping.shape[0]
|
||||||
|
if pad_size_slot > 0:
|
||||||
|
slot_mapping = nn.functional.pad(slot_mapping,
|
||||||
|
(0, pad_size_slot),
|
||||||
value=-1)
|
value=-1)
|
||||||
|
else:
|
||||||
|
slot_mapping = slot_mapping[:num_tokens_pad]
|
||||||
|
|
||||||
cos = cos[local_start:local_end_with_pad]
|
cos = cos[local_start:local_end_with_pad]
|
||||||
sin = sin[local_start:local_end_with_pad]
|
sin = sin[local_start:local_end_with_pad]
|
||||||
slot_mapping_cp = slot_mapping[local_start:local_end_with_pad]
|
slot_mapping_cp = slot_mapping[local_start:local_end_with_pad]
|
||||||
|
assert cos.shape[0] == num_tokens_per_device, \
|
||||||
|
f"cos.shape[0] must be equal to num_tokens_per_device, \
|
||||||
|
got {cos.shape[0]} and {num_tokens_per_device}"
|
||||||
|
assert slot_mapping_cp.shape[0] == num_tokens_per_device, \
|
||||||
|
f"slot_mapping_cp.shape[0] must be equal to num_tokens_per_device, \
|
||||||
|
got {slot_mapping_cp.shape[0]} and {num_tokens_per_device}"
|
||||||
|
assert slot_mapping.shape[0] == num_tokens_pad, \
|
||||||
|
f"slot_mapping.shape[0] must be equal to num_tokens_pad, \
|
||||||
|
got {slot_mapping.shape[0]} and {num_tokens_pad}"
|
||||||
|
|
||||||
actual_seq_lengths_query = torch.empty_like(cum_query_lens)
|
actual_seq_lengths_query = torch.empty_like(cum_query_lens)
|
||||||
actual_seq_lengths_key = torch.empty_like(seq_lens)
|
actual_seq_lengths_key = torch.empty_like(seq_lens)
|
||||||
@@ -254,8 +270,6 @@ class AscendSFAMetadataBuilder:
|
|||||||
local_start=local_start,
|
local_start=local_start,
|
||||||
local_end=local_end,
|
local_end=local_end,
|
||||||
local_end_with_pad=local_end_with_pad,
|
local_end_with_pad=local_end_with_pad,
|
||||||
pad_size=pad_size,
|
|
||||||
local_pad_size=local_pad_size,
|
|
||||||
slot_mapping_cp=slot_mapping_cp,
|
slot_mapping_cp=slot_mapping_cp,
|
||||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||||
actual_seq_lengths_key=actual_seq_lengths_key,
|
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||||
|
|||||||
Reference in New Issue
Block a user