[Scheduler] Add AscendScheduler. (#543)

This PR adds AscendScheduler to vllm v1 engine.
This scheduler currently supports v0-style prefill-first scheduling
strategy.
In the future more schedule methods will be supported by this scheduler.

---------

Signed-off-by: hw_whx <wanghexiang7@huawei.com>
Co-authored-by: hw_whx <wanghexiang7@huawei.com>
This commit is contained in:
whx
2025-04-17 19:31:50 +08:00
committed by GitHub
parent 697908f5cd
commit 20dff4deff
9 changed files with 967 additions and 72 deletions

View File

@@ -43,7 +43,7 @@ if TYPE_CHECKING:
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None):
# Construct lower triangle matrix.
mask_flag = torch.tril(
torch.ones((max_seq_len, max_seq_len),
@@ -52,10 +52,11 @@ def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
mask_flag = ~mask_flag
# Currently for fp16 dtype, the mask value should be set to -inf.
# TODO: Eliminate this part in the future.
if dtype == torch.float16:
mask_value = torch.finfo(torch.float32).min
else:
mask_value = 1
if mask_value is None:
if dtype == torch.float16:
mask_value = torch.finfo(torch.float32).min
else:
mask_value = 1
attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
mask_flag, mask_value).to(dtype)
return attn_mask
@@ -66,12 +67,14 @@ class AttentionMaskBuilder:
def __init__(self, attn_mask: torch.Tensor):
self._seq_len_cached = attn_mask.shape[0]
self.attn_mask_cache = attn_mask
self.splitfuse_mask_value = -10000
@classmethod
def initialize_from_len(cls,
max_seq_len: int,
dtype: torch.dtype = torch.float16):
return cls(generate_attn_mask(max_seq_len, dtype))
dtype: torch.dtype = torch.float16,
mask_value: Optional[int] = None):
return cls(generate_attn_mask(max_seq_len, dtype, mask_value))
def update_attn_cache(self, seqlen: int, dtype: torch.dtype,
device: torch.device):
@@ -97,6 +100,49 @@ class AttentionMaskBuilder:
return (self.attn_mask_cache.index_select(
0, input_lengths)[:, :max_s].view(-1, 1, max_s).contiguous())
def get_splitfuse_attn_mask(
self,
seq_lens,
query_lens,
position,
dtype,
device,
) -> torch.Tensor:
max_seq_len = max(seq_lens, default=0)
if max_seq_len <= self._seq_len_cached:
self.update_attn_cache(max_seq_len, dtype, device)
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
# is not the same. Fix this in the future when kernel is ready.
if self.attn_mask_cache[0][1] > 0:
attn_mask = self.get_attn_mask( # type: ignore
max_seq_len, dtype, device)
attn_mask *= -10000
else:
attn_mask = self.attn_mask_cache
return torch.index_select(attn_mask, dim=0,
index=position)[:, :max_seq_len]
total_q_len = sum(query_lens)
attn_mask = torch.zeros((total_q_len, max_seq_len),
dtype=dtype,
device="cpu")
current_row = 0
for i in range(len(query_lens)):
seq_len = seq_lens[i]
q_len = query_lens[i]
context_len = seq_len - q_len
assert context_len >= 0
attn_mask[current_row:current_row + q_len,
context_len:] = self.splitfuse_mask_value
right_tensor = attn_mask[current_row:current_row + q_len,
context_len:seq_len]
right_tensor.mask_fill_(
right_tensor.tril() == self.splitfuse_mask_value, 0)
current_row += q_len
return attn_mask.to(device, non_blocking=True)
class AscendAttentionBackend(AttentionBackend):