[Scheduler] Add AscendScheduler. (#543)

This PR adds AscendScheduler to vllm v1 engine.
This scheduler currently supports v0-style prefill-first scheduling
strategy.
In the future more schedule methods will be supported by this scheduler.

---------

Signed-off-by: hw_whx <wanghexiang7@huawei.com>
Co-authored-by: hw_whx <wanghexiang7@huawei.com>
This commit is contained in:
whx
2025-04-17 19:31:50 +08:00
committed by GitHub
parent 697908f5cd
commit 20dff4deff
9 changed files with 967 additions and 72 deletions

View File

@@ -25,7 +25,7 @@ import numpy as np
import numpy.typing as npt
import torch
import torch.nn as nn
from vllm.attention import AttentionType
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
@@ -37,7 +37,8 @@ from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import DeviceMemoryProfiler, LayerBlockType, cdiv
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
@@ -45,15 +46,14 @@ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
from vllm_ascend.attention.attention import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata)
from vllm_ascend.platform import NPUPlatform
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
NPU_PAGED_ATTENTION_MASK_VALUE = -10000
class NPUModelRunner:
@@ -74,6 +74,32 @@ class NPUModelRunner:
self.num_attn_layers = self.model_config.get_num_layers_by_block_type(
vllm_config.parallel_config, LayerBlockType.attention)
self.hidden_size = self.model_config.get_hidden_size()
self.dtype = self.model_config.dtype
cache_config = vllm_config.cache_config
if cache_config.cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
self.head_size = self.model_config.get_head_size()
self.attn_backend = get_attn_backend(
self.head_size,
self.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
)
if self.attn_backend is None:
error_msg = (
f"Error with get_att_backend: {self.head_size=}, "
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
f"{self.model_config.is_attention_free=}, "
f"{self.model_config.use_mla=}")
logger.error(error_msg)
raise NotImplementedError(
"Non-Attention backend is not supported by V1 NPUModelRunner.")
# Multi-modal data support
self.input_registry = INPUT_REGISTRY
@@ -135,7 +161,7 @@ class NPUModelRunner:
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.model_config.dtype,
dtype=self.dtype,
device=self.device)
# OPTIMIZATION: Cache the tensors rather than creating them every step.
@@ -183,13 +209,8 @@ class NPUModelRunner:
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
self.attn_mask_len = min(self.model_config.max_model_len,
int(mask_len))
self.attn_mask_npu = torch.full(
(self.attn_mask_len, self.attn_mask_len),
NPU_PAGED_ATTENTION_MASK_VALUE,
device=self.device,
dtype=self.vllm_config.model_config.dtype)
self.attn_mask_npu.masked_fill_(
self.attn_mask_npu.tril() == NPU_PAGED_ATTENTION_MASK_VALUE, 0)
self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
self.attn_mask_len, self.dtype)
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler
@@ -346,35 +367,20 @@ class NPUModelRunner:
def get_model(self) -> nn.Module:
return self.model
def _make_attention_mask(self, seq_lens, query_lens,
position) -> torch.Tensor:
max_seq_len = max(seq_lens, default=0)
if max_seq_len <= self.attn_mask_len:
return torch.index_select(self.attn_mask_npu,
dim=0,
index=position)[:, :max_seq_len]
total_q_len = sum(query_lens)
attn_mask = torch.zeros((total_q_len, max_seq_len),
dtype=self.vllm_config.model_config.dtype,
device="cpu")
current_row = 0
for i in range(len(query_lens)):
seq_len = seq_lens[i]
q_len = query_lens[i]
context_len = seq_len - q_len
assert context_len >= 0
attn_mask[current_row:current_row + q_len,
context_len:] = NPU_PAGED_ATTENTION_MASK_VALUE
right_tensor = attn_mask[current_row:current_row + q_len,
context_len:seq_len]
right_tensor.mask_fill_(
right_tensor.tril() == NPU_PAGED_ATTENTION_MASK_VALUE, 0)
current_row += q_len
return attn_mask.to(self.device, non_blocking=True)
def _make_attention_mask(self, seq_lens, query_lens, position,
attn_state) -> torch.Tensor:
# Chunk Prefill situation.
if attn_state == AscendAttentionState.ChunkedPrefill:
return self.attn_mask_builder.get_splitfuse_attn_mask(
seq_lens, query_lens, position, self.dtype, self.device)
# Prefill-only situation.
elif attn_state == AscendAttentionState.PrefillOnly:
max_seq_len = max(seq_lens, default=0)
return self.attn_mask_builder.get_attn_mask(
max_seq_len, self.dtype, self.device)
# Decode-only situation.
else:
return None
def _process_reqs(
self,
@@ -408,6 +414,9 @@ class NPUModelRunner:
cu_num_tokens = np.cumsum(num_scheduled_tokens)
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
num_scheduled_tokens)
sample_indices = cu_num_tokens - 1
sample_indices = torch.from_numpy(sample_indices).to(self.device,
non_blocking=True)
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
positions_np = self.positions_np[:total_num_scheduled_tokens]
@@ -437,9 +446,18 @@ class NPUModelRunner:
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
self.device, non_blocking=True)
attn_state = AscendAttentionState.ChunkedPrefill
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
attn_state = AscendAttentionState.PrefillOnly
elif np.all(num_scheduled_tokens == 1):
attn_state = AscendAttentionState.DecodeOnly
else:
attn_state = AscendAttentionState.ChunkedPrefill
attn_mask = self._make_attention_mask(seq_lens=seq_lens,
query_lens=num_scheduled_tokens,
position=positions)
position=positions,
attn_state=attn_state)
attn_metadata = AscendMetadata(
seq_lens=query_lens,
@@ -448,6 +466,7 @@ class NPUModelRunner:
block_tables=(
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
attn_mask=attn_mask,
attn_state=attn_state,
)
# Prepare input_ids
@@ -472,7 +491,7 @@ class NPUModelRunner:
inputs_embeds=None,
)
return hidden_states[cu_num_tokens - 1]
return hidden_states[sample_indices]
@torch.inference_mode()
def execute_model(
@@ -636,7 +655,7 @@ class NPUModelRunner:
self.intermediate_tensors = (
self.model.make_empty_intermediate_tensors(
batch_size=self.max_num_tokens,
dtype=self.model_config.dtype,
dtype=self.dtype,
device=self.device))
intermediate_tensors = IntermediateTensors({
k: v[:self.max_num_tokens]
@@ -708,6 +727,7 @@ class NPUModelRunner:
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
import torch_npu
kv_caches: Dict[str, torch.Tensor] = {}
for kv_cache_group in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group.kv_cache_spec
@@ -724,13 +744,14 @@ class NPUModelRunner:
# the min of all `num_blocks`. Verify it here.
assert num_blocks >= kv_cache_config.num_blocks
if isinstance(kv_cache_spec, FullAttentionSpec):
kv_cache_shape = AscendAttentionBackend.get_kv_cache_shape(
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
torch_npu.npu_format_cast(kv_caches[layer_name], 2)
else:
# TODO: add new branches when introducing more types of
# KV cache specs.