[main][bugfix] Fix bugs and refactor cached mask generation logic (#2442)

### What this PR does / why we need it?
This PR fix bugs and refactor cached mask generation logic. Now just
pre-construct and use the cached mask on cpu instead of device on npu.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.10.1.1
- vLLM main:
9b5f64238f

Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
rjg-lyh
2025-08-27 12:07:29 +08:00
committed by GitHub
parent 6881c19458
commit 2bfbf9b9b3
4 changed files with 97 additions and 136 deletions

View File

@@ -20,7 +20,6 @@
import copy
import gc
import math
import os
import time
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
@@ -233,8 +232,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
vllm_config, device)
self.attn_mask_builder = AttentionMaskBuilder(
min(self.model_config.max_model_len,
int(os.getenv("PAGED_ATTENTION_MASK_LEN", 10000))), self.dtype)
self.model_config.max_model_len, self.dtype)
# Set up speculative decoding.
self.use_aux_hidden_state_outputs = False
@@ -820,12 +818,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return tuple(tasks)
def _make_attention_mask(self, seq_lens, query_lens, position,
def _make_attention_mask(self, seq_lens, position,
attn_state) -> torch.Tensor:
# Chunk Prefill situation.
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla:
return self.attn_mask_builder.get_splitfuse_attn_mask(
seq_lens, query_lens, position, self.dtype, self.device)
seq_lens, position, self.dtype, self.device)
# Prefill without cache situation.
elif attn_state == AscendAttentionState.PrefillNoCache:
max_seq_len = max(seq_lens, default=0)
@@ -1126,16 +1124,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
non_blocking=True)
self.positions[total_num_scheduled_tokens:num_input_tokens].zero_()
self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
self.positions[:num_input_tokens].copy_(
self.positions_cpu[:num_input_tokens], non_blocking=True)
positions_cpu = self.positions_cpu[:num_input_tokens]
positions = self.positions[:num_input_tokens]
self.query_lens = torch.from_numpy(num_scheduled_tokens)
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
seq_lens = self.seq_lens_cpu[:num_reqs]
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions_np // self.block_size)
@@ -1150,11 +1149,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
num_valid_tokens)
self.attn_mask = self._make_attention_mask(
seq_lens=seq_lens,
query_lens=num_scheduled_tokens,
position=self.positions[:num_input_tokens],
attn_state=attn_state)
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
position=positions_cpu,
attn_state=attn_state)
self.attn_state = attn_state # type: ignore
self.query_start_loc_np[0] = 0