[main][bugfix] Fix bugs and refactor cached mask generation logic (#2442)
### What this PR does / why we need it?
This PR fix bugs and refactor cached mask generation logic. Now just
pre-construct and use the cached mask on cpu instead of device on npu.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.10.1.1
- vLLM main:
9b5f64238f
Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -20,7 +20,6 @@
|
||||
import copy
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
@@ -233,8 +232,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||
vllm_config, device)
|
||||
self.attn_mask_builder = AttentionMaskBuilder(
|
||||
min(self.model_config.max_model_len,
|
||||
int(os.getenv("PAGED_ATTENTION_MASK_LEN", 10000))), self.dtype)
|
||||
self.model_config.max_model_len, self.dtype)
|
||||
|
||||
# Set up speculative decoding.
|
||||
self.use_aux_hidden_state_outputs = False
|
||||
@@ -820,12 +818,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
return tuple(tasks)
|
||||
|
||||
def _make_attention_mask(self, seq_lens, query_lens, position,
|
||||
def _make_attention_mask(self, seq_lens, position,
|
||||
attn_state) -> torch.Tensor:
|
||||
# Chunk Prefill situation.
|
||||
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla:
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask(
|
||||
seq_lens, query_lens, position, self.dtype, self.device)
|
||||
seq_lens, position, self.dtype, self.device)
|
||||
# Prefill without cache situation.
|
||||
elif attn_state == AscendAttentionState.PrefillNoCache:
|
||||
max_seq_len = max(seq_lens, default=0)
|
||||
@@ -1126,16 +1124,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
|
||||
self.positions[total_num_scheduled_tokens:num_input_tokens].zero_()
|
||||
self.positions[:total_num_scheduled_tokens].copy_(
|
||||
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||
|
||||
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
|
||||
self.positions[:num_input_tokens].copy_(
|
||||
self.positions_cpu[:num_input_tokens], non_blocking=True)
|
||||
positions_cpu = self.positions_cpu[:num_input_tokens]
|
||||
positions = self.positions[:num_input_tokens]
|
||||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||
|
||||
self.seq_lens_np[:num_reqs] = (
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||
num_scheduled_tokens)
|
||||
seq_lens = self.seq_lens_cpu[:num_reqs]
|
||||
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
|
||||
|
||||
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
||||
positions_np // self.block_size)
|
||||
@@ -1150,11 +1149,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
|
||||
num_valid_tokens)
|
||||
|
||||
self.attn_mask = self._make_attention_mask(
|
||||
seq_lens=seq_lens,
|
||||
query_lens=num_scheduled_tokens,
|
||||
position=self.positions[:num_input_tokens],
|
||||
attn_state=attn_state)
|
||||
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
|
||||
position=positions_cpu,
|
||||
attn_state=attn_state)
|
||||
self.attn_state = attn_state # type: ignore
|
||||
|
||||
self.query_start_loc_np[0] = 0
|
||||
|
||||
Reference in New Issue
Block a user