[main][bugfix] Fix bugs and refactor cached mask generation logic (#2442)

### What this PR does / why we need it?
This PR fix bugs and refactor cached mask generation logic. Now just
pre-construct and use the cached mask on cpu instead of device on npu.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.10.1.1
- vLLM main:
9b5f64238f

Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
rjg-lyh
2025-08-27 12:07:29 +08:00
committed by GitHub
parent 6881c19458
commit 2bfbf9b9b3
4 changed files with 97 additions and 136 deletions

View File

@@ -28,23 +28,32 @@ class TestAttentionMaskBuilder(TestBase):
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
torch.float16)
self.assertEqual(attention_mask_builder.splitfuse_mask_value, -10000)
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
(1024, 1024))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(float("-inf"), dtype=torch.float16))
# generate attention_mask_builder with int8
attention_mask_builder = AttentionMaskBuilder(max_seq_len=512,
dtype=torch.int8)
self.assertEqual(attention_mask_builder._seq_len_cached, 512)
# generate attention_mask_builder with bfloat16
attention_mask_builder = AttentionMaskBuilder(max_seq_len=2048,
dtype=torch.bfloat16)
self.assertEqual(attention_mask_builder._seq_len_cached, 2048)
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
torch.int8)
self.assertEqual(attention_mask_builder.splitfuse_mask_value, -10000)
torch.bfloat16)
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
(512, 512))
(2048, 2048))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(1, dtype=torch.int8))
torch.tensor(1, dtype=torch.bfloat16))
def test_get_mask_scale_factor(self):
# supported data types
self.assertEqual(
AttentionMaskBuilder.get_mask_scale_factor(torch.float16), 1)
self.assertEqual(
AttentionMaskBuilder.get_mask_scale_factor(torch.bfloat16), -10000)
# mask_scale_factor now only supports data types: torch.float16 and torch.bfloat16
# Otherwise raise ValueError
with self.assertRaises(ValueError):
AttentionMaskBuilder.get_mask_scale_factor(torch.int8)
def test_get_attn_mask(self):
# if the len is less than max_seq_len, the attn_mask_cache will not be updated
@@ -77,80 +86,48 @@ class TestAttentionMaskBuilder(TestBase):
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
dtype=torch.float16)
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=[512],
query_lens=[512],
position=torch.tensor([0]),
seq_lens=torch.tensor([10, 20, 100]),
position=torch.tensor([7, 8, 9, 18, 19, 99]),
dtype=torch.float16,
device=torch.device("cpu"),
)
self.assertEqual(attn_mask.shape, (1, 512))
self.assertEqual(attn_mask.shape, (6, 100))
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=[2048],
query_lens=[1024],
position=torch.tensor([0]),
seq_lens=torch.tensor([10, 3000, 2000]),
position=torch.tensor([7, 8, 9, 2999, 1999]),
dtype=torch.float16,
device=torch.device("cpu"),
)
self.assertEqual(attn_mask.shape, (1024, 2048))
self.assertEqual(attn_mask.shape, (5, 3000))
self.assertEqual(attention_mask_builder._seq_len_cached, 3000)
# splitfuse_attn_mask now only supports data types: torch.float16 and torch.bfloat16
# otherwise raise ValueError
with self.assertRaises(ValueError):
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=torch.tensor([10, 20, 100]),
position=torch.tensor([7, 8, 9, 18, 19, 99]),
dtype=torch.int8,
device=torch.device("cpu"),
)
def test_mask_value_cleanliness(self):
attention_mask_builder = AttentionMaskBuilder(max_seq_len=6,
dtype=torch.bfloat16)
self.assertEqual(attention_mask_builder.attn_mask_cache[-2][-1],
torch.tensor(1, dtype=torch.bfloat16))
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
dtype=torch.int8)
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=[512],
query_lens=[512],
position=torch.tensor([0]),
dtype=torch.int8,
seq_lens=torch.tensor([6]),
position=torch.tensor([3, 4, 5]),
dtype=torch.bfloat16,
device=torch.device("cpu"),
)
self.assertEqual(attn_mask.shape, (1, 512))
def test_use_multiple_masks(self):
max_seq_lens = [128, 512, 1024]
dtypes = [torch.float16, torch.bfloat16, torch.int8]
for max_seq_len, dtype in zip(max_seq_lens, dtypes):
with self.subTest(max_seq_len=max_seq_len, dtype=dtype):
self._test_use_multiple_masks(max_seq_len, dtype)
def _test_use_multiple_masks(self, max_seq_len, dtype):
expected_mask_value = torch.finfo(
torch.float32).min if dtype == torch.float16 else 1
if dtype == torch.float16:
expected_splitfuse_mask_value = expected_mask_value
elif dtype == torch.bfloat16:
expected_splitfuse_mask_value = -10000
else:
assert dtype == torch.int8, "Unsupported dtype for attention mask"
expected_splitfuse_mask_value = -16
attention_mask_builder = AttentionMaskBuilder(max_seq_len=max_seq_len,
dtype=dtype)
splitfuse_attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=[max_seq_len],
query_lens=[max_seq_len],
position=torch.tensor([0]),
dtype=dtype,
device=torch.device("cpu"),
)
self.assertEqual(splitfuse_attn_mask.shape, (1, max_seq_len))
self.assertEqual(
splitfuse_attn_mask[0][-1],
torch.tensor(expected_splitfuse_mask_value, dtype=dtype))
self.assertEqual(attention_mask_builder._seq_len_cached, max_seq_len)
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
(max_seq_len, max_seq_len))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(expected_mask_value, dtype=dtype))
attn_mask = attention_mask_builder.get_attn_mask(
max_seq_len=max_seq_len, dtype=dtype, device=torch.device("cpu"))
self.assertEqual(attn_mask.shape, (max_seq_len, max_seq_len))
self.assertEqual(attn_mask[0][-1],
torch.tensor(expected_mask_value, dtype=dtype))
self.assertEqual(attention_mask_builder._seq_len_cached, max_seq_len)
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
(max_seq_len, max_seq_len))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(expected_mask_value, dtype=dtype))
attn_mask[-2][-1],
torch.tensor(-10000, dtype=torch.bfloat16,
device=attn_mask.device))
self.assertEqual(attention_mask_builder.attn_mask_cache[-2][-1],
torch.tensor(1, dtype=torch.bfloat16))

