diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 76b3e31..a6736c5 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -18,6 +18,7 @@ # import gc +import os from typing import TYPE_CHECKING, Dict, List, Optional, Union import numpy as np @@ -53,6 +54,8 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend, if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput +NPU_PAGED_ATTENTION_MASK_VALUE = -10000 + logger = init_logger(__name__) @@ -210,6 +213,24 @@ class NPUModelRunner: self.max_num_tokens, device="cpu") + # NOTE: Pre-construct a mask matrix to improve the efficiency of + # attention mask construction during inference. + # Note that the length of the matrix needs to be carefully balanced: a + # matrix that is too large will consume excessive VRAM, while a matrix + # that is too small will require dynamic concatenation during inference, + # leading to performance degradation. + # Therefore, an environment variable is added here to dynamically set + # the size of the pre-constructed mask matrix based on requirements. + mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000) + self.attn_mask_len = min(self.max_model_len, int(mask_len)) + self.attn_mask_npu = torch.full( + (self.attn_mask_len, self.attn_mask_len), + NPU_PAGED_ATTENTION_MASK_VALUE, + device=self.device, + dtype=self.vllm_config.model_config.dtype) + self.attn_mask_npu.masked_fill_( + self.attn_mask_npu.tril() == NPU_PAGED_ATTENTION_MASK_VALUE, 0) + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. @@ -365,40 +386,52 @@ class NPUModelRunner: def get_model(self) -> nn.Module: return self.model - @staticmethod - def make_attention_mask(kv_dtype, kv_device, max_seq_len, seq_lens, - query_lens): - # for paged attention - atten_mask = np.zeros([0, max_seq_len]) - for i, context_length in enumerate(seq_lens): - q_len = query_lens[i] - ones_len = context_length - q_len - ones = np.ones((q_len, ones_len), dtype=np.float16) - bias_cache = np.tril( - np.ones((q_len, max_seq_len - ones_len), dtype=np.float16)) - bias_cache = np.concatenate((ones, bias_cache), axis=1) - mask_value = -10000 - bias_cache[bias_cache == 0] = mask_value - bias_cache[bias_cache == 1] = 0 + def make_attention_mask(self, seq_lens, query_lens, + position) -> torch.Tensor: + max_seq_len = max(seq_lens, default=0) + if max_seq_len <= self.attn_mask_len: + return torch.index_select(self.attn_mask_npu, + dim=0, + index=position)[:, :max_seq_len] - atten_mask = np.concatenate([atten_mask, bias_cache], axis=0) - atten_mask = torch.from_numpy(atten_mask).to(kv_dtype).to(kv_device) - return atten_mask + total_q_len = sum(query_lens) + attn_mask = torch.zeros((total_q_len, max_seq_len), + dtype=self.vllm_config.model_config.dtype, + device="cpu") + + current_row = 0 + for i in range(len(query_lens)): + seq_len = seq_lens[i] + q_len = query_lens[i] + context_len = seq_len - q_len + + assert context_len >= 0 + attn_mask[current_row:current_row + q_len, + context_len:] = NPU_PAGED_ATTENTION_MASK_VALUE + right_tensor = attn_mask[current_row:current_row + q_len, + context_len:seq_len] + right_tensor.mask_fill_( + right_tensor.tril() == NPU_PAGED_ATTENTION_MASK_VALUE, 0) + current_row += q_len + + return attn_mask.to(self.device, non_blocking=True) def _process_reqs( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: - # check input valid + # Check input valid total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 + # Copy the blocks from CPU to NPU. # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit(num_reqs) + # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32) @@ -409,7 +442,7 @@ class NPUModelRunner: max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) - # prepare positions + # Prepare positions req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) cu_num_tokens = np.cumsum(num_scheduled_tokens) @@ -444,9 +477,9 @@ class NPUModelRunner: slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to( self.device, non_blocking=True) - attn_mask = self.make_attention_mask( - self.vllm_config.model_config.dtype, self.device, - max(seq_lens, default=0), seq_lens, num_scheduled_tokens) + attn_mask = self.make_attention_mask(seq_lens=seq_lens, + query_lens=num_scheduled_tokens, + position=positions) attn_metadata = AscendMetadata( seq_lens=query_lens, @@ -457,7 +490,7 @@ class NPUModelRunner: attn_mask=attn_mask, ) - # prepare input_ids + # Prepare input_ids token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), @@ -468,6 +501,7 @@ class NPUModelRunner: self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) input_ids = self.input_ids[:total_num_scheduled_tokens] + # Run forward pass with set_forward_context(attn_metadata, self.vllm_config): assert self.model is not None