[main][bugfix] Fix bugs and refactor cached mask generation logic (#2442)

### What this PR does / why we need it?
This PR fix bugs and refactor cached mask generation logic. Now just
pre-construct and use the cached mask on cpu instead of device on npu.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.10.1.1
- vLLM main:
9b5f64238f

Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
rjg-lyh
2025-08-27 12:07:29 +08:00
committed by GitHub
parent 6881c19458
commit 2bfbf9b9b3
4 changed files with 97 additions and 136 deletions

View File

@@ -28,23 +28,32 @@ class TestAttentionMaskBuilder(TestBase):
self.assertEqual(attention_mask_builder._seq_len_cached, 1024) self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype, self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
torch.float16) torch.float16)
self.assertEqual(attention_mask_builder.splitfuse_mask_value, -10000)
self.assertEqual(attention_mask_builder.attn_mask_cache.shape, self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
(1024, 1024)) (1024, 1024))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1], self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(float("-inf"), dtype=torch.float16)) torch.tensor(float("-inf"), dtype=torch.float16))
# generate attention_mask_builder with int8 # generate attention_mask_builder with bfloat16
attention_mask_builder = AttentionMaskBuilder(max_seq_len=512, attention_mask_builder = AttentionMaskBuilder(max_seq_len=2048,
dtype=torch.int8) dtype=torch.bfloat16)
self.assertEqual(attention_mask_builder._seq_len_cached, 512) self.assertEqual(attention_mask_builder._seq_len_cached, 2048)
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype, self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
torch.int8) torch.bfloat16)
self.assertEqual(attention_mask_builder.splitfuse_mask_value, -10000)
self.assertEqual(attention_mask_builder.attn_mask_cache.shape, self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
(512, 512)) (2048, 2048))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1], self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(1, dtype=torch.int8)) torch.tensor(1, dtype=torch.bfloat16))
def test_get_mask_scale_factor(self):
# supported data types
self.assertEqual(
AttentionMaskBuilder.get_mask_scale_factor(torch.float16), 1)
self.assertEqual(
AttentionMaskBuilder.get_mask_scale_factor(torch.bfloat16), -10000)
# mask_scale_factor now only supports data types: torch.float16 and torch.bfloat16
# Otherwise raise ValueError
with self.assertRaises(ValueError):
AttentionMaskBuilder.get_mask_scale_factor(torch.int8)
def test_get_attn_mask(self): def test_get_attn_mask(self):
# if the len is less than max_seq_len, the attn_mask_cache will not be updated # if the len is less than max_seq_len, the attn_mask_cache will not be updated
@@ -77,80 +86,48 @@ class TestAttentionMaskBuilder(TestBase):
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024, attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
dtype=torch.float16) dtype=torch.float16)
attn_mask = attention_mask_builder.get_splitfuse_attn_mask( attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=[512], seq_lens=torch.tensor([10, 20, 100]),
query_lens=[512], position=torch.tensor([7, 8, 9, 18, 19, 99]),
position=torch.tensor([0]),
dtype=torch.float16, dtype=torch.float16,
device=torch.device("cpu"), device=torch.device("cpu"),
) )
self.assertEqual(attn_mask.shape, (1, 512)) self.assertEqual(attn_mask.shape, (6, 100))
self.assertEqual(attention_mask_builder._seq_len_cached, 1024) self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
attn_mask = attention_mask_builder.get_splitfuse_attn_mask( attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=[2048], seq_lens=torch.tensor([10, 3000, 2000]),
query_lens=[1024], position=torch.tensor([7, 8, 9, 2999, 1999]),
position=torch.tensor([0]),
dtype=torch.float16, dtype=torch.float16,
device=torch.device("cpu"), device=torch.device("cpu"),
) )
self.assertEqual(attn_mask.shape, (1024, 2048)) self.assertEqual(attn_mask.shape, (5, 3000))
self.assertEqual(attention_mask_builder._seq_len_cached, 3000)
# splitfuse_attn_mask now only supports data types: torch.float16 and torch.bfloat16
# otherwise raise ValueError
with self.assertRaises(ValueError):
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=torch.tensor([10, 20, 100]),
position=torch.tensor([7, 8, 9, 18, 19, 99]),
dtype=torch.int8,
device=torch.device("cpu"),
)
def test_mask_value_cleanliness(self):
attention_mask_builder = AttentionMaskBuilder(max_seq_len=6,
dtype=torch.bfloat16)
self.assertEqual(attention_mask_builder.attn_mask_cache[-2][-1],
torch.tensor(1, dtype=torch.bfloat16))
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
dtype=torch.int8)
attn_mask = attention_mask_builder.get_splitfuse_attn_mask( attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=[512], seq_lens=torch.tensor([6]),
query_lens=[512], position=torch.tensor([3, 4, 5]),
position=torch.tensor([0]), dtype=torch.bfloat16,
dtype=torch.int8,
device=torch.device("cpu"), device=torch.device("cpu"),
) )
self.assertEqual(attn_mask.shape, (1, 512))
def test_use_multiple_masks(self):
max_seq_lens = [128, 512, 1024]
dtypes = [torch.float16, torch.bfloat16, torch.int8]
for max_seq_len, dtype in zip(max_seq_lens, dtypes):
with self.subTest(max_seq_len=max_seq_len, dtype=dtype):
self._test_use_multiple_masks(max_seq_len, dtype)
def _test_use_multiple_masks(self, max_seq_len, dtype):
expected_mask_value = torch.finfo(
torch.float32).min if dtype == torch.float16 else 1
if dtype == torch.float16:
expected_splitfuse_mask_value = expected_mask_value
elif dtype == torch.bfloat16:
expected_splitfuse_mask_value = -10000
else:
assert dtype == torch.int8, "Unsupported dtype for attention mask"
expected_splitfuse_mask_value = -16
attention_mask_builder = AttentionMaskBuilder(max_seq_len=max_seq_len,
dtype=dtype)
splitfuse_attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=[max_seq_len],
query_lens=[max_seq_len],
position=torch.tensor([0]),
dtype=dtype,
device=torch.device("cpu"),
)
self.assertEqual(splitfuse_attn_mask.shape, (1, max_seq_len))
self.assertEqual( self.assertEqual(
splitfuse_attn_mask[0][-1], attn_mask[-2][-1],
torch.tensor(expected_splitfuse_mask_value, dtype=dtype)) torch.tensor(-10000, dtype=torch.bfloat16,
self.assertEqual(attention_mask_builder._seq_len_cached, max_seq_len) device=attn_mask.device))
self.assertEqual(attention_mask_builder.attn_mask_cache.shape, self.assertEqual(attention_mask_builder.attn_mask_cache[-2][-1],
(max_seq_len, max_seq_len)) torch.tensor(1, dtype=torch.bfloat16))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(expected_mask_value, dtype=dtype))
attn_mask = attention_mask_builder.get_attn_mask(
max_seq_len=max_seq_len, dtype=dtype, device=torch.device("cpu"))
self.assertEqual(attn_mask.shape, (max_seq_len, max_seq_len))
self.assertEqual(attn_mask[0][-1],
torch.tensor(expected_mask_value, dtype=dtype))
self.assertEqual(attention_mask_builder._seq_len_cached, max_seq_len)
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
(max_seq_len, max_seq_len))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(expected_mask_value, dtype=dtype))