View File

@@ -44,61 +44,50 @@ class AttentionMaskBuilder:
self._seq_len_cached = attn_mask.shape[0]
self.attn_mask_cache = attn_mask
self.splitfuse_mask_value = -10000
@staticmethod
def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
if dtype == torch.float16:
mask_scale_factor = 1
elif dtype == torch.bfloat16:
mask_scale_factor = -10000
else:
raise ValueError(
"The current operation now only supports data types: torch.float16 and "
"torch.bfloat16. Please ensure the input is of one of these types."
)
return mask_scale_factor
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
device: torch.device):
self._update_attn_cache(max_seq_len, dtype, device)
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous()
self._update_attn_cache(max_seq_len, dtype)
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
).to(device)
def get_splitfuse_attn_mask(
self,
seq_lens,
query_lens,
position,
dtype,
device,
seq_lens: torch.Tensor,
position: torch.Tensor,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
if dtype not in [torch.float16, torch.bfloat16]:
raise ValueError(
"splitfuse_attn_mask now only supports bf16 and fp16")
max_seq_len = max(seq_lens, default=0)
if max_seq_len <= self._seq_len_cached:
self._update_attn_cache(max_seq_len, dtype, device)
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
# is not the same. Fix this in the future when kernel is ready.
if self.attn_mask_cache.numel(
) > 1 and self.attn_mask_cache[0][1] > 0:
attn_mask = self.get_attn_mask( # type: ignore
max_seq_len, dtype, device)
# Do not use in-place multiplication to avoid modifying `self.attn_mask_cache`!
attn_mask = attn_mask * -10000
else:
attn_mask = self.attn_mask_cache
return torch.index_select(attn_mask, dim=0,
index=position)[:, :max_seq_len]
total_q_len = sum(query_lens)
attn_mask = torch.zeros((total_q_len, max_seq_len),
dtype=dtype,
device="cpu")
current_row = 0
for i in range(len(query_lens)):
seq_len = seq_lens[i]
q_len = query_lens[i]
context_len = seq_len - q_len
self._update_attn_cache(max_seq_len, dtype)
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
# is not the same. Fix this in the future when kernel is ready.
mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(dtype)
attn_mask = torch.index_select(self.attn_mask_cache,
dim=0,
index=position)[:, :max_seq_len]
attn_mask *= mask_scale_factor
return attn_mask.contiguous().to(device, non_blocking=True)
assert context_len >= 0
attn_mask[current_row:current_row + q_len,
context_len:] = self.splitfuse_mask_value
right_tensor = attn_mask[current_row:current_row + q_len,
context_len:seq_len]
right_tensor.masked_fill_(
right_tensor.tril() == self.splitfuse_mask_value, 0)
current_row += q_len
return attn_mask.to(device, non_blocking=True)
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype,
device: torch.device):
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
if seqlen > self._seq_len_cached:
self._seq_len_cached = seqlen
self.attn_mask_cache = _generate_attn_mask(seqlen, dtype)
if self.attn_mask_cache.device != device:
self.attn_mask_cache = self.attn_mask_cache.to(device)
if self.attn_mask_cache.dtype != dtype:
self.attn_mask_cache = self.attn_mask_cache.to(dtype)

