[Scheduler] Add AscendScheduler. (#543)

This PR adds AscendScheduler to vllm v1 engine.
This scheduler currently supports v0-style prefill-first scheduling
strategy.
In the future more schedule methods will be supported by this scheduler.

---------

Signed-off-by: hw_whx <wanghexiang7@huawei.com>
Co-authored-by: hw_whx <wanghexiang7@huawei.com>
This commit is contained in:
whx
2025-04-17 19:31:50 +08:00
committed by GitHub
parent 697908f5cd
commit 20dff4deff
9 changed files with 967 additions and 72 deletions

View File

@@ -0,0 +1,396 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/blob/main/tests/models/utils.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional
import pytest
import torch
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from vllm_ascend.core.scheduler import AscendScheduler
EOS_TOKEN_ID = 50256
def create_scheduler(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
enable_prefix_caching: Optional[bool] = None,
long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False,
) -> AscendScheduler:
'''Create scheduler under test.
Args:
model: model under test
max_num_seqs: max sequences to schedule
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
(None)
Returns:
:class:`Scheduler` instance
'''
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_num_batched_tokens,
long_prefill_token_threshold=long_prefill_token_threshold,
disable_chunked_mm_input=disable_chunked_mm_input,
)
model_config = ModelConfig(
model=model,
task="auto",
tokenizer=model,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="float16",
seed=42,
)
# Cache config, optionally force APC
kwargs_cache = ({} if enable_prefix_caching is None else {
'enable_prefix_caching': enable_prefix_caching
})
cache_config = CacheConfig(
block_size=16,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
**kwargs_cache,
)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=10000, # A large number of blocks to hold all requests
tensors={},
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(16, 1, 1, torch.float32, False))
],
)
cache_config.num_gpu_blocks = 10000
return AscendScheduler(
scheduler_config,
model_config,
cache_config,
lora_config=None,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)
def create_requests(num_requests: int,
num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None):
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs)
requests = []
for i in range(num_requests):
if mm_positions is not None:
mm_position = mm_positions[i]
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
else:
mm_position = None
mm_inputs = None
request = Request(
request_id=f"{i}",
prompt=None,
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=0,
)
requests.append(request)
return requests
def test_add_requests():
scheduler = create_scheduler()
requests = create_requests(num_requests=10)
for i, request in enumerate(requests):
scheduler.add_request(request)
assert request.request_id in scheduler.requests
assert len(scheduler.waiting) == i + 1
def test_finish_request():
scheduler = create_scheduler()
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
for i, request in enumerate(requests):
scheduler.finish_requests(request.request_id,
RequestStatus.FINISHED_ABORTED)
assert request.request_id not in scheduler.requests
assert len(scheduler.waiting) == 9 - i
def test_get_num_unfinished_requests():
scheduler = create_scheduler()
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
for i, request in enumerate(requests):
scheduler.finish_requests(request.request_id,
RequestStatus.FINISHED_STOPPED)
assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1
@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
(None, None),
(True, 5),
])
def test_schedule(enable_prefix_caching: Optional[bool],
prompt_logprobs: Optional[int]):
'''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
'''
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
requests = create_requests(num_requests=10,
prompt_logprobs=prompt_logprobs)
for request in requests:
scheduler.add_request(request)
# Test initial scheduling
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert len(output.scheduled_cached_reqs) == 0
assert len(output.finished_req_ids) == 0
# Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items():
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
# Verify requests moved from waiting to running
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == len(requests)
for i, request in enumerate(requests):
assert scheduler.running[i] == request
def test_stop_via_update_from_output():
"""Test stopping behavior through update_from_output"""
scheduler = create_scheduler()
# Test case 1: Stop on EOS token
requests = create_requests(num_requests=2, max_tokens=10)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler.scheduled_req_ids.add(req.request_id)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 1,
requests[1].request_id: 2
},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i
for i, req in enumerate(requests)},
sampled_token_ids=[[EOS_TOKEN_ID],
[10,
11]], # First request hits EOS, second continues
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped, second continues
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_STOPPED
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID]
assert list(requests[1].output_token_ids) == [10, 11]
# Test case 2: Stop on custom stop token
scheduler = create_scheduler()
requests = create_requests(num_requests=2,
max_tokens=10,
stop_token_ids=[42, 43])
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler.scheduled_req_ids.add(req.request_id)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 2
},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 42],
requests[1].request_id: [13]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i
for i, req in enumerate(requests)},
sampled_token_ids=[[10, 42, 12],
[13, 14]], # First request hits stop token
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped on custom token
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_STOPPED
assert requests[0].stop_reason == 42
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [10, 42]
assert list(requests[1].output_token_ids) == [13, 14]
# Test case 3: Stop on max tokens
scheduler = create_scheduler()
requests = create_requests(num_requests=2, max_tokens=2)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler.scheduled_req_ids.add(req.request_id)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 1
},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 11],
requests[1].request_id: []
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i
for i, req in enumerate(requests)},
sampled_token_ids=[[10, 11, 12],
[13]], # First request exceeds max_tokens
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped due to length
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [10, 11
] # Truncated to max_tokens
assert list(requests[1].output_token_ids) == [13]
# Test case 4: Ignore EOS flag
scheduler = create_scheduler()
requests = create_requests(num_requests=1, max_tokens=10)
requests[0].sampling_params.ignore_eos = True
requests[0].num_computed_tokens = requests[0].num_tokens
scheduler.requests[requests[0].request_id] = requests[0]
scheduler.running.append(requests[0])
scheduler.scheduled_req_ids.add(requests[0].request_id)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [EOS_TOKEN_ID, 10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
scheduler.update_from_output(scheduler_output, model_output)
# Verify request continues past EOS
assert len(scheduler.running) == 1
assert not requests[0].is_finished()
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]