View File

@@ -44,61 +44,50 @@ class AttentionMaskBuilder:
self._seq_len_cached = attn_mask.shape[0] self._seq_len_cached = attn_mask.shape[0]
self.attn_mask_cache = attn_mask self.attn_mask_cache = attn_mask
self.splitfuse_mask_value = -10000
@staticmethod
def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
if dtype == torch.float16:
mask_scale_factor = 1
elif dtype == torch.bfloat16:
mask_scale_factor = -10000
else:
raise ValueError(
"The current operation now only supports data types: torch.float16 and "
"torch.bfloat16. Please ensure the input is of one of these types."
)
return mask_scale_factor
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype, def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
device: torch.device): device: torch.device):
self._update_attn_cache(max_seq_len, dtype, device) self._update_attn_cache(max_seq_len, dtype)
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous() return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
).to(device)
def get_splitfuse_attn_mask( def get_splitfuse_attn_mask(
self, self,
seq_lens, seq_lens: torch.Tensor,
query_lens, position: torch.Tensor,
position, dtype: torch.dtype,
dtype, device: torch.device,
device,
) -> torch.Tensor: ) -> torch.Tensor:
if dtype not in [torch.float16, torch.bfloat16]:
raise ValueError(
"splitfuse_attn_mask now only supports bf16 and fp16")
max_seq_len = max(seq_lens, default=0) max_seq_len = max(seq_lens, default=0)
if max_seq_len <= self._seq_len_cached: self._update_attn_cache(max_seq_len, dtype)
self._update_attn_cache(max_seq_len, dtype, device) # FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation # is not the same. Fix this in the future when kernel is ready.
# is not the same. Fix this in the future when kernel is ready. mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(dtype)
if self.attn_mask_cache.numel( attn_mask = torch.index_select(self.attn_mask_cache,
) > 1 and self.attn_mask_cache[0][1] > 0: dim=0,
attn_mask = self.get_attn_mask( # type: ignore index=position)[:, :max_seq_len]
max_seq_len, dtype, device) attn_mask *= mask_scale_factor
# Do not use in-place multiplication to avoid modifying `self.attn_mask_cache`! return attn_mask.contiguous().to(device, non_blocking=True)
attn_mask = attn_mask * -10000
else:
attn_mask = self.attn_mask_cache
return torch.index_select(attn_mask, dim=0,
index=position)[:, :max_seq_len]
total_q_len = sum(query_lens)
attn_mask = torch.zeros((total_q_len, max_seq_len),
dtype=dtype,
device="cpu")
current_row = 0
for i in range(len(query_lens)):
seq_len = seq_lens[i]
q_len = query_lens[i]
context_len = seq_len - q_len
assert context_len >= 0 def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
attn_mask[current_row:current_row + q_len,
context_len:] = self.splitfuse_mask_value
right_tensor = attn_mask[current_row:current_row + q_len,
context_len:seq_len]
right_tensor.masked_fill_(
right_tensor.tril() == self.splitfuse_mask_value, 0)
current_row += q_len
return attn_mask.to(device, non_blocking=True)
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype,
device: torch.device):
if seqlen > self._seq_len_cached: if seqlen > self._seq_len_cached:
self._seq_len_cached = seqlen self._seq_len_cached = seqlen
self.attn_mask_cache = _generate_attn_mask(seqlen, dtype) self.attn_mask_cache = _generate_attn_mask(seqlen, dtype)
if self.attn_mask_cache.device != device: if self.attn_mask_cache.dtype != dtype:
self.attn_mask_cache = self.attn_mask_cache.to(device) self.attn_mask_cache = self.attn_mask_cache.to(dtype)