View File

@@ -79,11 +79,10 @@ class EagleProposer:
def _make_attention_mask(
self,
seq_lens,
query_lens,
position,
) -> torch.Tensor:
return self.attn_mask_builder.get_splitfuse_attn_mask(
seq_lens, query_lens, position, self.dtype, self.device)
seq_lens, position, self.dtype, self.device)
def propose(
self,
@@ -247,7 +246,6 @@ class EagleProposer:
positions = positions_cpu.to(device)
attn_mask = self._make_attention_mask(
seq_lens=attn_metadata.seq_lens,
query_lens=attn_metadata.max_query_len,
position=positions,
)
attn_metadata.attn_mask = attn_mask

View File

@@ -20,7 +20,6 @@
import copy
import gc
import math
import os
import time
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
@@ -233,8 +232,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
vllm_config, device)
self.attn_mask_builder = AttentionMaskBuilder(
min(self.model_config.max_model_len,
int(os.getenv("PAGED_ATTENTION_MASK_LEN", 10000))), self.dtype)
self.model_config.max_model_len, self.dtype)
# Set up speculative decoding.
self.use_aux_hidden_state_outputs = False
@@ -820,12 +818,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return tuple(tasks)
def _make_attention_mask(self, seq_lens, query_lens, position,
def _make_attention_mask(self, seq_lens, position,
attn_state) -> torch.Tensor:
# Chunk Prefill situation.
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla:
return self.attn_mask_builder.get_splitfuse_attn_mask(
seq_lens, query_lens, position, self.dtype, self.device)
seq_lens, position, self.dtype, self.device)
# Prefill without cache situation.
elif attn_state == AscendAttentionState.PrefillNoCache:
max_seq_len = max(seq_lens, default=0)
@@ -1126,16 +1124,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
non_blocking=True)
self.positions[total_num_scheduled_tokens:num_input_tokens].zero_()
self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
self.positions[:num_input_tokens].copy_(
self.positions_cpu[:num_input_tokens], non_blocking=True)
positions_cpu = self.positions_cpu[:num_input_tokens]
positions = self.positions[:num_input_tokens]
self.query_lens = torch.from_numpy(num_scheduled_tokens)
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
seq_lens = self.seq_lens_cpu[:num_reqs]
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions_np // self.block_size)
@@ -1150,11 +1149,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
num_valid_tokens)
self.attn_mask = self._make_attention_mask(
seq_lens=seq_lens,
query_lens=num_scheduled_tokens,
position=self.positions[:num_input_tokens],
attn_state=attn_state)
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
position=positions_cpu,
attn_state=attn_state)
self.attn_state = attn_state # type: ignore
self.query_start_loc_np[0] = 0