diff --git a/tests/scheduler/test_scheduler.py b/tests/scheduler/test_scheduler.py new file mode 100644 index 0000000..6eddd4f --- /dev/null +++ b/tests/scheduler/test_scheduler.py @@ -0,0 +1,396 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/blob/main/tests/models/utils.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Optional + +import pytest +import torch +from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig +from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.sampling_params import SamplingParams +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request, RequestStatus +from vllm.v1.structured_output import StructuredOutputManager + +from vllm_ascend.core.scheduler import AscendScheduler + +EOS_TOKEN_ID = 50256 + + +def create_scheduler( + model: str = "facebook/opt-125m", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 8192, + enable_prefix_caching: Optional[bool] = None, + long_prefill_token_threshold: int = 0, + disable_chunked_mm_input: bool = False, +) -> AscendScheduler: + '''Create scheduler under test. + + Args: + model: model under test + max_num_seqs: max sequences to schedule + max_num_batch_tokens: max num tokens to batch + enable_prefix_caching: optionally force APC config + (True/False) or use default + (None) + + Returns: + :class:`Scheduler` instance + ''' + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_num_batched_tokens, + long_prefill_token_threshold=long_prefill_token_threshold, + disable_chunked_mm_input=disable_chunked_mm_input, + ) + model_config = ModelConfig( + model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=42, + ) + # Cache config, optionally force APC + kwargs_cache = ({} if enable_prefix_caching is None else { + 'enable_prefix_caching': enable_prefix_caching + }) + cache_config = CacheConfig( + block_size=16, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + **kwargs_cache, + ) + vllm_config = VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + ) + kv_cache_config = KVCacheConfig( + num_blocks=10000, # A large number of blocks to hold all requests + tensors={}, + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(16, 1, 1, torch.float32, False)) + ], + ) + cache_config.num_gpu_blocks = 10000 + return AscendScheduler( + scheduler_config, + model_config, + cache_config, + lora_config=None, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + ) + + +def create_requests(num_requests: int, + num_tokens: int = 10, + mm_positions: Optional[list[PlaceholderRange]] = None, + max_tokens: int = 16, + stop_token_ids: Optional[list[int]] = None, + prompt_logprobs: Optional[int] = None): + sampling_params = SamplingParams(ignore_eos=False, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids, + prompt_logprobs=prompt_logprobs) + requests = [] + for i in range(num_requests): + if mm_positions is not None: + mm_position = mm_positions[i] + mm_inputs = [MultiModalKwargs({})] * len(mm_position) + else: + mm_position = None + mm_inputs = None + request = Request( + request_id=f"{i}", + prompt=None, + prompt_token_ids=[i] * num_tokens, + sampling_params=sampling_params, + multi_modal_inputs=mm_inputs, + multi_modal_placeholders=mm_position, + multi_modal_hashes=None, + eos_token_id=EOS_TOKEN_ID, + arrival_time=0, + ) + requests.append(request) + return requests + + +def test_add_requests(): + scheduler = create_scheduler() + requests = create_requests(num_requests=10) + + for i, request in enumerate(requests): + scheduler.add_request(request) + assert request.request_id in scheduler.requests + assert len(scheduler.waiting) == i + 1 + + +def test_finish_request(): + scheduler = create_scheduler() + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + for i, request in enumerate(requests): + scheduler.finish_requests(request.request_id, + RequestStatus.FINISHED_ABORTED) + assert request.request_id not in scheduler.requests + assert len(scheduler.waiting) == 9 - i + + +def test_get_num_unfinished_requests(): + scheduler = create_scheduler() + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + for i, request in enumerate(requests): + scheduler.finish_requests(request.request_id, + RequestStatus.FINISHED_STOPPED) + assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1 + + +@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [ + (None, None), + (True, 5), +]) +def test_schedule(enable_prefix_caching: Optional[bool], + prompt_logprobs: Optional[int]): + '''Test scheduling. + Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs + ''' + scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching) + requests = create_requests(num_requests=10, + prompt_logprobs=prompt_logprobs) + for request in requests: + scheduler.add_request(request) + + # Test initial scheduling + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == len(requests) + assert len(output.scheduled_cached_reqs) == 0 + assert len(output.finished_req_ids) == 0 + # Verify all requests are scheduled. + for req_id, num_tokens in output.num_scheduled_tokens.items(): + assert num_tokens == len(requests[int(req_id)].prompt_token_ids) + + # Verify requests moved from waiting to running + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == len(requests) + for i, request in enumerate(requests): + assert scheduler.running[i] == request + + +def test_stop_via_update_from_output(): + """Test stopping behavior through update_from_output""" + scheduler = create_scheduler() + + # Test case 1: Stop on EOS token + requests = create_requests(num_requests=2, max_tokens=10) + for req in requests: + req.num_computed_tokens = req.num_tokens + scheduler.requests[req.request_id] = req + scheduler.running.append(req) + scheduler.scheduled_req_ids.add(req.request_id) + + scheduler_output = SchedulerOutput(scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 1, + requests[1].request_id: 2 + }, + total_num_scheduled_tokens=3, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [], + requests[1].request_id: [10] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None) + + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={req.request_id: i + for i, req in enumerate(requests)}, + sampled_token_ids=[[EOS_TOKEN_ID], + [10, + 11]], # First request hits EOS, second continues + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify first request stopped, second continues + assert len(scheduler.running) == 1 + assert scheduler.running[0].request_id == requests[1].request_id + assert requests[0].status == RequestStatus.FINISHED_STOPPED + assert requests[0].request_id in scheduler.finished_req_ids + assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID] + assert list(requests[1].output_token_ids) == [10, 11] + + # Test case 2: Stop on custom stop token + scheduler = create_scheduler() + requests = create_requests(num_requests=2, + max_tokens=10, + stop_token_ids=[42, 43]) + for req in requests: + req.num_computed_tokens = req.num_tokens + scheduler.requests[req.request_id] = req + scheduler.running.append(req) + scheduler.scheduled_req_ids.add(req.request_id) + + scheduler_output = SchedulerOutput(scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 3, + requests[1].request_id: 2 + }, + total_num_scheduled_tokens=5, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [10, 42], + requests[1].request_id: [13] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None) + + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={req.request_id: i + for i, req in enumerate(requests)}, + sampled_token_ids=[[10, 42, 12], + [13, 14]], # First request hits stop token + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify first request stopped on custom token + assert len(scheduler.running) == 1 + assert scheduler.running[0].request_id == requests[1].request_id + assert requests[0].status == RequestStatus.FINISHED_STOPPED + assert requests[0].stop_reason == 42 + assert requests[0].request_id in scheduler.finished_req_ids + assert list(requests[0].output_token_ids) == [10, 42] + assert list(requests[1].output_token_ids) == [13, 14] + + # Test case 3: Stop on max tokens + scheduler = create_scheduler() + requests = create_requests(num_requests=2, max_tokens=2) + for req in requests: + req.num_computed_tokens = req.num_tokens + scheduler.requests[req.request_id] = req + scheduler.running.append(req) + scheduler.scheduled_req_ids.add(req.request_id) + + scheduler_output = SchedulerOutput(scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 3, + requests[1].request_id: 1 + }, + total_num_scheduled_tokens=4, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [10, 11], + requests[1].request_id: [] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None) + + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={req.request_id: i + for i, req in enumerate(requests)}, + sampled_token_ids=[[10, 11, 12], + [13]], # First request exceeds max_tokens + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify first request stopped due to length + assert len(scheduler.running) == 1 + assert scheduler.running[0].request_id == requests[1].request_id + assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED + assert requests[0].request_id in scheduler.finished_req_ids + assert list(requests[0].output_token_ids) == [10, 11 + ] # Truncated to max_tokens + assert list(requests[1].output_token_ids) == [13] + + # Test case 4: Ignore EOS flag + scheduler = create_scheduler() + requests = create_requests(num_requests=1, max_tokens=10) + requests[0].sampling_params.ignore_eos = True + requests[0].num_computed_tokens = requests[0].num_tokens + scheduler.requests[requests[0].request_id] = requests[0] + scheduler.running.append(requests[0]) + scheduler.scheduled_req_ids.add(requests[0].request_id) + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={requests[0].request_id: 3}, + total_num_scheduled_tokens=3, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [EOS_TOKEN_ID, 10] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None) + + model_output = ModelRunnerOutput( + req_ids=[requests[0].request_id], + req_id_to_index={requests[0].request_id: 0}, + sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify request continues past EOS + assert len(scheduler.running) == 1 + assert not requests[0].is_finished() + assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11] diff --git a/vllm_ascend/attention/attention.py b/vllm_ascend/attention/attention.py index 45f5e25..5a9082f 100644 --- a/vllm_ascend/attention/attention.py +++ b/vllm_ascend/attention/attention.py @@ -43,7 +43,7 @@ if TYPE_CHECKING: ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata) -def generate_attn_mask(max_seq_len: int, dtype=torch.float16): +def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None): # Construct lower triangle matrix. mask_flag = torch.tril( torch.ones((max_seq_len, max_seq_len), @@ -52,10 +52,11 @@ def generate_attn_mask(max_seq_len: int, dtype=torch.float16): mask_flag = ~mask_flag # Currently for fp16 dtype, the mask value should be set to -inf. # TODO: Eliminate this part in the future. - if dtype == torch.float16: - mask_value = torch.finfo(torch.float32).min - else: - mask_value = 1 + if mask_value is None: + if dtype == torch.float16: + mask_value = torch.finfo(torch.float32).min + else: + mask_value = 1 attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value).to(dtype) return attn_mask @@ -66,12 +67,14 @@ class AttentionMaskBuilder: def __init__(self, attn_mask: torch.Tensor): self._seq_len_cached = attn_mask.shape[0] self.attn_mask_cache = attn_mask + self.splitfuse_mask_value = -10000 @classmethod def initialize_from_len(cls, max_seq_len: int, - dtype: torch.dtype = torch.float16): - return cls(generate_attn_mask(max_seq_len, dtype)) + dtype: torch.dtype = torch.float16, + mask_value: Optional[int] = None): + return cls(generate_attn_mask(max_seq_len, dtype, mask_value)) def update_attn_cache(self, seqlen: int, dtype: torch.dtype, device: torch.device): @@ -97,6 +100,49 @@ class AttentionMaskBuilder: return (self.attn_mask_cache.index_select( 0, input_lengths)[:, :max_s].view(-1, 1, max_s).contiguous()) + def get_splitfuse_attn_mask( + self, + seq_lens, + query_lens, + position, + dtype, + device, + ) -> torch.Tensor: + max_seq_len = max(seq_lens, default=0) + if max_seq_len <= self._seq_len_cached: + self.update_attn_cache(max_seq_len, dtype, device) + # FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation + # is not the same. Fix this in the future when kernel is ready. + if self.attn_mask_cache[0][1] > 0: + attn_mask = self.get_attn_mask( # type: ignore + max_seq_len, dtype, device) + attn_mask *= -10000 + else: + attn_mask = self.attn_mask_cache + return torch.index_select(attn_mask, dim=0, + index=position)[:, :max_seq_len] + total_q_len = sum(query_lens) + attn_mask = torch.zeros((total_q_len, max_seq_len), + dtype=dtype, + device="cpu") + + current_row = 0 + for i in range(len(query_lens)): + seq_len = seq_lens[i] + q_len = query_lens[i] + context_len = seq_len - q_len + + assert context_len >= 0 + attn_mask[current_row:current_row + q_len, + context_len:] = self.splitfuse_mask_value + right_tensor = attn_mask[current_row:current_row + q_len, + context_len:seq_len] + right_tensor.mask_fill_( + right_tensor.tril() == self.splitfuse_mask_value, 0) + current_row += q_len + + return attn_mask.to(device, non_blocking=True) + class AscendAttentionBackend(AttentionBackend): diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 6cd9a3a..22e6b35 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -16,6 +16,7 @@ # from dataclasses import dataclass +from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Type import torch @@ -50,7 +51,7 @@ class AscendAttentionBackend(AttentionBackend): num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - return (2, num_blocks, block_size, num_kv_heads * head_size) + return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def swap_blocks( @@ -83,6 +84,12 @@ class AscendAttentionBackend(AttentionBackend): value_caches[dst_indices] = value_caches[src_indices] +class AscendAttentionState(Enum): + PrefillOnly = 0 + DecodeOnly = 1 + ChunkedPrefill = 2 + + @dataclass class AscendMetadata: # (batch_size, max_blocks_per_seq). @@ -104,6 +111,8 @@ class AscendMetadata: # FlashAttention has better performance than PageAtttention, # but it does not support decode requests. is_only_prefill: bool = False + # Current state of this attention run. + attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill attn_mask: Optional[torch.Tensor] = None @@ -139,7 +148,8 @@ class AscendAttentionBackendImpl(AttentionImpl): assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.seq_len_cpu_tensor = None + self.key_cache = None + self.value_cache = None def forward( self, @@ -190,30 +200,52 @@ class AscendAttentionBackendImpl(AttentionImpl): # TODO: Remove this contiguous in the future. value = value.contiguous() + if kv_cache.numel() > 0: + if self.key_cache is None: + self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] + slots = attn_metadata.slot_mapping + torch_npu._npu_reshape_and_cache(key=key, + value=value, + key_cache=self.key_cache, + value_cache=self.value_cache, + slot_indices=slots) + if hasattr(layer, 'quant_method'): # TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata pass + # V0-Style scheduler situation. + elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly: + assert attn_metadata is not None + assert attn_metadata.attn_mask is not None + mask = attn_metadata.attn_mask + torch_npu._npu_flash_attention(query=query, + key=key, + value=value, + mask=mask, + seq_len=attn_metadata.seq_lens, + scale_value=self.scale, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + out=output) + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + block_tables = attn_metadata.block_tables + torch_npu._npu_paged_attention( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=block_tables, + context_lens=attn_metadata.context_lens, + out=output) + # Normal V1 situation. else: - if kv_cache.numel() > 0: - key_cache, value_cache = kv_cache[0], kv_cache[1] - num_blocks, block_size, _ = key_cache.shape - key_cache = key_cache.view(num_blocks, block_size, - self.num_kv_heads, self.head_size) - value_cache = value_cache.view(num_blocks, block_size, - self.num_kv_heads, - self.head_size) - slots = attn_metadata.slot_mapping - torch_npu._npu_reshape_and_cache(key=key, - value=value, - key_cache=key_cache, - value_cache=value_cache, - slot_indices=slots) - # use paged attention torch_npu._npu_paged_attention_splitfuse( query=query, - key_cache=key_cache, - value_cache=value_cache, + key_cache=self.key_cache, + value_cache=self.value_cache, mask=attn_metadata.attn_mask, block_table=attn_metadata.block_tables, seq_len=attn_metadata.seq_lens, diff --git a/vllm_ascend/core/__init__.py b/vllm_ascend/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_ascend/core/schedule_config.py b/vllm_ascend/core/schedule_config.py new file mode 100644 index 0000000..4861bfc --- /dev/null +++ b/vllm_ascend/core/schedule_config.py @@ -0,0 +1,73 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from dataclasses import dataclass, fields +from typing import Type, Union + +from vllm.config import SchedulerConfig + + +@dataclass +class AscendSchedulerConfig(SchedulerConfig): + enable_chunked_prefill: bool = False + policy: str = "fcfs" + num_scheduler_steps: int = 1 + scheduler_cls: Union[str, Type[object]] = ( + "vllm_ascend.core.scheduler.AscendScheduler") + + @classmethod + def initialize_from_config( + cls, + vllm_scheduler_config: SchedulerConfig, + ascend_scheduler_config: dict, + ): + scheduler_config = { + field.name: getattr(vllm_scheduler_config, field.name) + for field in fields(vllm_scheduler_config) if field.init + } + # Override default values into original SchedulerConfig + scheduler_config["enable_chunked_prefill"] = False + scheduler_config["policy"] = "fcfs" + scheduler_config["num_scheduler_steps"] = 1 + scheduler_config["scheduler_cls"] = ( + "vllm_ascend.core.scheduler.AscendScheduler") + # Override params in original SchedulerConfig with params in additional_config.ascend_scheduler_config + for k, v in ascend_scheduler_config.items(): + scheduler_config[k] = v + return cls(**scheduler_config) + + def __post_init__(self) -> None: + self.max_num_encoder_input_tokens = self.max_num_batched_tokens + self.encoder_cache_size = self.max_num_batched_tokens + self.chunked_prefill_enabled = self.enable_chunked_prefill + if self.policy != "fcfs": + raise NotImplementedError( + f"currently AscendScheduler only supports fcfs policy, got {self.policy}" + ) + if self.is_multimodal_model: + raise NotImplementedError( + "currently AscendScheduler only supports LLM modles.") + if self.num_scheduler_steps > 1: + raise NotImplementedError( + "currently AscendScheduler doesn't support multi-step.") + if self.send_delta_data: + raise NotImplementedError( + "currently AscendScheduler doesn't support send_delta_data.") + if self.delay_factor > 0: + raise NotImplementedError( + "currently AscendScheduler doesn't support scheduler_delay_factor." + ) diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py new file mode 100644 index 0000000..348d7e7 --- /dev/null +++ b/vllm_ascend/core/scheduler.py @@ -0,0 +1,305 @@ +from collections import deque + +from vllm.logger import init_logger +from vllm.utils import cdiv +from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.request import Request, RequestStatus + +logger = init_logger(__name__) + + +class AscendScheduler(Scheduler): + """This Scheduler extends vllm's original v1 scheduler + with prefill-first scheduling strategy.""" + + def schedule(self) -> SchedulerOutput: + if self.scheduler_config.chunked_prefill_enabled: + return super().schedule() + scheduled_new_reqs: list[Request] = [] + scheduled_resumed_reqs: list[Request] = [] + scheduled_running_reqs: list[Request] = [] + preempted_reqs: list[Request] = [] + + req_to_new_block_ids: dict[str, list[int]] = {} + num_scheduled_tokens: dict[str, int] = {} + token_budget = self.max_num_scheduled_tokens + # Spec decode-related. + scheduled_spec_decode_tokens: dict[str, list[int]] = {} + + # Record scheduled LoRA requests. + scheduled_loras: set[int] = set() + + # Use a temporary deque to collect requests that need to be skipped + # and put back at the head of the waiting queue later + skipped_waiting_requests: deque[Request] = deque() + + # Schedule prefill requests first. + while self.waiting and token_budget > 0: + if len(scheduled_new_reqs) == self.max_num_running_reqs: + break + + request = self.waiting[0] + + def skip_cur_request(): + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + + # Check that adding the request still respects the max_loras + # constraint. + if (self.lora_config and request.lora_request and + (len(scheduled_loras) == self.lora_config.max_loras + and request.lora_request.lora_int_id not in scheduled_loras)): + # Scheduling would exceed max_loras, skip. + skip_cur_request() + continue + + prompt_limit = self._get_prompt_limit(request) + # Get already-cached tokens. + computed_blocks, num_computed_tokens = ( + self.kv_cache_manager.get_computed_blocks(request)) + num_new_tokens = request.num_prompt_tokens - num_computed_tokens + if (0 < self.scheduler_config.long_prefill_token_threshold < + num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + max_tokens_in_kvcache = (self.kv_cache_config.num_blocks * + self.block_size) + prompt_limit = min(prompt_limit, max_tokens_in_kvcache) + + # Finish request that exceeds prompt_limit or kv cache size. + if num_new_tokens > prompt_limit: + logger.warning( + "Input prompt (%d tokens) is too long" + " and exceeds limit of %d", + num_new_tokens, + prompt_limit, + ) + request.status = RequestStatus.FINISHED_IGNORED + self.finished_req_ids.add(request.request_id) # type: ignore + self.waiting.popleft() + continue + + if num_new_tokens > token_budget: + # Scheduling would exceed token_budget, skip. + skip_cur_request() + continue + + assert num_new_tokens > 0 + watermark = getattr(self.scheduler_config, "watermark", 0.01) + if not self._check_watermark_for_prefill( + request, num_new_tokens, computed_blocks, watermark): + # Scheduling would exceed watermark, skip. + skip_cur_request() + continue + + new_blocks = self.kv_cache_manager.allocate_slots( + request, num_new_tokens, computed_blocks) + if new_blocks is None: + # The request cannot be scheduled. + break + + self.waiting.popleft() + self.running.append(request) + self.scheduled_req_ids.add(request.request_id) + # Check request status. + if request.status == RequestStatus.WAITING: + scheduled_new_reqs.append(request) + elif request.status == RequestStatus.PREEMPTED: + scheduled_resumed_reqs.append(request) + else: + raise RuntimeError(f"Invalid request status: {request.status}") + + if self.lora_config and request.lora_request: + scheduled_loras.add(request.lora_request.lora_int_id) + req_to_new_block_ids[request.request_id] = [ + b.block_id for b in computed_blocks + new_blocks + ] + # Update request info. + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + request.status = RequestStatus.RUNNING + request.num_computed_tokens = num_computed_tokens + + # Put back any skipped requests at the head of the waiting queue + if skipped_waiting_requests: + self.waiting.extendleft(skipped_waiting_requests) + + # If no prefill requests are scheduled, + # Schedule decode requests next. + if len(self.scheduled_req_ids) == 0: + req_index = 0 + while req_index < len(self.running) and token_budget > 0: + request = self.running[req_index] + if request.request_id in self.scheduled_req_ids: + # This request has already been scheduled. + req_index += 1 + continue + + num_new_tokens = (request.num_tokens_with_spec - + request.num_computed_tokens) + if (0 < self.scheduler_config.long_prefill_token_threshold < + num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens == 1 + + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, num_new_tokens) + if new_blocks is None: + # The request cannot be scheduled. + # Preempt the lowest-priority request. + preempted_req = self.running.pop() + self.kv_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + self.waiting.appendleft(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. + can_schedule = False + break + else: + # The request can be scheduled. + can_schedule = True + break + if not can_schedule: + break + assert new_blocks is not None + + # Schedule the request. + scheduled_running_reqs.append(request) + self.scheduled_req_ids.add(request.request_id) + req_to_new_block_ids[request.request_id] = [ + b.block_id for b in new_blocks + ] + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + req_index += 1 + + # Speculative decode related. + if request.spec_token_ids: + num_scheduled_spec_tokens = (num_new_tokens + + request.num_computed_tokens - + request.num_tokens) + if num_scheduled_spec_tokens > 0: + # Trim spec_token_ids list to num_scheduled_spec_tokens. + del request.spec_token_ids[num_scheduled_spec_tokens:] + scheduled_spec_decode_tokens[request.request_id] = ( + request.spec_token_ids) + + # Check if the scheduling constraints are satisfied. + total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert token_budget >= 0 + assert len(self.running) <= self.max_num_running_reqs + assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( + scheduled_running_reqs) <= len(self.running) + + # Get the longest common prefix among all requests in the running queue. + # This can be potentially used for cascade attention. + num_common_prefix_blocks = 0 + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request, len(self.running))) + + # Construct the scheduler output. + new_reqs_data = [ + NewRequestData.from_request(req, + req_to_new_block_ids[req.request_id]) + for req in scheduled_new_reqs + ] + resumed_reqs_data = [ + self._make_cached_request_data( + req, + num_scheduled_tokens[req.request_id], + len(scheduled_spec_decode_tokens.get(req.request_id, ())), + req_to_new_block_ids[req.request_id], + resumed_from_preemption=True, + ) for req in scheduled_resumed_reqs + ] + running_reqs_data = [ + self._make_cached_request_data( + req, + num_scheduled_tokens[req.request_id], + len(scheduled_spec_decode_tokens.get(req.request_id, ())), + req_to_new_block_ids[req.request_id], + resumed_from_preemption=False, + ) for req in scheduled_running_reqs + ] + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_cached_reqs=resumed_reqs_data + running_reqs_data, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=num_common_prefix_blocks, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between + # the previous and the current steps. + finished_req_ids=self.finished_req_ids, # type: ignore + free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), + structured_output_request_ids={}, + grammar_bitmask=None, + ) + + # Advance the number of computed tokens for the request AFTER + # the request is scheduled. + # 1. The scheduler_output of the current step has to include the + # original number of scheduled tokens to determine input IDs. + # 2. Advance the number of computed tokens here allowing us to + # schedule the prefill request again immediately in the next + # scheduling step. + # 3. If some tokens (e.g. spec tokens) are rejected later, the number of + # computed tokens will be adjusted in update_from_output. + for req_id, num_scheduled_token in num_scheduled_tokens.items(): + self.requests[req_id].num_computed_tokens += num_scheduled_token + + self.finished_req_ids = set() # type: ignore + return scheduler_output + + def _check_watermark_for_prefill(self, + request, + num_new_tokens, + computed_blocks, + watermark=0.01): + computed_blocks = computed_blocks or [] + watermark_blocks = self.kv_cache_config.num_blocks * watermark + num_computed_tokens = (request.num_computed_tokens + + len(computed_blocks) * self.block_size) + num_required_blocks = cdiv(num_new_tokens + num_computed_tokens, + self.block_size) + req_blocks = self.kv_cache_manager.req_to_blocks[request.request_id] + num_new_blocks = (num_required_blocks - len(req_blocks) - + len(computed_blocks)) + num_evictable_computed_blocks = sum(1 for blk in computed_blocks + if blk.ref_cnt == 0) + # If number of free blocks is less than water mark after allocating, don't allocate. + if (self.kv_cache_manager.block_pool.get_num_free_blocks() - + num_evictable_computed_blocks - + num_new_blocks) < watermark_blocks: + return False + return True + + def _get_prompt_limit(self, request: Request) -> int: + if (self.scheduler_config.chunked_prefill_enabled + and not self.scheduler_config.is_multi_step): + prompt_limit = self.scheduler_config.max_model_len + else: + prompt_limit = min( + self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens, + ) + + # Model is fine tuned with long context. Return the fine tuned max_len. + if request.lora_request and request.lora_request.long_lora_max_len: + assert prompt_limit <= request.lora_request.long_lora_max_len + return request.lora_request.long_lora_max_len + else: + return prompt_limit diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index c54bdfe..ff35808 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -132,6 +132,22 @@ class NPUPlatform(Platform): ) cache_config.enable_prefix_caching = False + if envs.VLLM_USE_V1: + # Activate custom ops for v1. + vllm_config.compilation_config.custom_ops = ["all"] + additional_config = vllm_config.additional_config + # If ascend_scheduler_config exists in additional_config, + # extents original scheduler_config to use AscendScheduler. + if additional_config and additional_config.get( + "ascend_scheduler_config", None) is not None: + additional_scheduler_config = additional_config.get( + "ascend_scheduler_config") + from vllm_ascend.core.schedule_config import \ + AscendSchedulerConfig + ascend_scheduler_config = AscendSchedulerConfig.initialize_from_config( + vllm_config.scheduler_config, additional_scheduler_config) + vllm_config.scheduler_config = ascend_scheduler_config + @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla): diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index d9beb35..f5438e3 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -25,7 +25,7 @@ import numpy as np import numpy.typing as npt import torch import torch.nn as nn -from vllm.attention import AttentionType +from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention from vllm.config import VllmConfig from vllm.distributed.parallel_state import get_pp_group @@ -37,7 +37,8 @@ from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors -from vllm.utils import DeviceMemoryProfiler, LayerBlockType, cdiv +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, + LayerBlockType, cdiv) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) @@ -45,15 +46,14 @@ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch -from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend, +from vllm_ascend.attention.attention import AttentionMaskBuilder +from vllm_ascend.attention.attention_v1 import (AscendAttentionState, AscendMetadata) from vllm_ascend.platform import NPUPlatform if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput -NPU_PAGED_ATTENTION_MASK_VALUE = -10000 - class NPUModelRunner: @@ -74,6 +74,32 @@ class NPUModelRunner: self.num_attn_layers = self.model_config.get_num_layers_by_block_type( vllm_config.parallel_config, LayerBlockType.attention) self.hidden_size = self.model_config.get_hidden_size() + self.dtype = self.model_config.dtype + cache_config = vllm_config.cache_config + if cache_config.cache_dtype == "auto": + self.kv_cache_dtype = self.dtype + else: + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + cache_config.cache_dtype] + + self.head_size = self.model_config.get_head_size() + self.attn_backend = get_attn_backend( + self.head_size, + self.dtype, + self.kv_cache_dtype, + self.block_size, + self.model_config.is_attention_free, + use_mla=self.model_config.use_mla, + ) + if self.attn_backend is None: + error_msg = ( + f"Error with get_att_backend: {self.head_size=}, " + f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, " + f"{self.model_config.is_attention_free=}, " + f"{self.model_config.use_mla=}") + logger.error(error_msg) + raise NotImplementedError( + "Non-Attention backend is not supported by V1 NPUModelRunner.") # Multi-modal data support self.input_registry = INPUT_REGISTRY @@ -135,7 +161,7 @@ class NPUModelRunner: self.inputs_embeds = torch.zeros( (self.max_num_tokens, self.hidden_size), - dtype=self.model_config.dtype, + dtype=self.dtype, device=self.device) # OPTIMIZATION: Cache the tensors rather than creating them every step. @@ -183,13 +209,8 @@ class NPUModelRunner: mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000) self.attn_mask_len = min(self.model_config.max_model_len, int(mask_len)) - self.attn_mask_npu = torch.full( - (self.attn_mask_len, self.attn_mask_len), - NPU_PAGED_ATTENTION_MASK_VALUE, - device=self.device, - dtype=self.vllm_config.model_config.dtype) - self.attn_mask_npu.masked_fill_( - self.attn_mask_npu.tril() == NPU_PAGED_ATTENTION_MASK_VALUE, 0) + self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len( + self.attn_mask_len, self.dtype) def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler @@ -346,35 +367,20 @@ class NPUModelRunner: def get_model(self) -> nn.Module: return self.model - def _make_attention_mask(self, seq_lens, query_lens, - position) -> torch.Tensor: - max_seq_len = max(seq_lens, default=0) - if max_seq_len <= self.attn_mask_len: - return torch.index_select(self.attn_mask_npu, - dim=0, - index=position)[:, :max_seq_len] - - total_q_len = sum(query_lens) - attn_mask = torch.zeros((total_q_len, max_seq_len), - dtype=self.vllm_config.model_config.dtype, - device="cpu") - - current_row = 0 - for i in range(len(query_lens)): - seq_len = seq_lens[i] - q_len = query_lens[i] - context_len = seq_len - q_len - - assert context_len >= 0 - attn_mask[current_row:current_row + q_len, - context_len:] = NPU_PAGED_ATTENTION_MASK_VALUE - right_tensor = attn_mask[current_row:current_row + q_len, - context_len:seq_len] - right_tensor.mask_fill_( - right_tensor.tril() == NPU_PAGED_ATTENTION_MASK_VALUE, 0) - current_row += q_len - - return attn_mask.to(self.device, non_blocking=True) + def _make_attention_mask(self, seq_lens, query_lens, position, + attn_state) -> torch.Tensor: + # Chunk Prefill situation. + if attn_state == AscendAttentionState.ChunkedPrefill: + return self.attn_mask_builder.get_splitfuse_attn_mask( + seq_lens, query_lens, position, self.dtype, self.device) + # Prefill-only situation. + elif attn_state == AscendAttentionState.PrefillOnly: + max_seq_len = max(seq_lens, default=0) + return self.attn_mask_builder.get_attn_mask( + max_seq_len, self.dtype, self.device) + # Decode-only situation. + else: + return None def _process_reqs( self, @@ -408,6 +414,9 @@ class NPUModelRunner: cu_num_tokens = np.cumsum(num_scheduled_tokens) cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, num_scheduled_tokens) + sample_indices = cu_num_tokens - 1 + sample_indices = torch.from_numpy(sample_indices).to(self.device, + non_blocking=True) arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets positions_np = self.positions_np[:total_num_scheduled_tokens] @@ -437,9 +446,18 @@ class NPUModelRunner: slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to( self.device, non_blocking=True) + attn_state = AscendAttentionState.ChunkedPrefill + if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens): + attn_state = AscendAttentionState.PrefillOnly + elif np.all(num_scheduled_tokens == 1): + attn_state = AscendAttentionState.DecodeOnly + else: + attn_state = AscendAttentionState.ChunkedPrefill + attn_mask = self._make_attention_mask(seq_lens=seq_lens, query_lens=num_scheduled_tokens, - position=positions) + position=positions, + attn_state=attn_state) attn_metadata = AscendMetadata( seq_lens=query_lens, @@ -448,6 +466,7 @@ class NPUModelRunner: block_tables=( self.input_batch.block_table.get_device_tensor()[:num_reqs]), attn_mask=attn_mask, + attn_state=attn_state, ) # Prepare input_ids @@ -472,7 +491,7 @@ class NPUModelRunner: inputs_embeds=None, ) - return hidden_states[cu_num_tokens - 1] + return hidden_states[sample_indices] @torch.inference_mode() def execute_model( @@ -636,7 +655,7 @@ class NPUModelRunner: self.intermediate_tensors = ( self.model.make_empty_intermediate_tensors( batch_size=self.max_num_tokens, - dtype=self.model_config.dtype, + dtype=self.dtype, device=self.device)) intermediate_tensors = IntermediateTensors({ k: v[:self.max_num_tokens] @@ -708,6 +727,7 @@ class NPUModelRunner: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ + import torch_npu kv_caches: Dict[str, torch.Tensor] = {} for kv_cache_group in kv_cache_config.kv_cache_groups: kv_cache_spec = kv_cache_group.kv_cache_spec @@ -724,13 +744,14 @@ class NPUModelRunner: # the min of all `num_blocks`. Verify it here. assert num_blocks >= kv_cache_config.num_blocks if isinstance(kv_cache_spec, FullAttentionSpec): - kv_cache_shape = AscendAttentionBackend.get_kv_cache_shape( + kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype kv_caches[layer_name] = torch.zeros(kv_cache_shape, dtype=dtype, device=self.device) + torch_npu.npu_format_cast(kv_caches[layer_name], 2) else: # TODO: add new branches when introducing more types of # KV cache specs. diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 0bd18d5..73dde65 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -40,6 +40,7 @@ from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.worker_base import WorkerBase from vllm_ascend.platform import NPUPlatform +from vllm_ascend.utils import try_register_lib from vllm_ascend.worker.model_runner_v1 import NPUModelRunner @@ -66,6 +67,11 @@ class NPUWorker(WorkerBase): rank=rank, distributed_init_method=distributed_init_method, is_driver_worker=is_driver_worker) + # Try to import mindie_turbo to accelerate vLLM inference. + try_register_lib( + "mindie_turbo", + "MindIE Turbo is installed. vLLM inference will be accelerated with MindIE Turbo." + ) if self.cache_config.cache_dtype == "auto": self.cache_dtype = self.model_config.dtype else: