[Scheduler] Add AscendScheduler. (#543)
This PR adds AscendScheduler to vllm v1 engine. This scheduler currently supports v0-style prefill-first scheduling strategy. In the future more schedule methods will be supported by this scheduler. --------- Signed-off-by: hw_whx <wanghexiang7@huawei.com> Co-authored-by: hw_whx <wanghexiang7@huawei.com>
This commit is contained in:
@@ -25,7 +25,7 @@ import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
@@ -37,7 +37,8 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import DeviceMemoryProfiler, LayerBlockType, cdiv
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
LayerBlockType, cdiv)
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
@@ -45,15 +46,14 @@ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
|
||||
from vllm_ascend.attention.attention import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||
AscendMetadata)
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
NPU_PAGED_ATTENTION_MASK_VALUE = -10000
|
||||
|
||||
|
||||
class NPUModelRunner:
|
||||
|
||||
@@ -74,6 +74,32 @@ class NPUModelRunner:
|
||||
self.num_attn_layers = self.model_config.get_num_layers_by_block_type(
|
||||
vllm_config.parallel_config, LayerBlockType.attention)
|
||||
self.hidden_size = self.model_config.get_hidden_size()
|
||||
self.dtype = self.model_config.dtype
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config.cache_dtype == "auto":
|
||||
self.kv_cache_dtype = self.dtype
|
||||
else:
|
||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
cache_config.cache_dtype]
|
||||
|
||||
self.head_size = self.model_config.get_head_size()
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.head_size,
|
||||
self.dtype,
|
||||
self.kv_cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=self.model_config.use_mla,
|
||||
)
|
||||
if self.attn_backend is None:
|
||||
error_msg = (
|
||||
f"Error with get_att_backend: {self.head_size=}, "
|
||||
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
|
||||
f"{self.model_config.is_attention_free=}, "
|
||||
f"{self.model_config.use_mla=}")
|
||||
logger.error(error_msg)
|
||||
raise NotImplementedError(
|
||||
"Non-Attention backend is not supported by V1 NPUModelRunner.")
|
||||
|
||||
# Multi-modal data support
|
||||
self.input_registry = INPUT_REGISTRY
|
||||
@@ -135,7 +161,7 @@ class NPUModelRunner:
|
||||
|
||||
self.inputs_embeds = torch.zeros(
|
||||
(self.max_num_tokens, self.hidden_size),
|
||||
dtype=self.model_config.dtype,
|
||||
dtype=self.dtype,
|
||||
device=self.device)
|
||||
|
||||
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
||||
@@ -183,13 +209,8 @@ class NPUModelRunner:
|
||||
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
|
||||
self.attn_mask_len = min(self.model_config.max_model_len,
|
||||
int(mask_len))
|
||||
self.attn_mask_npu = torch.full(
|
||||
(self.attn_mask_len, self.attn_mask_len),
|
||||
NPU_PAGED_ATTENTION_MASK_VALUE,
|
||||
device=self.device,
|
||||
dtype=self.vllm_config.model_config.dtype)
|
||||
self.attn_mask_npu.masked_fill_(
|
||||
self.attn_mask_npu.tril() == NPU_PAGED_ATTENTION_MASK_VALUE, 0)
|
||||
self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
|
||||
self.attn_mask_len, self.dtype)
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""Update the cached states and the persistent batch with the scheduler
|
||||
@@ -346,35 +367,20 @@ class NPUModelRunner:
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def _make_attention_mask(self, seq_lens, query_lens,
|
||||
position) -> torch.Tensor:
|
||||
max_seq_len = max(seq_lens, default=0)
|
||||
if max_seq_len <= self.attn_mask_len:
|
||||
return torch.index_select(self.attn_mask_npu,
|
||||
dim=0,
|
||||
index=position)[:, :max_seq_len]
|
||||
|
||||
total_q_len = sum(query_lens)
|
||||
attn_mask = torch.zeros((total_q_len, max_seq_len),
|
||||
dtype=self.vllm_config.model_config.dtype,
|
||||
device="cpu")
|
||||
|
||||
current_row = 0
|
||||
for i in range(len(query_lens)):
|
||||
seq_len = seq_lens[i]
|
||||
q_len = query_lens[i]
|
||||
context_len = seq_len - q_len
|
||||
|
||||
assert context_len >= 0
|
||||
attn_mask[current_row:current_row + q_len,
|
||||
context_len:] = NPU_PAGED_ATTENTION_MASK_VALUE
|
||||
right_tensor = attn_mask[current_row:current_row + q_len,
|
||||
context_len:seq_len]
|
||||
right_tensor.mask_fill_(
|
||||
right_tensor.tril() == NPU_PAGED_ATTENTION_MASK_VALUE, 0)
|
||||
current_row += q_len
|
||||
|
||||
return attn_mask.to(self.device, non_blocking=True)
|
||||
def _make_attention_mask(self, seq_lens, query_lens, position,
|
||||
attn_state) -> torch.Tensor:
|
||||
# Chunk Prefill situation.
|
||||
if attn_state == AscendAttentionState.ChunkedPrefill:
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask(
|
||||
seq_lens, query_lens, position, self.dtype, self.device)
|
||||
# Prefill-only situation.
|
||||
elif attn_state == AscendAttentionState.PrefillOnly:
|
||||
max_seq_len = max(seq_lens, default=0)
|
||||
return self.attn_mask_builder.get_attn_mask(
|
||||
max_seq_len, self.dtype, self.device)
|
||||
# Decode-only situation.
|
||||
else:
|
||||
return None
|
||||
|
||||
def _process_reqs(
|
||||
self,
|
||||
@@ -408,6 +414,9 @@ class NPUModelRunner:
|
||||
cu_num_tokens = np.cumsum(num_scheduled_tokens)
|
||||
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
|
||||
num_scheduled_tokens)
|
||||
sample_indices = cu_num_tokens - 1
|
||||
sample_indices = torch.from_numpy(sample_indices).to(self.device,
|
||||
non_blocking=True)
|
||||
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
|
||||
|
||||
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
||||
@@ -437,9 +446,18 @@ class NPUModelRunner:
|
||||
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
|
||||
self.device, non_blocking=True)
|
||||
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
||||
attn_state = AscendAttentionState.PrefillOnly
|
||||
elif np.all(num_scheduled_tokens == 1):
|
||||
attn_state = AscendAttentionState.DecodeOnly
|
||||
else:
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
|
||||
attn_mask = self._make_attention_mask(seq_lens=seq_lens,
|
||||
query_lens=num_scheduled_tokens,
|
||||
position=positions)
|
||||
position=positions,
|
||||
attn_state=attn_state)
|
||||
|
||||
attn_metadata = AscendMetadata(
|
||||
seq_lens=query_lens,
|
||||
@@ -448,6 +466,7 @@ class NPUModelRunner:
|
||||
block_tables=(
|
||||
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state,
|
||||
)
|
||||
|
||||
# Prepare input_ids
|
||||
@@ -472,7 +491,7 @@ class NPUModelRunner:
|
||||
inputs_embeds=None,
|
||||
)
|
||||
|
||||
return hidden_states[cu_num_tokens - 1]
|
||||
return hidden_states[sample_indices]
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
@@ -636,7 +655,7 @@ class NPUModelRunner:
|
||||
self.intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors(
|
||||
batch_size=self.max_num_tokens,
|
||||
dtype=self.model_config.dtype,
|
||||
dtype=self.dtype,
|
||||
device=self.device))
|
||||
intermediate_tensors = IntermediateTensors({
|
||||
k: v[:self.max_num_tokens]
|
||||
@@ -708,6 +727,7 @@ class NPUModelRunner:
|
||||
kv_cache_config: Configuration for the KV cache, including the KV
|
||||
cache size of each layer
|
||||
"""
|
||||
import torch_npu
|
||||
kv_caches: Dict[str, torch.Tensor] = {}
|
||||
for kv_cache_group in kv_cache_config.kv_cache_groups:
|
||||
kv_cache_spec = kv_cache_group.kv_cache_spec
|
||||
@@ -724,13 +744,14 @@ class NPUModelRunner:
|
||||
# the min of all `num_blocks`. Verify it here.
|
||||
assert num_blocks >= kv_cache_config.num_blocks
|
||||
if isinstance(kv_cache_spec, FullAttentionSpec):
|
||||
kv_cache_shape = AscendAttentionBackend.get_kv_cache_shape(
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
torch_npu.npu_format_cast(kv_caches[layer_name], 2)
|
||||
else:
|
||||
# TODO: add new branches when introducing more types of
|
||||
# KV cache specs.
|
||||
|
||||
Reference in New Issue
Block a user