View File

@@ -43,7 +43,7 @@ if TYPE_CHECKING:
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata) ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
def generate_attn_mask(max_seq_len: int, dtype=torch.float16): def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None):
# Construct lower triangle matrix. # Construct lower triangle matrix.
mask_flag = torch.tril( mask_flag = torch.tril(
torch.ones((max_seq_len, max_seq_len), torch.ones((max_seq_len, max_seq_len),
@@ -52,10 +52,11 @@ def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
mask_flag = ~mask_flag mask_flag = ~mask_flag
# Currently for fp16 dtype, the mask value should be set to -inf. # Currently for fp16 dtype, the mask value should be set to -inf.
# TODO: Eliminate this part in the future. # TODO: Eliminate this part in the future.
if dtype == torch.float16: if mask_value is None:
mask_value = torch.finfo(torch.float32).min if dtype == torch.float16:
else: mask_value = torch.finfo(torch.float32).min
mask_value = 1 else:
mask_value = 1
attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)), attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
mask_flag, mask_value).to(dtype) mask_flag, mask_value).to(dtype)
return attn_mask return attn_mask
@@ -66,12 +67,14 @@ class AttentionMaskBuilder:
def __init__(self, attn_mask: torch.Tensor): def __init__(self, attn_mask: torch.Tensor):
self._seq_len_cached = attn_mask.shape[0] self._seq_len_cached = attn_mask.shape[0]
self.attn_mask_cache = attn_mask self.attn_mask_cache = attn_mask
self.splitfuse_mask_value = -10000
@classmethod @classmethod
def initialize_from_len(cls, def initialize_from_len(cls,
max_seq_len: int, max_seq_len: int,
dtype: torch.dtype = torch.float16): dtype: torch.dtype = torch.float16,
return cls(generate_attn_mask(max_seq_len, dtype)) mask_value: Optional[int] = None):
return cls(generate_attn_mask(max_seq_len, dtype, mask_value))
def update_attn_cache(self, seqlen: int, dtype: torch.dtype, def update_attn_cache(self, seqlen: int, dtype: torch.dtype,
device: torch.device): device: torch.device):
@@ -97,6 +100,49 @@ class AttentionMaskBuilder:
return (self.attn_mask_cache.index_select( return (self.attn_mask_cache.index_select(
0, input_lengths)[:, :max_s].view(-1, 1, max_s).contiguous()) 0, input_lengths)[:, :max_s].view(-1, 1, max_s).contiguous())
def get_splitfuse_attn_mask(
self,
seq_lens,
query_lens,
position,
dtype,
device,
) -> torch.Tensor:
max_seq_len = max(seq_lens, default=0)
if max_seq_len <= self._seq_len_cached:
self.update_attn_cache(max_seq_len, dtype, device)
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
# is not the same. Fix this in the future when kernel is ready.
if self.attn_mask_cache[0][1] > 0:
attn_mask = self.get_attn_mask( # type: ignore
max_seq_len, dtype, device)
attn_mask *= -10000
else:
attn_mask = self.attn_mask_cache
return torch.index_select(attn_mask, dim=0,
index=position)[:, :max_seq_len]
total_q_len = sum(query_lens)
attn_mask = torch.zeros((total_q_len, max_seq_len),
dtype=dtype,
device="cpu")
current_row = 0
for i in range(len(query_lens)):
seq_len = seq_lens[i]
q_len = query_lens[i]
context_len = seq_len - q_len
assert context_len >= 0
attn_mask[current_row:current_row + q_len,
context_len:] = self.splitfuse_mask_value
right_tensor = attn_mask[current_row:current_row + q_len,
context_len:seq_len]
right_tensor.mask_fill_(
right_tensor.tril() == self.splitfuse_mask_value, 0)
current_row += q_len
return attn_mask.to(device, non_blocking=True)
class AscendAttentionBackend(AttentionBackend): class AscendAttentionBackend(AttentionBackend):

