[Scheduler] Add AscendScheduler. (#543)
This PR adds AscendScheduler to vllm v1 engine. This scheduler currently supports v0-style prefill-first scheduling strategy. In the future more schedule methods will be supported by this scheduler. --------- Signed-off-by: hw_whx <wanghexiang7@huawei.com> Co-authored-by: hw_whx <wanghexiang7@huawei.com>
This commit is contained in:
@@ -43,7 +43,7 @@ if TYPE_CHECKING:
|
||||
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
|
||||
|
||||
|
||||
def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
|
||||
def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None):
|
||||
# Construct lower triangle matrix.
|
||||
mask_flag = torch.tril(
|
||||
torch.ones((max_seq_len, max_seq_len),
|
||||
@@ -52,10 +52,11 @@ def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
|
||||
mask_flag = ~mask_flag
|
||||
# Currently for fp16 dtype, the mask value should be set to -inf.
|
||||
# TODO: Eliminate this part in the future.
|
||||
if dtype == torch.float16:
|
||||
mask_value = torch.finfo(torch.float32).min
|
||||
else:
|
||||
mask_value = 1
|
||||
if mask_value is None:
|
||||
if dtype == torch.float16:
|
||||
mask_value = torch.finfo(torch.float32).min
|
||||
else:
|
||||
mask_value = 1
|
||||
attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
|
||||
mask_flag, mask_value).to(dtype)
|
||||
return attn_mask
|
||||
@@ -66,12 +67,14 @@ class AttentionMaskBuilder:
|
||||
def __init__(self, attn_mask: torch.Tensor):
|
||||
self._seq_len_cached = attn_mask.shape[0]
|
||||
self.attn_mask_cache = attn_mask
|
||||
self.splitfuse_mask_value = -10000
|
||||
|
||||
@classmethod
|
||||
def initialize_from_len(cls,
|
||||
max_seq_len: int,
|
||||
dtype: torch.dtype = torch.float16):
|
||||
return cls(generate_attn_mask(max_seq_len, dtype))
|
||||
dtype: torch.dtype = torch.float16,
|
||||
mask_value: Optional[int] = None):
|
||||
return cls(generate_attn_mask(max_seq_len, dtype, mask_value))
|
||||
|
||||
def update_attn_cache(self, seqlen: int, dtype: torch.dtype,
|
||||
device: torch.device):
|
||||
@@ -97,6 +100,49 @@ class AttentionMaskBuilder:
|
||||
return (self.attn_mask_cache.index_select(
|
||||
0, input_lengths)[:, :max_s].view(-1, 1, max_s).contiguous())
|
||||
|
||||
def get_splitfuse_attn_mask(
|
||||
self,
|
||||
seq_lens,
|
||||
query_lens,
|
||||
position,
|
||||
dtype,
|
||||
device,
|
||||
) -> torch.Tensor:
|
||||
max_seq_len = max(seq_lens, default=0)
|
||||
if max_seq_len <= self._seq_len_cached:
|
||||
self.update_attn_cache(max_seq_len, dtype, device)
|
||||
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
|
||||
# is not the same. Fix this in the future when kernel is ready.
|
||||
if self.attn_mask_cache[0][1] > 0:
|
||||
attn_mask = self.get_attn_mask( # type: ignore
|
||||
max_seq_len, dtype, device)
|
||||
attn_mask *= -10000
|
||||
else:
|
||||
attn_mask = self.attn_mask_cache
|
||||
return torch.index_select(attn_mask, dim=0,
|
||||
index=position)[:, :max_seq_len]
|
||||
total_q_len = sum(query_lens)
|
||||
attn_mask = torch.zeros((total_q_len, max_seq_len),
|
||||
dtype=dtype,
|
||||
device="cpu")
|
||||
|
||||
current_row = 0
|
||||
for i in range(len(query_lens)):
|
||||
seq_len = seq_lens[i]
|
||||
q_len = query_lens[i]
|
||||
context_len = seq_len - q_len
|
||||
|
||||
assert context_len >= 0
|
||||
attn_mask[current_row:current_row + q_len,
|
||||
context_len:] = self.splitfuse_mask_value
|
||||
right_tensor = attn_mask[current_row:current_row + q_len,
|
||||
context_len:seq_len]
|
||||
right_tensor.mask_fill_(
|
||||
right_tensor.tril() == self.splitfuse_mask_value, 0)
|
||||
current_row += q_len
|
||||
|
||||
return attn_mask.to(device, non_blocking=True)
|
||||
|
||||
|
||||
class AscendAttentionBackend(AttentionBackend):
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
@@ -50,7 +51,7 @@ class AscendAttentionBackend(AttentionBackend):
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
@@ -83,6 +84,12 @@ class AscendAttentionBackend(AttentionBackend):
|
||||
value_caches[dst_indices] = value_caches[src_indices]
|
||||
|
||||
|
||||
class AscendAttentionState(Enum):
|
||||
PrefillOnly = 0
|
||||
DecodeOnly = 1
|
||||
ChunkedPrefill = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMetadata:
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
@@ -104,6 +111,8 @@ class AscendMetadata:
|
||||
# FlashAttention has better performance than PageAtttention,
|
||||
# but it does not support decode requests.
|
||||
is_only_prefill: bool = False
|
||||
# Current state of this attention run.
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
|
||||
@@ -139,7 +148,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.seq_len_cpu_tensor = None
|
||||
self.key_cache = None
|
||||
self.value_cache = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -190,30 +200,52 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
# TODO: Remove this contiguous in the future.
|
||||
value = value.contiguous()
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
if self.key_cache is None:
|
||||
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
torch_npu._npu_reshape_and_cache(key=key,
|
||||
value=value,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_indices=slots)
|
||||
|
||||
if hasattr(layer, 'quant_method'):
|
||||
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
|
||||
pass
|
||||
# V0-Style scheduler situation.
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
torch_npu._npu_flash_attention(query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
block_tables = attn_metadata.block_tables
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=block_tables,
|
||||
context_lens=attn_metadata.context_lens,
|
||||
out=output)
|
||||
# Normal V1 situation.
|
||||
else:
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
num_blocks, block_size, _ = key_cache.shape
|
||||
key_cache = key_cache.view(num_blocks, block_size,
|
||||
self.num_kv_heads, self.head_size)
|
||||
value_cache = value_cache.view(num_blocks, block_size,
|
||||
self.num_kv_heads,
|
||||
self.head_size)
|
||||
slots = attn_metadata.slot_mapping
|
||||
torch_npu._npu_reshape_and_cache(key=key,
|
||||
value=value,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
slot_indices=slots)
|
||||
|
||||
# use paged attention
|
||||
torch_npu._npu_paged_attention_splitfuse(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
mask=attn_metadata.attn_mask,
|
||||
block_table=attn_metadata.block_tables,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
|
||||
Reference in New Issue
Block a user