[main][bugfix] Fix bugs and refactor cached mask generation logic (#2442)
### What this PR does / why we need it?
This PR fix bugs and refactor cached mask generation logic. Now just
pre-construct and use the cached mask on cpu instead of device on npu.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.10.1.1
- vLLM main:
9b5f64238f
Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -44,61 +44,50 @@ class AttentionMaskBuilder:
|
||||
|
||||
self._seq_len_cached = attn_mask.shape[0]
|
||||
self.attn_mask_cache = attn_mask
|
||||
self.splitfuse_mask_value = -10000
|
||||
|
||||
@staticmethod
|
||||
def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
|
||||
if dtype == torch.float16:
|
||||
mask_scale_factor = 1
|
||||
elif dtype == torch.bfloat16:
|
||||
mask_scale_factor = -10000
|
||||
else:
|
||||
raise ValueError(
|
||||
"The current operation now only supports data types: torch.float16 and "
|
||||
"torch.bfloat16. Please ensure the input is of one of these types."
|
||||
)
|
||||
return mask_scale_factor
|
||||
|
||||
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
|
||||
device: torch.device):
|
||||
self._update_attn_cache(max_seq_len, dtype, device)
|
||||
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous()
|
||||
self._update_attn_cache(max_seq_len, dtype)
|
||||
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
|
||||
).to(device)
|
||||
|
||||
def get_splitfuse_attn_mask(
|
||||
self,
|
||||
seq_lens,
|
||||
query_lens,
|
||||
position,
|
||||
dtype,
|
||||
device,
|
||||
seq_lens: torch.Tensor,
|
||||
position: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
if dtype not in [torch.float16, torch.bfloat16]:
|
||||
raise ValueError(
|
||||
"splitfuse_attn_mask now only supports bf16 and fp16")
|
||||
max_seq_len = max(seq_lens, default=0)
|
||||
if max_seq_len <= self._seq_len_cached:
|
||||
self._update_attn_cache(max_seq_len, dtype, device)
|
||||
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
|
||||
# is not the same. Fix this in the future when kernel is ready.
|
||||
if self.attn_mask_cache.numel(
|
||||
) > 1 and self.attn_mask_cache[0][1] > 0:
|
||||
attn_mask = self.get_attn_mask( # type: ignore
|
||||
max_seq_len, dtype, device)
|
||||
# Do not use in-place multiplication to avoid modifying `self.attn_mask_cache`!
|
||||
attn_mask = attn_mask * -10000
|
||||
else:
|
||||
attn_mask = self.attn_mask_cache
|
||||
return torch.index_select(attn_mask, dim=0,
|
||||
index=position)[:, :max_seq_len]
|
||||
total_q_len = sum(query_lens)
|
||||
attn_mask = torch.zeros((total_q_len, max_seq_len),
|
||||
dtype=dtype,
|
||||
device="cpu")
|
||||
current_row = 0
|
||||
for i in range(len(query_lens)):
|
||||
seq_len = seq_lens[i]
|
||||
q_len = query_lens[i]
|
||||
context_len = seq_len - q_len
|
||||
self._update_attn_cache(max_seq_len, dtype)
|
||||
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
|
||||
# is not the same. Fix this in the future when kernel is ready.
|
||||
mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(dtype)
|
||||
attn_mask = torch.index_select(self.attn_mask_cache,
|
||||
dim=0,
|
||||
index=position)[:, :max_seq_len]
|
||||
attn_mask *= mask_scale_factor
|
||||
return attn_mask.contiguous().to(device, non_blocking=True)
|
||||
|
||||
assert context_len >= 0
|
||||
attn_mask[current_row:current_row + q_len,
|
||||
context_len:] = self.splitfuse_mask_value
|
||||
right_tensor = attn_mask[current_row:current_row + q_len,
|
||||
context_len:seq_len]
|
||||
right_tensor.masked_fill_(
|
||||
right_tensor.tril() == self.splitfuse_mask_value, 0)
|
||||
current_row += q_len
|
||||
|
||||
return attn_mask.to(device, non_blocking=True)
|
||||
|
||||
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype,
|
||||
device: torch.device):
|
||||
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
|
||||
if seqlen > self._seq_len_cached:
|
||||
self._seq_len_cached = seqlen
|
||||
self.attn_mask_cache = _generate_attn_mask(seqlen, dtype)
|
||||
if self.attn_mask_cache.device != device:
|
||||
self.attn_mask_cache = self.attn_mask_cache.to(device)
|
||||
if self.attn_mask_cache.dtype != dtype:
|
||||
self.attn_mask_cache = self.attn_mask_cache.to(dtype)
|
||||
|
||||
Reference in New Issue
Block a user