View File

@@ -79,11 +79,10 @@ class EagleProposer:
def _make_attention_mask( def _make_attention_mask(
self, self,
seq_lens, seq_lens,
query_lens,
position, position,
) -> torch.Tensor: ) -> torch.Tensor:
return self.attn_mask_builder.get_splitfuse_attn_mask( return self.attn_mask_builder.get_splitfuse_attn_mask(
seq_lens, query_lens, position, self.dtype, self.device) seq_lens, position, self.dtype, self.device)
def propose( def propose(
self, self,
@@ -247,7 +246,6 @@ class EagleProposer:
positions = positions_cpu.to(device) positions = positions_cpu.to(device)
attn_mask = self._make_attention_mask( attn_mask = self._make_attention_mask(
seq_lens=attn_metadata.seq_lens, seq_lens=attn_metadata.seq_lens,
query_lens=attn_metadata.max_query_len,
position=positions, position=positions,
) )
attn_metadata.attn_mask = attn_mask attn_metadata.attn_mask = attn_mask

View File

@@ -20,7 +20,6 @@
import copy import copy
import gc import gc
import math import math
import os
import time import time
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
@@ -233,8 +232,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.attn_metadata_builder = self.attn_backend.get_builder_cls()( self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
vllm_config, device) vllm_config, device)
self.attn_mask_builder = AttentionMaskBuilder( self.attn_mask_builder = AttentionMaskBuilder(
min(self.model_config.max_model_len, self.model_config.max_model_len, self.dtype)
int(os.getenv("PAGED_ATTENTION_MASK_LEN", 10000))), self.dtype)
# Set up speculative decoding. # Set up speculative decoding.
self.use_aux_hidden_state_outputs = False self.use_aux_hidden_state_outputs = False
@@ -820,12 +818,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return tuple(tasks) return tuple(tasks)
def _make_attention_mask(self, seq_lens, query_lens, position, def _make_attention_mask(self, seq_lens, position,
attn_state) -> torch.Tensor: attn_state) -> torch.Tensor:
# Chunk Prefill situation. # Chunk Prefill situation.
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla: if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla:
return self.attn_mask_builder.get_splitfuse_attn_mask( return self.attn_mask_builder.get_splitfuse_attn_mask(
seq_lens, query_lens, position, self.dtype, self.device) seq_lens, position, self.dtype, self.device)
# Prefill without cache situation. # Prefill without cache situation.
elif attn_state == AscendAttentionState.PrefillNoCache: elif attn_state == AscendAttentionState.PrefillNoCache:
max_seq_len = max(seq_lens, default=0) max_seq_len = max(seq_lens, default=0)
@@ -1126,16 +1124,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.mrope_positions_cpu[:, :total_num_scheduled_tokens], self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
non_blocking=True) non_blocking=True)
self.positions[total_num_scheduled_tokens:num_input_tokens].zero_() self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
self.positions[:total_num_scheduled_tokens].copy_( self.positions[:num_input_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) self.positions_cpu[:num_input_tokens], non_blocking=True)
positions_cpu = self.positions_cpu[:num_input_tokens]
positions = self.positions[:num_input_tokens]
self.query_lens = torch.from_numpy(num_scheduled_tokens) self.query_lens = torch.from_numpy(num_scheduled_tokens)
self.seq_lens_np[:num_reqs] = ( self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] + self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens) num_scheduled_tokens)
seq_lens = self.seq_lens_cpu[:num_reqs] seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
block_table_indices = (req_indices * self.max_num_blocks_per_req + block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions_np // self.block_size) positions_np // self.block_size)
@@ -1150,11 +1149,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
num_valid_tokens) num_valid_tokens)
self.attn_mask = self._make_attention_mask( self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
seq_lens=seq_lens, position=positions_cpu,
query_lens=num_scheduled_tokens, attn_state=attn_state)
position=self.positions[:num_input_tokens],
attn_state=attn_state)
self.attn_state = attn_state # type: ignore self.attn_state = attn_state # type: ignore
self.query_start_loc_np[0] = 0 self.query_start_loc_np[0] = 0