View File

@@ -16,6 +16,7 @@
# #
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type from typing import Any, Dict, List, Optional, Tuple, Type
import torch import torch
@@ -50,7 +51,7 @@ class AscendAttentionBackend(AttentionBackend):
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads * head_size) return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod @staticmethod
def swap_blocks( def swap_blocks(
@@ -83,6 +84,12 @@ class AscendAttentionBackend(AttentionBackend):
value_caches[dst_indices] = value_caches[src_indices] value_caches[dst_indices] = value_caches[src_indices]
class AscendAttentionState(Enum):
PrefillOnly = 0
DecodeOnly = 1
ChunkedPrefill = 2
@dataclass @dataclass
class AscendMetadata: class AscendMetadata:
# (batch_size, max_blocks_per_seq). # (batch_size, max_blocks_per_seq).
@@ -104,6 +111,8 @@ class AscendMetadata:
# FlashAttention has better performance than PageAtttention, # FlashAttention has better performance than PageAtttention,
# but it does not support decode requests. # but it does not support decode requests.
is_only_prefill: bool = False is_only_prefill: bool = False
# Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
attn_mask: Optional[torch.Tensor] = None attn_mask: Optional[torch.Tensor] = None
@@ -139,7 +148,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.seq_len_cpu_tensor = None self.key_cache = None
self.value_cache = None
def forward( def forward(
self, self,
@@ -190,30 +200,52 @@ class AscendAttentionBackendImpl(AttentionImpl):
# TODO: Remove this contiguous in the future. # TODO: Remove this contiguous in the future.
value = value.contiguous() value = value.contiguous()
if kv_cache.numel() > 0:
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
torch_npu._npu_reshape_and_cache(key=key,
value=value,
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slots)
if hasattr(layer, 'quant_method'): if hasattr(layer, 'quant_method'):
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata # TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
pass pass
# V0-Style scheduler situation.
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
mask = attn_metadata.attn_mask
torch_npu._npu_flash_attention(query=query,
key=key,
value=value,
mask=mask,
seq_len=attn_metadata.seq_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
block_tables = attn_metadata.block_tables
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=block_tables,
context_lens=attn_metadata.context_lens,
out=output)
# Normal V1 situation.
else: else:
if kv_cache.numel() > 0:
key_cache, value_cache = kv_cache[0], kv_cache[1]
num_blocks, block_size, _ = key_cache.shape
key_cache = key_cache.view(num_blocks, block_size,
self.num_kv_heads, self.head_size)
value_cache = value_cache.view(num_blocks, block_size,
self.num_kv_heads,
self.head_size)
slots = attn_metadata.slot_mapping
torch_npu._npu_reshape_and_cache(key=key,
value=value,
key_cache=key_cache,
value_cache=value_cache,
slot_indices=slots)
# use paged attention # use paged attention
torch_npu._npu_paged_attention_splitfuse( torch_npu._npu_paged_attention_splitfuse(
query=query, query=query,
key_cache=key_cache, key_cache=self.key_cache,
value_cache=value_cache, value_cache=self.value_cache,
mask=attn_metadata.attn_mask, mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables, block_table=attn_metadata.block_tables,
seq_len=attn_metadata.seq_lens, seq_len=attn_metadata.seq_lens,

View File

View File

@@ -0,0 +1,73 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from dataclasses import dataclass, fields
from typing import Type, Union
from vllm.config import SchedulerConfig
@dataclass
class AscendSchedulerConfig(SchedulerConfig):
enable_chunked_prefill: bool = False
policy: str = "fcfs"
num_scheduler_steps: int = 1
scheduler_cls: Union[str, Type[object]] = (
"vllm_ascend.core.scheduler.AscendScheduler")
@classmethod
def initialize_from_config(
cls,
vllm_scheduler_config: SchedulerConfig,
ascend_scheduler_config: dict,
):
scheduler_config = {
field.name: getattr(vllm_scheduler_config, field.name)
for field in fields(vllm_scheduler_config) if field.init
}
# Override default values into original SchedulerConfig
scheduler_config["enable_chunked_prefill"] = False
scheduler_config["policy"] = "fcfs"
scheduler_config["num_scheduler_steps"] = 1
scheduler_config["scheduler_cls"] = (
"vllm_ascend.core.scheduler.AscendScheduler")
# Override params in original SchedulerConfig with params in additional_config.ascend_scheduler_config
for k, v in ascend_scheduler_config.items():
scheduler_config[k] = v
return cls(**scheduler_config)
def __post_init__(self) -> None:
self.max_num_encoder_input_tokens = self.max_num_batched_tokens
self.encoder_cache_size = self.max_num_batched_tokens
self.chunked_prefill_enabled = self.enable_chunked_prefill
if self.policy != "fcfs":
raise NotImplementedError(
f"currently AscendScheduler only supports fcfs policy, got {self.policy}"
)
if self.is_multimodal_model:
raise NotImplementedError(
"currently AscendScheduler only supports LLM modles.")
if self.num_scheduler_steps > 1:
raise NotImplementedError(
"currently AscendScheduler doesn't support multi-step.")
if self.send_delta_data:
raise NotImplementedError(
"currently AscendScheduler doesn't support send_delta_data.")
if self.delay_factor > 0:
raise NotImplementedError(
"currently AscendScheduler doesn't support scheduler_delay_factor."
)

View File

@@ -0,0 +1,305 @@
from collections import deque
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import Request, RequestStatus
logger = init_logger(__name__)
class AscendScheduler(Scheduler):
"""This Scheduler extends vllm's original v1 scheduler
with prefill-first scheduling strategy."""
def schedule(self) -> SchedulerOutput:
if self.scheduler_config.chunked_prefill_enabled:
return super().schedule()
scheduled_new_reqs: list[Request] = []
scheduled_resumed_reqs: list[Request] = []
scheduled_running_reqs: list[Request] = []
preempted_reqs: list[Request] = []
req_to_new_block_ids: dict[str, list[int]] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
# Record scheduled LoRA requests.
scheduled_loras: set[int] = set()
# Use a temporary deque to collect requests that need to be skipped
# and put back at the head of the waiting queue later
skipped_waiting_requests: deque[Request] = deque()
# Schedule prefill requests first.
while self.waiting and token_budget > 0:
if len(scheduled_new_reqs) == self.max_num_running_reqs:
break
request = self.waiting[0]
def skip_cur_request():
self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
# Check that adding the request still respects the max_loras
# constraint.
if (self.lora_config and request.lora_request and
(len(scheduled_loras) == self.lora_config.max_loras
and request.lora_request.lora_int_id not in scheduled_loras)):
# Scheduling would exceed max_loras, skip.
skip_cur_request()
continue
prompt_limit = self._get_prompt_limit(request)
# Get already-cached tokens.
computed_blocks, num_computed_tokens = (
self.kv_cache_manager.get_computed_blocks(request))
num_new_tokens = request.num_prompt_tokens - num_computed_tokens
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
max_tokens_in_kvcache = (self.kv_cache_config.num_blocks *
self.block_size)
prompt_limit = min(prompt_limit, max_tokens_in_kvcache)
# Finish request that exceeds prompt_limit or kv cache size.
if num_new_tokens > prompt_limit:
logger.warning(
"Input prompt (%d tokens) is too long"
" and exceeds limit of %d",
num_new_tokens,
prompt_limit,
)
request.status = RequestStatus.FINISHED_IGNORED
self.finished_req_ids.add(request.request_id) # type: ignore
self.waiting.popleft()
continue
if num_new_tokens > token_budget:
# Scheduling would exceed token_budget, skip.
skip_cur_request()
continue
assert num_new_tokens > 0
watermark = getattr(self.scheduler_config, "watermark", 0.01)
if not self._check_watermark_for_prefill(
request, num_new_tokens, computed_blocks, watermark):
# Scheduling would exceed watermark, skip.
skip_cur_request()
continue
new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens, computed_blocks)
if new_blocks is None:
# The request cannot be scheduled.
break
self.waiting.popleft()
self.running.append(request)
self.scheduled_req_ids.add(request.request_id)
# Check request status.
if request.status == RequestStatus.WAITING:
scheduled_new_reqs.append(request)
elif request.status == RequestStatus.PREEMPTED:
scheduled_resumed_reqs.append(request)
else:
raise RuntimeError(f"Invalid request status: {request.status}")
if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_block_ids[request.request_id] = [
b.block_id for b in computed_blocks + new_blocks
]
# Update request info.
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
self.waiting.extendleft(skipped_waiting_requests)
# If no prefill requests are scheduled,
# Schedule decode requests next.
if len(self.scheduled_req_ids) == 0:
req_index = 0
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
if request.request_id in self.scheduled_req_ids:
# This request has already been scheduled.
req_index += 1
continue
num_new_tokens = (request.num_tokens_with_spec -
request.num_computed_tokens)
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens == 1
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens)
if new_blocks is None:
# The request cannot be scheduled.
# Preempt the lowest-priority request.
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
self.waiting.appendleft(preempted_req)
preempted_reqs.append(preempted_req)
if preempted_req == request:
# No more request to preempt.
can_schedule = False
break
else:
# The request can be scheduled.
can_schedule = True
break
if not can_schedule:
break
assert new_blocks is not None
# Schedule the request.
scheduled_running_reqs.append(request)
self.scheduled_req_ids.add(request.request_id)
req_to_new_block_ids[request.request_id] = [
b.block_id for b in new_blocks
]
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
# Speculative decode related.
if request.spec_token_ids:
num_scheduled_spec_tokens = (num_new_tokens +
request.num_computed_tokens -
request.num_tokens)
if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids)
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
scheduled_running_reqs) <= len(self.running)
# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
num_common_prefix_blocks = 0
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = (
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request, len(self.running)))
# Construct the scheduler output.
new_reqs_data = [
NewRequestData.from_request(req,
req_to_new_block_ids[req.request_id])
for req in scheduled_new_reqs
]
resumed_reqs_data = [
self._make_cached_request_data(
req,
num_scheduled_tokens[req.request_id],
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
req_to_new_block_ids[req.request_id],
resumed_from_preemption=True,
) for req in scheduled_resumed_reqs
]
running_reqs_data = [
self._make_cached_request_data(
req,
num_scheduled_tokens[req.request_id],
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
req_to_new_block_ids[req.request_id],
resumed_from_preemption=False,
) for req in scheduled_running_reqs
]
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs={},
num_common_prefix_blocks=num_common_prefix_blocks,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=self.finished_req_ids, # type: ignore
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
structured_output_request_ids={},
grammar_bitmask=None,
)
# Advance the number of computed tokens for the request AFTER
# the request is scheduled.
# 1. The scheduler_output of the current step has to include the
# original number of scheduled tokens to determine input IDs.
# 2. Advance the number of computed tokens here allowing us to
# schedule the prefill request again immediately in the next
# scheduling step.
# 3. If some tokens (e.g. spec tokens) are rejected later, the number of
# computed tokens will be adjusted in update_from_output.
for req_id, num_scheduled_token in num_scheduled_tokens.items():
self.requests[req_id].num_computed_tokens += num_scheduled_token
self.finished_req_ids = set() # type: ignore
return scheduler_output
def _check_watermark_for_prefill(self,
request,
num_new_tokens,
computed_blocks,
watermark=0.01):
computed_blocks = computed_blocks or []
watermark_blocks = self.kv_cache_config.num_blocks * watermark
num_computed_tokens = (request.num_computed_tokens +
len(computed_blocks) * self.block_size)
num_required_blocks = cdiv(num_new_tokens + num_computed_tokens,
self.block_size)
req_blocks = self.kv_cache_manager.req_to_blocks[request.request_id]
num_new_blocks = (num_required_blocks - len(req_blocks) -
len(computed_blocks))
num_evictable_computed_blocks = sum(1 for blk in computed_blocks
if blk.ref_cnt == 0)
# If number of free blocks is less than water mark after allocating, don't allocate.
if (self.kv_cache_manager.block_pool.get_num_free_blocks() -
num_evictable_computed_blocks -
num_new_blocks) < watermark_blocks:
return False
return True
def _get_prompt_limit(self, request: Request) -> int:
if (self.scheduler_config.chunked_prefill_enabled
and not self.scheduler_config.is_multi_step):
prompt_limit = self.scheduler_config.max_model_len
else:
prompt_limit = min(
self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens,
)
# Model is fine tuned with long context. Return the fine tuned max_len.
if request.lora_request and request.lora_request.long_lora_max_len:
assert prompt_limit <= request.lora_request.long_lora_max_len
return request.lora_request.long_lora_max_len
else:
return prompt_limit

