[ModelRunner][V1] Optimize V1 attention mask (#442)
### What this PR does / why we need it? Pre-construct a mask matrix to improve the efficiency of attention mask construction during inference. Note that the length of the matrix needs to be carefully balanced: a matrix that is too large will consume excessive VRAM, while a matrix that is too small will require dynamic concatenation during inference, leading to performance degradation. Therefore, an environment variable is added here to dynamically set the size of the pre-constructed mask matrix based on requirements. --------- Signed-off-by: shen-shanshan <467638484@qq.com> Co-authored-by: didongli182 <didongli@huawei.com>
This commit is contained in:
@@ -18,6 +18,7 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
|
import os
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -53,6 +54,8 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|
||||||
|
NPU_PAGED_ATTENTION_MASK_VALUE = -10000
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -210,6 +213,24 @@ class NPUModelRunner:
|
|||||||
self.max_num_tokens,
|
self.max_num_tokens,
|
||||||
device="cpu")
|
device="cpu")
|
||||||
|
|
||||||
|
# NOTE: Pre-construct a mask matrix to improve the efficiency of
|
||||||
|
# attention mask construction during inference.
|
||||||
|
# Note that the length of the matrix needs to be carefully balanced: a
|
||||||
|
# matrix that is too large will consume excessive VRAM, while a matrix
|
||||||
|
# that is too small will require dynamic concatenation during inference,
|
||||||
|
# leading to performance degradation.
|
||||||
|
# Therefore, an environment variable is added here to dynamically set
|
||||||
|
# the size of the pre-constructed mask matrix based on requirements.
|
||||||
|
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
|
||||||
|
self.attn_mask_len = min(self.max_model_len, int(mask_len))
|
||||||
|
self.attn_mask_npu = torch.full(
|
||||||
|
(self.attn_mask_len, self.attn_mask_len),
|
||||||
|
NPU_PAGED_ATTENTION_MASK_VALUE,
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.vllm_config.model_config.dtype)
|
||||||
|
self.attn_mask_npu.masked_fill_(
|
||||||
|
self.attn_mask_npu.tril() == NPU_PAGED_ATTENTION_MASK_VALUE, 0)
|
||||||
|
|
||||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||||
"""Update the cached states and the persistent batch with the scheduler
|
"""Update the cached states and the persistent batch with the scheduler
|
||||||
output.
|
output.
|
||||||
@@ -365,40 +386,52 @@ class NPUModelRunner:
|
|||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
@staticmethod
|
def make_attention_mask(self, seq_lens, query_lens,
|
||||||
def make_attention_mask(kv_dtype, kv_device, max_seq_len, seq_lens,
|
position) -> torch.Tensor:
|
||||||
query_lens):
|
max_seq_len = max(seq_lens, default=0)
|
||||||
# for paged attention
|
if max_seq_len <= self.attn_mask_len:
|
||||||
atten_mask = np.zeros([0, max_seq_len])
|
return torch.index_select(self.attn_mask_npu,
|
||||||
for i, context_length in enumerate(seq_lens):
|
dim=0,
|
||||||
q_len = query_lens[i]
|
index=position)[:, :max_seq_len]
|
||||||
ones_len = context_length - q_len
|
|
||||||
ones = np.ones((q_len, ones_len), dtype=np.float16)
|
|
||||||
bias_cache = np.tril(
|
|
||||||
np.ones((q_len, max_seq_len - ones_len), dtype=np.float16))
|
|
||||||
bias_cache = np.concatenate((ones, bias_cache), axis=1)
|
|
||||||
mask_value = -10000
|
|
||||||
bias_cache[bias_cache == 0] = mask_value
|
|
||||||
bias_cache[bias_cache == 1] = 0
|
|
||||||
|
|
||||||
atten_mask = np.concatenate([atten_mask, bias_cache], axis=0)
|
total_q_len = sum(query_lens)
|
||||||
atten_mask = torch.from_numpy(atten_mask).to(kv_dtype).to(kv_device)
|
attn_mask = torch.zeros((total_q_len, max_seq_len),
|
||||||
return atten_mask
|
dtype=self.vllm_config.model_config.dtype,
|
||||||
|
device="cpu")
|
||||||
|
|
||||||
|
current_row = 0
|
||||||
|
for i in range(len(query_lens)):
|
||||||
|
seq_len = seq_lens[i]
|
||||||
|
q_len = query_lens[i]
|
||||||
|
context_len = seq_len - q_len
|
||||||
|
|
||||||
|
assert context_len >= 0
|
||||||
|
attn_mask[current_row:current_row + q_len,
|
||||||
|
context_len:] = NPU_PAGED_ATTENTION_MASK_VALUE
|
||||||
|
right_tensor = attn_mask[current_row:current_row + q_len,
|
||||||
|
context_len:seq_len]
|
||||||
|
right_tensor.mask_fill_(
|
||||||
|
right_tensor.tril() == NPU_PAGED_ATTENTION_MASK_VALUE, 0)
|
||||||
|
current_row += q_len
|
||||||
|
|
||||||
|
return attn_mask.to(self.device, non_blocking=True)
|
||||||
|
|
||||||
def _process_reqs(
|
def _process_reqs(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# check input valid
|
# Check input valid
|
||||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
assert total_num_scheduled_tokens > 0
|
assert total_num_scheduled_tokens > 0
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
assert num_reqs > 0
|
assert num_reqs > 0
|
||||||
|
|
||||||
|
# Copy the blocks from CPU to NPU.
|
||||||
# OPTIMIZATION: Start copying the block table first.
|
# OPTIMIZATION: Start copying the block table first.
|
||||||
# This way, we can overlap the copy with the following CPU operations.
|
# This way, we can overlap the copy with the following CPU operations.
|
||||||
self.input_batch.block_table.commit(num_reqs)
|
self.input_batch.block_table.commit(num_reqs)
|
||||||
|
|
||||||
# Get the number of scheduled tokens for each request.
|
# Get the number of scheduled tokens for each request.
|
||||||
# TODO: The Python loop can be slow. Optimize.
|
# TODO: The Python loop can be slow. Optimize.
|
||||||
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
|
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
|
||||||
@@ -409,7 +442,7 @@ class NPUModelRunner:
|
|||||||
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
||||||
num_tokens)
|
num_tokens)
|
||||||
|
|
||||||
# prepare positions
|
# Prepare positions
|
||||||
req_indices = np.repeat(self.arange_np[:num_reqs],
|
req_indices = np.repeat(self.arange_np[:num_reqs],
|
||||||
num_scheduled_tokens)
|
num_scheduled_tokens)
|
||||||
cu_num_tokens = np.cumsum(num_scheduled_tokens)
|
cu_num_tokens = np.cumsum(num_scheduled_tokens)
|
||||||
@@ -444,9 +477,9 @@ class NPUModelRunner:
|
|||||||
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
|
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
|
||||||
self.device, non_blocking=True)
|
self.device, non_blocking=True)
|
||||||
|
|
||||||
attn_mask = self.make_attention_mask(
|
attn_mask = self.make_attention_mask(seq_lens=seq_lens,
|
||||||
self.vllm_config.model_config.dtype, self.device,
|
query_lens=num_scheduled_tokens,
|
||||||
max(seq_lens, default=0), seq_lens, num_scheduled_tokens)
|
position=positions)
|
||||||
|
|
||||||
attn_metadata = AscendMetadata(
|
attn_metadata = AscendMetadata(
|
||||||
seq_lens=query_lens,
|
seq_lens=query_lens,
|
||||||
@@ -457,7 +490,7 @@ class NPUModelRunner:
|
|||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
# prepare input_ids
|
# Prepare input_ids
|
||||||
token_indices = (positions_np +
|
token_indices = (positions_np +
|
||||||
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
||||||
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
||||||
@@ -468,6 +501,7 @@ class NPUModelRunner:
|
|||||||
self.input_ids[:total_num_scheduled_tokens].copy_(
|
self.input_ids[:total_num_scheduled_tokens].copy_(
|
||||||
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||||
input_ids = self.input_ids[:total_num_scheduled_tokens]
|
input_ids = self.input_ids[:total_num_scheduled_tokens]
|
||||||
|
|
||||||
# Run forward pass
|
# Run forward pass
|
||||||
with set_forward_context(attn_metadata, self.vllm_config):
|
with set_forward_context(attn_metadata, self.vllm_config):
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user