View File

@@ -132,6 +132,22 @@ class NPUPlatform(Platform):
) )
cache_config.enable_prefix_caching = False cache_config.enable_prefix_caching = False
if envs.VLLM_USE_V1:
# Activate custom ops for v1.
vllm_config.compilation_config.custom_ops = ["all"]
additional_config = vllm_config.additional_config
# If ascend_scheduler_config exists in additional_config,
# extents original scheduler_config to use AscendScheduler.
if additional_config and additional_config.get(
"ascend_scheduler_config", None) is not None:
additional_scheduler_config = additional_config.get(
"ascend_scheduler_config")
from vllm_ascend.core.schedule_config import \
AscendSchedulerConfig
ascend_scheduler_config = AscendSchedulerConfig.initialize_from_config(
vllm_config.scheduler_config, additional_scheduler_config)
vllm_config.scheduler_config = ascend_scheduler_config
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype, def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla): kv_cache_dtype, block_size, use_v1, use_mla):

View File

@@ -25,7 +25,7 @@ import numpy as np
import numpy.typing as npt import numpy.typing as npt
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention import AttentionType from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
@@ -37,7 +37,8 @@ from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import DeviceMemoryProfiler, LayerBlockType, cdiv from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec) KVCacheSpec)
@@ -45,15 +46,14 @@ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend, from vllm_ascend.attention.attention import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata) AscendMetadata)
from vllm_ascend.platform import NPUPlatform from vllm_ascend.platform import NPUPlatform
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
NPU_PAGED_ATTENTION_MASK_VALUE = -10000
class NPUModelRunner: class NPUModelRunner:
@@ -74,6 +74,32 @@ class NPUModelRunner:
self.num_attn_layers = self.model_config.get_num_layers_by_block_type( self.num_attn_layers = self.model_config.get_num_layers_by_block_type(
vllm_config.parallel_config, LayerBlockType.attention) vllm_config.parallel_config, LayerBlockType.attention)
self.hidden_size = self.model_config.get_hidden_size() self.hidden_size = self.model_config.get_hidden_size()
self.dtype = self.model_config.dtype
cache_config = vllm_config.cache_config
if cache_config.cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
self.head_size = self.model_config.get_head_size()
self.attn_backend = get_attn_backend(
self.head_size,
self.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
)
if self.attn_backend is None:
error_msg = (
f"Error with get_att_backend: {self.head_size=}, "
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
f"{self.model_config.is_attention_free=}, "
f"{self.model_config.use_mla=}")
logger.error(error_msg)
raise NotImplementedError(
"Non-Attention backend is not supported by V1 NPUModelRunner.")
# Multi-modal data support # Multi-modal data support
self.input_registry = INPUT_REGISTRY self.input_registry = INPUT_REGISTRY
@@ -135,7 +161,7 @@ class NPUModelRunner:
self.inputs_embeds = torch.zeros( self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size), (self.max_num_tokens, self.hidden_size),
dtype=self.model_config.dtype, dtype=self.dtype,
device=self.device) device=self.device)
# OPTIMIZATION: Cache the tensors rather than creating them every step. # OPTIMIZATION: Cache the tensors rather than creating them every step.
@@ -183,13 +209,8 @@ class NPUModelRunner:
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000) mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
self.attn_mask_len = min(self.model_config.max_model_len, self.attn_mask_len = min(self.model_config.max_model_len,
int(mask_len)) int(mask_len))
self.attn_mask_npu = torch.full( self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
(self.attn_mask_len, self.attn_mask_len), self.attn_mask_len, self.dtype)
NPU_PAGED_ATTENTION_MASK_VALUE,
device=self.device,
dtype=self.vllm_config.model_config.dtype)
self.attn_mask_npu.masked_fill_(
self.attn_mask_npu.tril() == NPU_PAGED_ATTENTION_MASK_VALUE, 0)
def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler """Update the cached states and the persistent batch with the scheduler
@@ -346,35 +367,20 @@ class NPUModelRunner:
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model return self.model
def _make_attention_mask(self, seq_lens, query_lens, def _make_attention_mask(self, seq_lens, query_lens, position,
position) -> torch.Tensor: attn_state) -> torch.Tensor:
max_seq_len = max(seq_lens, default=0) # Chunk Prefill situation.
if max_seq_len <= self.attn_mask_len: if attn_state == AscendAttentionState.ChunkedPrefill:
return torch.index_select(self.attn_mask_npu, return self.attn_mask_builder.get_splitfuse_attn_mask(
dim=0, seq_lens, query_lens, position, self.dtype, self.device)
index=position)[:, :max_seq_len] # Prefill-only situation.
elif attn_state == AscendAttentionState.PrefillOnly:
total_q_len = sum(query_lens) max_seq_len = max(seq_lens, default=0)
attn_mask = torch.zeros((total_q_len, max_seq_len), return self.attn_mask_builder.get_attn_mask(
dtype=self.vllm_config.model_config.dtype, max_seq_len, self.dtype, self.device)
device="cpu") # Decode-only situation.
else:
current_row = 0 return None
for i in range(len(query_lens)):
seq_len = seq_lens[i]
q_len = query_lens[i]
context_len = seq_len - q_len
assert context_len >= 0
attn_mask[current_row:current_row + q_len,
context_len:] = NPU_PAGED_ATTENTION_MASK_VALUE
right_tensor = attn_mask[current_row:current_row + q_len,
context_len:seq_len]
right_tensor.mask_fill_(
right_tensor.tril() == NPU_PAGED_ATTENTION_MASK_VALUE, 0)
current_row += q_len
return attn_mask.to(self.device, non_blocking=True)
def _process_reqs( def _process_reqs(
self, self,
@@ -408,6 +414,9 @@ class NPUModelRunner:
cu_num_tokens = np.cumsum(num_scheduled_tokens) cu_num_tokens = np.cumsum(num_scheduled_tokens)
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
num_scheduled_tokens) num_scheduled_tokens)
sample_indices = cu_num_tokens - 1
sample_indices = torch.from_numpy(sample_indices).to(self.device,
non_blocking=True)
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
positions_np = self.positions_np[:total_num_scheduled_tokens] positions_np = self.positions_np[:total_num_scheduled_tokens]
@@ -437,9 +446,18 @@ class NPUModelRunner:
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to( slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
self.device, non_blocking=True) self.device, non_blocking=True)
attn_state = AscendAttentionState.ChunkedPrefill
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
attn_state = AscendAttentionState.PrefillOnly
elif np.all(num_scheduled_tokens == 1):
attn_state = AscendAttentionState.DecodeOnly
else:
attn_state = AscendAttentionState.ChunkedPrefill
attn_mask = self._make_attention_mask(seq_lens=seq_lens, attn_mask = self._make_attention_mask(seq_lens=seq_lens,
query_lens=num_scheduled_tokens, query_lens=num_scheduled_tokens,
position=positions) position=positions,
attn_state=attn_state)
attn_metadata = AscendMetadata( attn_metadata = AscendMetadata(
seq_lens=query_lens, seq_lens=query_lens,
@@ -448,6 +466,7 @@ class NPUModelRunner:
block_tables=( block_tables=(
self.input_batch.block_table.get_device_tensor()[:num_reqs]), self.input_batch.block_table.get_device_tensor()[:num_reqs]),
attn_mask=attn_mask, attn_mask=attn_mask,
attn_state=attn_state,
) )
# Prepare input_ids # Prepare input_ids
@@ -472,7 +491,7 @@ class NPUModelRunner:
inputs_embeds=None, inputs_embeds=None,
) )
return hidden_states[cu_num_tokens - 1] return hidden_states[sample_indices]
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
@@ -636,7 +655,7 @@ class NPUModelRunner:
self.intermediate_tensors = ( self.intermediate_tensors = (
self.model.make_empty_intermediate_tensors( self.model.make_empty_intermediate_tensors(
batch_size=self.max_num_tokens, batch_size=self.max_num_tokens,
dtype=self.model_config.dtype, dtype=self.dtype,
device=self.device)) device=self.device))
intermediate_tensors = IntermediateTensors({ intermediate_tensors = IntermediateTensors({
k: v[:self.max_num_tokens] k: v[:self.max_num_tokens]
@@ -708,6 +727,7 @@ class NPUModelRunner:
kv_cache_config: Configuration for the KV cache, including the KV kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer cache size of each layer
""" """
import torch_npu
kv_caches: Dict[str, torch.Tensor] = {} kv_caches: Dict[str, torch.Tensor] = {}
for kv_cache_group in kv_cache_config.kv_cache_groups: for kv_cache_group in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group.kv_cache_spec kv_cache_spec = kv_cache_group.kv_cache_spec
@@ -724,13 +744,14 @@ class NPUModelRunner:
# the min of all `num_blocks`. Verify it here. # the min of all `num_blocks`. Verify it here.
assert num_blocks >= kv_cache_config.num_blocks assert num_blocks >= kv_cache_config.num_blocks
if isinstance(kv_cache_spec, FullAttentionSpec): if isinstance(kv_cache_spec, FullAttentionSpec):
kv_cache_shape = AscendAttentionBackend.get_kv_cache_shape( kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size, num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype dtype = kv_cache_spec.dtype
kv_caches[layer_name] = torch.zeros(kv_cache_shape, kv_caches[layer_name] = torch.zeros(kv_cache_shape,
dtype=dtype, dtype=dtype,
device=self.device) device=self.device)
torch_npu.npu_format_cast(kv_caches[layer_name], 2)
else: else:
# TODO: add new branches when introducing more types of # TODO: add new branches when introducing more types of
# KV cache specs. # KV cache specs.

View File

@@ -40,6 +40,7 @@ from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import WorkerBase
from vllm_ascend.platform import NPUPlatform from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import try_register_lib
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
@@ -66,6 +67,11 @@ class NPUWorker(WorkerBase):
rank=rank, rank=rank,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker) is_driver_worker=is_driver_worker)
# Try to import mindie_turbo to accelerate vLLM inference.
try_register_lib(
"mindie_turbo",
"MindIE Turbo is installed. vLLM inference will be accelerated with MindIE Turbo."
)
if self.cache_config.cache_dtype == "auto": if self.cache_config.cache_dtype == "auto":
self.cache_dtype = self.model_config.dtype self.cache_dtype = self.model_config.dtype
else: else: