[Scheduler] Add AscendScheduler. (#543)
This PR adds AscendScheduler to vllm v1 engine. This scheduler currently supports v0-style prefill-first scheduling strategy. In the future more schedule methods will be supported by this scheduler. --------- Signed-off-by: hw_whx <wanghexiang7@huawei.com> Co-authored-by: hw_whx <wanghexiang7@huawei.com>
This commit is contained in:
396
tests/scheduler/test_scheduler.py
Normal file
396
tests/scheduler/test_scheduler.py
Normal file
@@ -0,0 +1,396 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
# Adapted from vllm-project/vllm/blob/main/tests/models/utils.py
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
||||||
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
|
KVCacheGroupSpec)
|
||||||
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
from vllm.v1.request import Request, RequestStatus
|
||||||
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
|
||||||
|
from vllm_ascend.core.scheduler import AscendScheduler
|
||||||
|
|
||||||
|
EOS_TOKEN_ID = 50256
|
||||||
|
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
model: str = "facebook/opt-125m",
|
||||||
|
max_num_seqs: int = 16,
|
||||||
|
max_num_batched_tokens: int = 8192,
|
||||||
|
enable_prefix_caching: Optional[bool] = None,
|
||||||
|
long_prefill_token_threshold: int = 0,
|
||||||
|
disable_chunked_mm_input: bool = False,
|
||||||
|
) -> AscendScheduler:
|
||||||
|
'''Create scheduler under test.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: model under test
|
||||||
|
max_num_seqs: max sequences to schedule
|
||||||
|
max_num_batch_tokens: max num tokens to batch
|
||||||
|
enable_prefix_caching: optionally force APC config
|
||||||
|
(True/False) or use default
|
||||||
|
(None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:class:`Scheduler` instance
|
||||||
|
'''
|
||||||
|
scheduler_config = SchedulerConfig(
|
||||||
|
max_num_seqs=max_num_seqs,
|
||||||
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
max_model_len=max_num_batched_tokens,
|
||||||
|
long_prefill_token_threshold=long_prefill_token_threshold,
|
||||||
|
disable_chunked_mm_input=disable_chunked_mm_input,
|
||||||
|
)
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model=model,
|
||||||
|
task="auto",
|
||||||
|
tokenizer=model,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=True,
|
||||||
|
dtype="float16",
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
# Cache config, optionally force APC
|
||||||
|
kwargs_cache = ({} if enable_prefix_caching is None else {
|
||||||
|
'enable_prefix_caching': enable_prefix_caching
|
||||||
|
})
|
||||||
|
cache_config = CacheConfig(
|
||||||
|
block_size=16,
|
||||||
|
gpu_memory_utilization=0.9,
|
||||||
|
swap_space=0,
|
||||||
|
cache_dtype="auto",
|
||||||
|
**kwargs_cache,
|
||||||
|
)
|
||||||
|
vllm_config = VllmConfig(
|
||||||
|
scheduler_config=scheduler_config,
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
)
|
||||||
|
kv_cache_config = KVCacheConfig(
|
||||||
|
num_blocks=10000, # A large number of blocks to hold all requests
|
||||||
|
tensors={},
|
||||||
|
kv_cache_groups=[
|
||||||
|
KVCacheGroupSpec(['layer'],
|
||||||
|
FullAttentionSpec(16, 1, 1, torch.float32, False))
|
||||||
|
],
|
||||||
|
)
|
||||||
|
cache_config.num_gpu_blocks = 10000
|
||||||
|
return AscendScheduler(
|
||||||
|
scheduler_config,
|
||||||
|
model_config,
|
||||||
|
cache_config,
|
||||||
|
lora_config=None,
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
log_stats=True,
|
||||||
|
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_requests(num_requests: int,
|
||||||
|
num_tokens: int = 10,
|
||||||
|
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||||
|
max_tokens: int = 16,
|
||||||
|
stop_token_ids: Optional[list[int]] = None,
|
||||||
|
prompt_logprobs: Optional[int] = None):
|
||||||
|
sampling_params = SamplingParams(ignore_eos=False,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
|
prompt_logprobs=prompt_logprobs)
|
||||||
|
requests = []
|
||||||
|
for i in range(num_requests):
|
||||||
|
if mm_positions is not None:
|
||||||
|
mm_position = mm_positions[i]
|
||||||
|
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
|
||||||
|
else:
|
||||||
|
mm_position = None
|
||||||
|
mm_inputs = None
|
||||||
|
request = Request(
|
||||||
|
request_id=f"{i}",
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=[i] * num_tokens,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
multi_modal_inputs=mm_inputs,
|
||||||
|
multi_modal_placeholders=mm_position,
|
||||||
|
multi_modal_hashes=None,
|
||||||
|
eos_token_id=EOS_TOKEN_ID,
|
||||||
|
arrival_time=0,
|
||||||
|
)
|
||||||
|
requests.append(request)
|
||||||
|
return requests
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_requests():
|
||||||
|
scheduler = create_scheduler()
|
||||||
|
requests = create_requests(num_requests=10)
|
||||||
|
|
||||||
|
for i, request in enumerate(requests):
|
||||||
|
scheduler.add_request(request)
|
||||||
|
assert request.request_id in scheduler.requests
|
||||||
|
assert len(scheduler.waiting) == i + 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_finish_request():
|
||||||
|
scheduler = create_scheduler()
|
||||||
|
requests = create_requests(num_requests=10)
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
for i, request in enumerate(requests):
|
||||||
|
scheduler.finish_requests(request.request_id,
|
||||||
|
RequestStatus.FINISHED_ABORTED)
|
||||||
|
assert request.request_id not in scheduler.requests
|
||||||
|
assert len(scheduler.waiting) == 9 - i
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_num_unfinished_requests():
|
||||||
|
scheduler = create_scheduler()
|
||||||
|
requests = create_requests(num_requests=10)
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
for i, request in enumerate(requests):
|
||||||
|
scheduler.finish_requests(request.request_id,
|
||||||
|
RequestStatus.FINISHED_STOPPED)
|
||||||
|
assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
|
||||||
|
(None, None),
|
||||||
|
(True, 5),
|
||||||
|
])
|
||||||
|
def test_schedule(enable_prefix_caching: Optional[bool],
|
||||||
|
prompt_logprobs: Optional[int]):
|
||||||
|
'''Test scheduling.
|
||||||
|
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
|
||||||
|
'''
|
||||||
|
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
|
||||||
|
requests = create_requests(num_requests=10,
|
||||||
|
prompt_logprobs=prompt_logprobs)
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
# Test initial scheduling
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(output.scheduled_new_reqs) == len(requests)
|
||||||
|
assert len(output.scheduled_cached_reqs) == 0
|
||||||
|
assert len(output.finished_req_ids) == 0
|
||||||
|
# Verify all requests are scheduled.
|
||||||
|
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
||||||
|
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
|
||||||
|
|
||||||
|
# Verify requests moved from waiting to running
|
||||||
|
assert len(scheduler.waiting) == 0
|
||||||
|
assert len(scheduler.running) == len(requests)
|
||||||
|
for i, request in enumerate(requests):
|
||||||
|
assert scheduler.running[i] == request
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_via_update_from_output():
|
||||||
|
"""Test stopping behavior through update_from_output"""
|
||||||
|
scheduler = create_scheduler()
|
||||||
|
|
||||||
|
# Test case 1: Stop on EOS token
|
||||||
|
requests = create_requests(num_requests=2, max_tokens=10)
|
||||||
|
for req in requests:
|
||||||
|
req.num_computed_tokens = req.num_tokens
|
||||||
|
scheduler.requests[req.request_id] = req
|
||||||
|
scheduler.running.append(req)
|
||||||
|
scheduler.scheduled_req_ids.add(req.request_id)
|
||||||
|
|
||||||
|
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||||
|
scheduled_cached_reqs=[],
|
||||||
|
num_scheduled_tokens={
|
||||||
|
requests[0].request_id: 1,
|
||||||
|
requests[1].request_id: 2
|
||||||
|
},
|
||||||
|
total_num_scheduled_tokens=3,
|
||||||
|
scheduled_encoder_inputs={},
|
||||||
|
scheduled_spec_decode_tokens={
|
||||||
|
requests[0].request_id: [],
|
||||||
|
requests[1].request_id: [10]
|
||||||
|
},
|
||||||
|
num_common_prefix_blocks=0,
|
||||||
|
finished_req_ids=set(),
|
||||||
|
free_encoder_input_ids=[],
|
||||||
|
structured_output_request_ids={},
|
||||||
|
grammar_bitmask=None)
|
||||||
|
|
||||||
|
model_output = ModelRunnerOutput(
|
||||||
|
req_ids=[req.request_id for req in requests],
|
||||||
|
req_id_to_index={req.request_id: i
|
||||||
|
for i, req in enumerate(requests)},
|
||||||
|
sampled_token_ids=[[EOS_TOKEN_ID],
|
||||||
|
[10,
|
||||||
|
11]], # First request hits EOS, second continues
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={})
|
||||||
|
|
||||||
|
scheduler.update_from_output(scheduler_output, model_output)
|
||||||
|
|
||||||
|
# Verify first request stopped, second continues
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
assert scheduler.running[0].request_id == requests[1].request_id
|
||||||
|
assert requests[0].status == RequestStatus.FINISHED_STOPPED
|
||||||
|
assert requests[0].request_id in scheduler.finished_req_ids
|
||||||
|
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID]
|
||||||
|
assert list(requests[1].output_token_ids) == [10, 11]
|
||||||
|
|
||||||
|
# Test case 2: Stop on custom stop token
|
||||||
|
scheduler = create_scheduler()
|
||||||
|
requests = create_requests(num_requests=2,
|
||||||
|
max_tokens=10,
|
||||||
|
stop_token_ids=[42, 43])
|
||||||
|
for req in requests:
|
||||||
|
req.num_computed_tokens = req.num_tokens
|
||||||
|
scheduler.requests[req.request_id] = req
|
||||||
|
scheduler.running.append(req)
|
||||||
|
scheduler.scheduled_req_ids.add(req.request_id)
|
||||||
|
|
||||||
|
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||||
|
scheduled_cached_reqs=[],
|
||||||
|
num_scheduled_tokens={
|
||||||
|
requests[0].request_id: 3,
|
||||||
|
requests[1].request_id: 2
|
||||||
|
},
|
||||||
|
total_num_scheduled_tokens=5,
|
||||||
|
scheduled_encoder_inputs={},
|
||||||
|
scheduled_spec_decode_tokens={
|
||||||
|
requests[0].request_id: [10, 42],
|
||||||
|
requests[1].request_id: [13]
|
||||||
|
},
|
||||||
|
num_common_prefix_blocks=0,
|
||||||
|
finished_req_ids=set(),
|
||||||
|
free_encoder_input_ids=[],
|
||||||
|
structured_output_request_ids={},
|
||||||
|
grammar_bitmask=None)
|
||||||
|
|
||||||
|
model_output = ModelRunnerOutput(
|
||||||
|
req_ids=[req.request_id for req in requests],
|
||||||
|
req_id_to_index={req.request_id: i
|
||||||
|
for i, req in enumerate(requests)},
|
||||||
|
sampled_token_ids=[[10, 42, 12],
|
||||||
|
[13, 14]], # First request hits stop token
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={})
|
||||||
|
|
||||||
|
scheduler.update_from_output(scheduler_output, model_output)
|
||||||
|
|
||||||
|
# Verify first request stopped on custom token
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
assert scheduler.running[0].request_id == requests[1].request_id
|
||||||
|
assert requests[0].status == RequestStatus.FINISHED_STOPPED
|
||||||
|
assert requests[0].stop_reason == 42
|
||||||
|
assert requests[0].request_id in scheduler.finished_req_ids
|
||||||
|
assert list(requests[0].output_token_ids) == [10, 42]
|
||||||
|
assert list(requests[1].output_token_ids) == [13, 14]
|
||||||
|
|
||||||
|
# Test case 3: Stop on max tokens
|
||||||
|
scheduler = create_scheduler()
|
||||||
|
requests = create_requests(num_requests=2, max_tokens=2)
|
||||||
|
for req in requests:
|
||||||
|
req.num_computed_tokens = req.num_tokens
|
||||||
|
scheduler.requests[req.request_id] = req
|
||||||
|
scheduler.running.append(req)
|
||||||
|
scheduler.scheduled_req_ids.add(req.request_id)
|
||||||
|
|
||||||
|
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||||
|
scheduled_cached_reqs=[],
|
||||||
|
num_scheduled_tokens={
|
||||||
|
requests[0].request_id: 3,
|
||||||
|
requests[1].request_id: 1
|
||||||
|
},
|
||||||
|
total_num_scheduled_tokens=4,
|
||||||
|
scheduled_encoder_inputs={},
|
||||||
|
scheduled_spec_decode_tokens={
|
||||||
|
requests[0].request_id: [10, 11],
|
||||||
|
requests[1].request_id: []
|
||||||
|
},
|
||||||
|
num_common_prefix_blocks=0,
|
||||||
|
finished_req_ids=set(),
|
||||||
|
free_encoder_input_ids=[],
|
||||||
|
structured_output_request_ids={},
|
||||||
|
grammar_bitmask=None)
|
||||||
|
|
||||||
|
model_output = ModelRunnerOutput(
|
||||||
|
req_ids=[req.request_id for req in requests],
|
||||||
|
req_id_to_index={req.request_id: i
|
||||||
|
for i, req in enumerate(requests)},
|
||||||
|
sampled_token_ids=[[10, 11, 12],
|
||||||
|
[13]], # First request exceeds max_tokens
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={})
|
||||||
|
|
||||||
|
scheduler.update_from_output(scheduler_output, model_output)
|
||||||
|
|
||||||
|
# Verify first request stopped due to length
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
assert scheduler.running[0].request_id == requests[1].request_id
|
||||||
|
assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||||
|
assert requests[0].request_id in scheduler.finished_req_ids
|
||||||
|
assert list(requests[0].output_token_ids) == [10, 11
|
||||||
|
] # Truncated to max_tokens
|
||||||
|
assert list(requests[1].output_token_ids) == [13]
|
||||||
|
|
||||||
|
# Test case 4: Ignore EOS flag
|
||||||
|
scheduler = create_scheduler()
|
||||||
|
requests = create_requests(num_requests=1, max_tokens=10)
|
||||||
|
requests[0].sampling_params.ignore_eos = True
|
||||||
|
requests[0].num_computed_tokens = requests[0].num_tokens
|
||||||
|
scheduler.requests[requests[0].request_id] = requests[0]
|
||||||
|
scheduler.running.append(requests[0])
|
||||||
|
scheduler.scheduled_req_ids.add(requests[0].request_id)
|
||||||
|
|
||||||
|
scheduler_output = SchedulerOutput(
|
||||||
|
scheduled_new_reqs=[],
|
||||||
|
scheduled_cached_reqs=[],
|
||||||
|
num_scheduled_tokens={requests[0].request_id: 3},
|
||||||
|
total_num_scheduled_tokens=3,
|
||||||
|
scheduled_encoder_inputs={},
|
||||||
|
scheduled_spec_decode_tokens={
|
||||||
|
requests[0].request_id: [EOS_TOKEN_ID, 10]
|
||||||
|
},
|
||||||
|
num_common_prefix_blocks=0,
|
||||||
|
finished_req_ids=set(),
|
||||||
|
free_encoder_input_ids=[],
|
||||||
|
structured_output_request_ids={},
|
||||||
|
grammar_bitmask=None)
|
||||||
|
|
||||||
|
model_output = ModelRunnerOutput(
|
||||||
|
req_ids=[requests[0].request_id],
|
||||||
|
req_id_to_index={requests[0].request_id: 0},
|
||||||
|
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={})
|
||||||
|
|
||||||
|
scheduler.update_from_output(scheduler_output, model_output)
|
||||||
|
|
||||||
|
# Verify request continues past EOS
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
assert not requests[0].is_finished()
|
||||||
|
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
|
||||||
@@ -43,7 +43,7 @@ if TYPE_CHECKING:
|
|||||||
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
|
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
|
||||||
|
|
||||||
|
|
||||||
def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
|
def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None):
|
||||||
# Construct lower triangle matrix.
|
# Construct lower triangle matrix.
|
||||||
mask_flag = torch.tril(
|
mask_flag = torch.tril(
|
||||||
torch.ones((max_seq_len, max_seq_len),
|
torch.ones((max_seq_len, max_seq_len),
|
||||||
@@ -52,10 +52,11 @@ def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
|
|||||||
mask_flag = ~mask_flag
|
mask_flag = ~mask_flag
|
||||||
# Currently for fp16 dtype, the mask value should be set to -inf.
|
# Currently for fp16 dtype, the mask value should be set to -inf.
|
||||||
# TODO: Eliminate this part in the future.
|
# TODO: Eliminate this part in the future.
|
||||||
if dtype == torch.float16:
|
if mask_value is None:
|
||||||
mask_value = torch.finfo(torch.float32).min
|
if dtype == torch.float16:
|
||||||
else:
|
mask_value = torch.finfo(torch.float32).min
|
||||||
mask_value = 1
|
else:
|
||||||
|
mask_value = 1
|
||||||
attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
|
attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
|
||||||
mask_flag, mask_value).to(dtype)
|
mask_flag, mask_value).to(dtype)
|
||||||
return attn_mask
|
return attn_mask
|
||||||
@@ -66,12 +67,14 @@ class AttentionMaskBuilder:
|
|||||||
def __init__(self, attn_mask: torch.Tensor):
|
def __init__(self, attn_mask: torch.Tensor):
|
||||||
self._seq_len_cached = attn_mask.shape[0]
|
self._seq_len_cached = attn_mask.shape[0]
|
||||||
self.attn_mask_cache = attn_mask
|
self.attn_mask_cache = attn_mask
|
||||||
|
self.splitfuse_mask_value = -10000
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def initialize_from_len(cls,
|
def initialize_from_len(cls,
|
||||||
max_seq_len: int,
|
max_seq_len: int,
|
||||||
dtype: torch.dtype = torch.float16):
|
dtype: torch.dtype = torch.float16,
|
||||||
return cls(generate_attn_mask(max_seq_len, dtype))
|
mask_value: Optional[int] = None):
|
||||||
|
return cls(generate_attn_mask(max_seq_len, dtype, mask_value))
|
||||||
|
|
||||||
def update_attn_cache(self, seqlen: int, dtype: torch.dtype,
|
def update_attn_cache(self, seqlen: int, dtype: torch.dtype,
|
||||||
device: torch.device):
|
device: torch.device):
|
||||||
@@ -97,6 +100,49 @@ class AttentionMaskBuilder:
|
|||||||
return (self.attn_mask_cache.index_select(
|
return (self.attn_mask_cache.index_select(
|
||||||
0, input_lengths)[:, :max_s].view(-1, 1, max_s).contiguous())
|
0, input_lengths)[:, :max_s].view(-1, 1, max_s).contiguous())
|
||||||
|
|
||||||
|
def get_splitfuse_attn_mask(
|
||||||
|
self,
|
||||||
|
seq_lens,
|
||||||
|
query_lens,
|
||||||
|
position,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
max_seq_len = max(seq_lens, default=0)
|
||||||
|
if max_seq_len <= self._seq_len_cached:
|
||||||
|
self.update_attn_cache(max_seq_len, dtype, device)
|
||||||
|
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
|
||||||
|
# is not the same. Fix this in the future when kernel is ready.
|
||||||
|
if self.attn_mask_cache[0][1] > 0:
|
||||||
|
attn_mask = self.get_attn_mask( # type: ignore
|
||||||
|
max_seq_len, dtype, device)
|
||||||
|
attn_mask *= -10000
|
||||||
|
else:
|
||||||
|
attn_mask = self.attn_mask_cache
|
||||||
|
return torch.index_select(attn_mask, dim=0,
|
||||||
|
index=position)[:, :max_seq_len]
|
||||||
|
total_q_len = sum(query_lens)
|
||||||
|
attn_mask = torch.zeros((total_q_len, max_seq_len),
|
||||||
|
dtype=dtype,
|
||||||
|
device="cpu")
|
||||||
|
|
||||||
|
current_row = 0
|
||||||
|
for i in range(len(query_lens)):
|
||||||
|
seq_len = seq_lens[i]
|
||||||
|
q_len = query_lens[i]
|
||||||
|
context_len = seq_len - q_len
|
||||||
|
|
||||||
|
assert context_len >= 0
|
||||||
|
attn_mask[current_row:current_row + q_len,
|
||||||
|
context_len:] = self.splitfuse_mask_value
|
||||||
|
right_tensor = attn_mask[current_row:current_row + q_len,
|
||||||
|
context_len:seq_len]
|
||||||
|
right_tensor.mask_fill_(
|
||||||
|
right_tensor.tril() == self.splitfuse_mask_value, 0)
|
||||||
|
current_row += q_len
|
||||||
|
|
||||||
|
return attn_mask.to(device, non_blocking=True)
|
||||||
|
|
||||||
|
|
||||||
class AscendAttentionBackend(AttentionBackend):
|
class AscendAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -50,7 +51,7 @@ class AscendAttentionBackend(AttentionBackend):
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
) -> Tuple[int, ...]:
|
) -> Tuple[int, ...]:
|
||||||
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def swap_blocks(
|
def swap_blocks(
|
||||||
@@ -83,6 +84,12 @@ class AscendAttentionBackend(AttentionBackend):
|
|||||||
value_caches[dst_indices] = value_caches[src_indices]
|
value_caches[dst_indices] = value_caches[src_indices]
|
||||||
|
|
||||||
|
|
||||||
|
class AscendAttentionState(Enum):
|
||||||
|
PrefillOnly = 0
|
||||||
|
DecodeOnly = 1
|
||||||
|
ChunkedPrefill = 2
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AscendMetadata:
|
class AscendMetadata:
|
||||||
# (batch_size, max_blocks_per_seq).
|
# (batch_size, max_blocks_per_seq).
|
||||||
@@ -104,6 +111,8 @@ class AscendMetadata:
|
|||||||
# FlashAttention has better performance than PageAtttention,
|
# FlashAttention has better performance than PageAtttention,
|
||||||
# but it does not support decode requests.
|
# but it does not support decode requests.
|
||||||
is_only_prefill: bool = False
|
is_only_prefill: bool = False
|
||||||
|
# Current state of this attention run.
|
||||||
|
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||||
|
|
||||||
attn_mask: Optional[torch.Tensor] = None
|
attn_mask: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
@@ -139,7 +148,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
self.seq_len_cpu_tensor = None
|
self.key_cache = None
|
||||||
|
self.value_cache = None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -190,30 +200,52 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
# TODO: Remove this contiguous in the future.
|
# TODO: Remove this contiguous in the future.
|
||||||
value = value.contiguous()
|
value = value.contiguous()
|
||||||
|
|
||||||
|
if kv_cache.numel() > 0:
|
||||||
|
if self.key_cache is None:
|
||||||
|
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||||
|
slots = attn_metadata.slot_mapping
|
||||||
|
torch_npu._npu_reshape_and_cache(key=key,
|
||||||
|
value=value,
|
||||||
|
key_cache=self.key_cache,
|
||||||
|
value_cache=self.value_cache,
|
||||||
|
slot_indices=slots)
|
||||||
|
|
||||||
if hasattr(layer, 'quant_method'):
|
if hasattr(layer, 'quant_method'):
|
||||||
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
|
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
|
||||||
pass
|
pass
|
||||||
|
# V0-Style scheduler situation.
|
||||||
|
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
|
||||||
|
assert attn_metadata is not None
|
||||||
|
assert attn_metadata.attn_mask is not None
|
||||||
|
mask = attn_metadata.attn_mask
|
||||||
|
torch_npu._npu_flash_attention(query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
mask=mask,
|
||||||
|
seq_len=attn_metadata.seq_lens,
|
||||||
|
scale_value=self.scale,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
out=output)
|
||||||
|
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||||
|
block_tables = attn_metadata.block_tables
|
||||||
|
torch_npu._npu_paged_attention(
|
||||||
|
query=query,
|
||||||
|
key_cache=self.key_cache,
|
||||||
|
value_cache=self.value_cache,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
scale_value=self.scale,
|
||||||
|
block_table=block_tables,
|
||||||
|
context_lens=attn_metadata.context_lens,
|
||||||
|
out=output)
|
||||||
|
# Normal V1 situation.
|
||||||
else:
|
else:
|
||||||
if kv_cache.numel() > 0:
|
|
||||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
|
||||||
num_blocks, block_size, _ = key_cache.shape
|
|
||||||
key_cache = key_cache.view(num_blocks, block_size,
|
|
||||||
self.num_kv_heads, self.head_size)
|
|
||||||
value_cache = value_cache.view(num_blocks, block_size,
|
|
||||||
self.num_kv_heads,
|
|
||||||
self.head_size)
|
|
||||||
slots = attn_metadata.slot_mapping
|
|
||||||
torch_npu._npu_reshape_and_cache(key=key,
|
|
||||||
value=value,
|
|
||||||
key_cache=key_cache,
|
|
||||||
value_cache=value_cache,
|
|
||||||
slot_indices=slots)
|
|
||||||
|
|
||||||
# use paged attention
|
# use paged attention
|
||||||
torch_npu._npu_paged_attention_splitfuse(
|
torch_npu._npu_paged_attention_splitfuse(
|
||||||
query=query,
|
query=query,
|
||||||
key_cache=key_cache,
|
key_cache=self.key_cache,
|
||||||
value_cache=value_cache,
|
value_cache=self.value_cache,
|
||||||
mask=attn_metadata.attn_mask,
|
mask=attn_metadata.attn_mask,
|
||||||
block_table=attn_metadata.block_tables,
|
block_table=attn_metadata.block_tables,
|
||||||
seq_len=attn_metadata.seq_lens,
|
seq_len=attn_metadata.seq_lens,
|
||||||
|
|||||||
0
vllm_ascend/core/__init__.py
Normal file
0
vllm_ascend/core/__init__.py
Normal file
73
vllm_ascend/core/schedule_config.py
Normal file
73
vllm_ascend/core/schedule_config.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
from dataclasses import dataclass, fields
|
||||||
|
from typing import Type, Union
|
||||||
|
|
||||||
|
from vllm.config import SchedulerConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AscendSchedulerConfig(SchedulerConfig):
|
||||||
|
enable_chunked_prefill: bool = False
|
||||||
|
policy: str = "fcfs"
|
||||||
|
num_scheduler_steps: int = 1
|
||||||
|
scheduler_cls: Union[str, Type[object]] = (
|
||||||
|
"vllm_ascend.core.scheduler.AscendScheduler")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def initialize_from_config(
|
||||||
|
cls,
|
||||||
|
vllm_scheduler_config: SchedulerConfig,
|
||||||
|
ascend_scheduler_config: dict,
|
||||||
|
):
|
||||||
|
scheduler_config = {
|
||||||
|
field.name: getattr(vllm_scheduler_config, field.name)
|
||||||
|
for field in fields(vllm_scheduler_config) if field.init
|
||||||
|
}
|
||||||
|
# Override default values into original SchedulerConfig
|
||||||
|
scheduler_config["enable_chunked_prefill"] = False
|
||||||
|
scheduler_config["policy"] = "fcfs"
|
||||||
|
scheduler_config["num_scheduler_steps"] = 1
|
||||||
|
scheduler_config["scheduler_cls"] = (
|
||||||
|
"vllm_ascend.core.scheduler.AscendScheduler")
|
||||||
|
# Override params in original SchedulerConfig with params in additional_config.ascend_scheduler_config
|
||||||
|
for k, v in ascend_scheduler_config.items():
|
||||||
|
scheduler_config[k] = v
|
||||||
|
return cls(**scheduler_config)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self.max_num_encoder_input_tokens = self.max_num_batched_tokens
|
||||||
|
self.encoder_cache_size = self.max_num_batched_tokens
|
||||||
|
self.chunked_prefill_enabled = self.enable_chunked_prefill
|
||||||
|
if self.policy != "fcfs":
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"currently AscendScheduler only supports fcfs policy, got {self.policy}"
|
||||||
|
)
|
||||||
|
if self.is_multimodal_model:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"currently AscendScheduler only supports LLM modles.")
|
||||||
|
if self.num_scheduler_steps > 1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"currently AscendScheduler doesn't support multi-step.")
|
||||||
|
if self.send_delta_data:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"currently AscendScheduler doesn't support send_delta_data.")
|
||||||
|
if self.delay_factor > 0:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"currently AscendScheduler doesn't support scheduler_delay_factor."
|
||||||
|
)
|
||||||
305
vllm_ascend/core/scheduler.py
Normal file
305
vllm_ascend/core/scheduler.py
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
from collections import deque
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import cdiv
|
||||||
|
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||||
|
from vllm.v1.core.sched.scheduler import Scheduler
|
||||||
|
from vllm.v1.request import Request, RequestStatus
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AscendScheduler(Scheduler):
|
||||||
|
"""This Scheduler extends vllm's original v1 scheduler
|
||||||
|
with prefill-first scheduling strategy."""
|
||||||
|
|
||||||
|
def schedule(self) -> SchedulerOutput:
|
||||||
|
if self.scheduler_config.chunked_prefill_enabled:
|
||||||
|
return super().schedule()
|
||||||
|
scheduled_new_reqs: list[Request] = []
|
||||||
|
scheduled_resumed_reqs: list[Request] = []
|
||||||
|
scheduled_running_reqs: list[Request] = []
|
||||||
|
preempted_reqs: list[Request] = []
|
||||||
|
|
||||||
|
req_to_new_block_ids: dict[str, list[int]] = {}
|
||||||
|
num_scheduled_tokens: dict[str, int] = {}
|
||||||
|
token_budget = self.max_num_scheduled_tokens
|
||||||
|
# Spec decode-related.
|
||||||
|
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
|
||||||
|
|
||||||
|
# Record scheduled LoRA requests.
|
||||||
|
scheduled_loras: set[int] = set()
|
||||||
|
|
||||||
|
# Use a temporary deque to collect requests that need to be skipped
|
||||||
|
# and put back at the head of the waiting queue later
|
||||||
|
skipped_waiting_requests: deque[Request] = deque()
|
||||||
|
|
||||||
|
# Schedule prefill requests first.
|
||||||
|
while self.waiting and token_budget > 0:
|
||||||
|
if len(scheduled_new_reqs) == self.max_num_running_reqs:
|
||||||
|
break
|
||||||
|
|
||||||
|
request = self.waiting[0]
|
||||||
|
|
||||||
|
def skip_cur_request():
|
||||||
|
self.waiting.popleft()
|
||||||
|
skipped_waiting_requests.appendleft(request)
|
||||||
|
|
||||||
|
# Check that adding the request still respects the max_loras
|
||||||
|
# constraint.
|
||||||
|
if (self.lora_config and request.lora_request and
|
||||||
|
(len(scheduled_loras) == self.lora_config.max_loras
|
||||||
|
and request.lora_request.lora_int_id not in scheduled_loras)):
|
||||||
|
# Scheduling would exceed max_loras, skip.
|
||||||
|
skip_cur_request()
|
||||||
|
continue
|
||||||
|
|
||||||
|
prompt_limit = self._get_prompt_limit(request)
|
||||||
|
# Get already-cached tokens.
|
||||||
|
computed_blocks, num_computed_tokens = (
|
||||||
|
self.kv_cache_manager.get_computed_blocks(request))
|
||||||
|
num_new_tokens = request.num_prompt_tokens - num_computed_tokens
|
||||||
|
if (0 < self.scheduler_config.long_prefill_token_threshold <
|
||||||
|
num_new_tokens):
|
||||||
|
num_new_tokens = (
|
||||||
|
self.scheduler_config.long_prefill_token_threshold)
|
||||||
|
max_tokens_in_kvcache = (self.kv_cache_config.num_blocks *
|
||||||
|
self.block_size)
|
||||||
|
prompt_limit = min(prompt_limit, max_tokens_in_kvcache)
|
||||||
|
|
||||||
|
# Finish request that exceeds prompt_limit or kv cache size.
|
||||||
|
if num_new_tokens > prompt_limit:
|
||||||
|
logger.warning(
|
||||||
|
"Input prompt (%d tokens) is too long"
|
||||||
|
" and exceeds limit of %d",
|
||||||
|
num_new_tokens,
|
||||||
|
prompt_limit,
|
||||||
|
)
|
||||||
|
request.status = RequestStatus.FINISHED_IGNORED
|
||||||
|
self.finished_req_ids.add(request.request_id) # type: ignore
|
||||||
|
self.waiting.popleft()
|
||||||
|
continue
|
||||||
|
|
||||||
|
if num_new_tokens > token_budget:
|
||||||
|
# Scheduling would exceed token_budget, skip.
|
||||||
|
skip_cur_request()
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert num_new_tokens > 0
|
||||||
|
watermark = getattr(self.scheduler_config, "watermark", 0.01)
|
||||||
|
if not self._check_watermark_for_prefill(
|
||||||
|
request, num_new_tokens, computed_blocks, watermark):
|
||||||
|
# Scheduling would exceed watermark, skip.
|
||||||
|
skip_cur_request()
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||||
|
request, num_new_tokens, computed_blocks)
|
||||||
|
if new_blocks is None:
|
||||||
|
# The request cannot be scheduled.
|
||||||
|
break
|
||||||
|
|
||||||
|
self.waiting.popleft()
|
||||||
|
self.running.append(request)
|
||||||
|
self.scheduled_req_ids.add(request.request_id)
|
||||||
|
# Check request status.
|
||||||
|
if request.status == RequestStatus.WAITING:
|
||||||
|
scheduled_new_reqs.append(request)
|
||||||
|
elif request.status == RequestStatus.PREEMPTED:
|
||||||
|
scheduled_resumed_reqs.append(request)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Invalid request status: {request.status}")
|
||||||
|
|
||||||
|
if self.lora_config and request.lora_request:
|
||||||
|
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||||
|
req_to_new_block_ids[request.request_id] = [
|
||||||
|
b.block_id for b in computed_blocks + new_blocks
|
||||||
|
]
|
||||||
|
# Update request info.
|
||||||
|
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||||
|
token_budget -= num_new_tokens
|
||||||
|
request.status = RequestStatus.RUNNING
|
||||||
|
request.num_computed_tokens = num_computed_tokens
|
||||||
|
|
||||||
|
# Put back any skipped requests at the head of the waiting queue
|
||||||
|
if skipped_waiting_requests:
|
||||||
|
self.waiting.extendleft(skipped_waiting_requests)
|
||||||
|
|
||||||
|
# If no prefill requests are scheduled,
|
||||||
|
# Schedule decode requests next.
|
||||||
|
if len(self.scheduled_req_ids) == 0:
|
||||||
|
req_index = 0
|
||||||
|
while req_index < len(self.running) and token_budget > 0:
|
||||||
|
request = self.running[req_index]
|
||||||
|
if request.request_id in self.scheduled_req_ids:
|
||||||
|
# This request has already been scheduled.
|
||||||
|
req_index += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
num_new_tokens = (request.num_tokens_with_spec -
|
||||||
|
request.num_computed_tokens)
|
||||||
|
if (0 < self.scheduler_config.long_prefill_token_threshold <
|
||||||
|
num_new_tokens):
|
||||||
|
num_new_tokens = (
|
||||||
|
self.scheduler_config.long_prefill_token_threshold)
|
||||||
|
num_new_tokens = min(num_new_tokens, token_budget)
|
||||||
|
assert num_new_tokens == 1
|
||||||
|
|
||||||
|
while True:
|
||||||
|
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||||
|
request, num_new_tokens)
|
||||||
|
if new_blocks is None:
|
||||||
|
# The request cannot be scheduled.
|
||||||
|
# Preempt the lowest-priority request.
|
||||||
|
preempted_req = self.running.pop()
|
||||||
|
self.kv_cache_manager.free(preempted_req)
|
||||||
|
preempted_req.status = RequestStatus.PREEMPTED
|
||||||
|
preempted_req.num_computed_tokens = 0
|
||||||
|
self.waiting.appendleft(preempted_req)
|
||||||
|
preempted_reqs.append(preempted_req)
|
||||||
|
if preempted_req == request:
|
||||||
|
# No more request to preempt.
|
||||||
|
can_schedule = False
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# The request can be scheduled.
|
||||||
|
can_schedule = True
|
||||||
|
break
|
||||||
|
if not can_schedule:
|
||||||
|
break
|
||||||
|
assert new_blocks is not None
|
||||||
|
|
||||||
|
# Schedule the request.
|
||||||
|
scheduled_running_reqs.append(request)
|
||||||
|
self.scheduled_req_ids.add(request.request_id)
|
||||||
|
req_to_new_block_ids[request.request_id] = [
|
||||||
|
b.block_id for b in new_blocks
|
||||||
|
]
|
||||||
|
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||||
|
token_budget -= num_new_tokens
|
||||||
|
req_index += 1
|
||||||
|
|
||||||
|
# Speculative decode related.
|
||||||
|
if request.spec_token_ids:
|
||||||
|
num_scheduled_spec_tokens = (num_new_tokens +
|
||||||
|
request.num_computed_tokens -
|
||||||
|
request.num_tokens)
|
||||||
|
if num_scheduled_spec_tokens > 0:
|
||||||
|
# Trim spec_token_ids list to num_scheduled_spec_tokens.
|
||||||
|
del request.spec_token_ids[num_scheduled_spec_tokens:]
|
||||||
|
scheduled_spec_decode_tokens[request.request_id] = (
|
||||||
|
request.spec_token_ids)
|
||||||
|
|
||||||
|
# Check if the scheduling constraints are satisfied.
|
||||||
|
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||||
|
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
||||||
|
assert token_budget >= 0
|
||||||
|
assert len(self.running) <= self.max_num_running_reqs
|
||||||
|
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
|
||||||
|
scheduled_running_reqs) <= len(self.running)
|
||||||
|
|
||||||
|
# Get the longest common prefix among all requests in the running queue.
|
||||||
|
# This can be potentially used for cascade attention.
|
||||||
|
num_common_prefix_blocks = 0
|
||||||
|
if self.running:
|
||||||
|
any_request = self.running[0]
|
||||||
|
num_common_prefix_blocks = (
|
||||||
|
self.kv_cache_manager.get_num_common_prefix_blocks(
|
||||||
|
any_request, len(self.running)))
|
||||||
|
|
||||||
|
# Construct the scheduler output.
|
||||||
|
new_reqs_data = [
|
||||||
|
NewRequestData.from_request(req,
|
||||||
|
req_to_new_block_ids[req.request_id])
|
||||||
|
for req in scheduled_new_reqs
|
||||||
|
]
|
||||||
|
resumed_reqs_data = [
|
||||||
|
self._make_cached_request_data(
|
||||||
|
req,
|
||||||
|
num_scheduled_tokens[req.request_id],
|
||||||
|
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
|
||||||
|
req_to_new_block_ids[req.request_id],
|
||||||
|
resumed_from_preemption=True,
|
||||||
|
) for req in scheduled_resumed_reqs
|
||||||
|
]
|
||||||
|
running_reqs_data = [
|
||||||
|
self._make_cached_request_data(
|
||||||
|
req,
|
||||||
|
num_scheduled_tokens[req.request_id],
|
||||||
|
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
|
||||||
|
req_to_new_block_ids[req.request_id],
|
||||||
|
resumed_from_preemption=False,
|
||||||
|
) for req in scheduled_running_reqs
|
||||||
|
]
|
||||||
|
scheduler_output = SchedulerOutput(
|
||||||
|
scheduled_new_reqs=new_reqs_data,
|
||||||
|
scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
|
||||||
|
num_scheduled_tokens=num_scheduled_tokens,
|
||||||
|
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||||
|
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||||
|
scheduled_encoder_inputs={},
|
||||||
|
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||||
|
# finished_req_ids is an existing state in the scheduler,
|
||||||
|
# instead of being newly scheduled in this step.
|
||||||
|
# It contains the request IDs that are finished in between
|
||||||
|
# the previous and the current steps.
|
||||||
|
finished_req_ids=self.finished_req_ids, # type: ignore
|
||||||
|
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
|
||||||
|
structured_output_request_ids={},
|
||||||
|
grammar_bitmask=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Advance the number of computed tokens for the request AFTER
|
||||||
|
# the request is scheduled.
|
||||||
|
# 1. The scheduler_output of the current step has to include the
|
||||||
|
# original number of scheduled tokens to determine input IDs.
|
||||||
|
# 2. Advance the number of computed tokens here allowing us to
|
||||||
|
# schedule the prefill request again immediately in the next
|
||||||
|
# scheduling step.
|
||||||
|
# 3. If some tokens (e.g. spec tokens) are rejected later, the number of
|
||||||
|
# computed tokens will be adjusted in update_from_output.
|
||||||
|
for req_id, num_scheduled_token in num_scheduled_tokens.items():
|
||||||
|
self.requests[req_id].num_computed_tokens += num_scheduled_token
|
||||||
|
|
||||||
|
self.finished_req_ids = set() # type: ignore
|
||||||
|
return scheduler_output
|
||||||
|
|
||||||
|
def _check_watermark_for_prefill(self,
|
||||||
|
request,
|
||||||
|
num_new_tokens,
|
||||||
|
computed_blocks,
|
||||||
|
watermark=0.01):
|
||||||
|
computed_blocks = computed_blocks or []
|
||||||
|
watermark_blocks = self.kv_cache_config.num_blocks * watermark
|
||||||
|
num_computed_tokens = (request.num_computed_tokens +
|
||||||
|
len(computed_blocks) * self.block_size)
|
||||||
|
num_required_blocks = cdiv(num_new_tokens + num_computed_tokens,
|
||||||
|
self.block_size)
|
||||||
|
req_blocks = self.kv_cache_manager.req_to_blocks[request.request_id]
|
||||||
|
num_new_blocks = (num_required_blocks - len(req_blocks) -
|
||||||
|
len(computed_blocks))
|
||||||
|
num_evictable_computed_blocks = sum(1 for blk in computed_blocks
|
||||||
|
if blk.ref_cnt == 0)
|
||||||
|
# If number of free blocks is less than water mark after allocating, don't allocate.
|
||||||
|
if (self.kv_cache_manager.block_pool.get_num_free_blocks() -
|
||||||
|
num_evictable_computed_blocks -
|
||||||
|
num_new_blocks) < watermark_blocks:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _get_prompt_limit(self, request: Request) -> int:
|
||||||
|
if (self.scheduler_config.chunked_prefill_enabled
|
||||||
|
and not self.scheduler_config.is_multi_step):
|
||||||
|
prompt_limit = self.scheduler_config.max_model_len
|
||||||
|
else:
|
||||||
|
prompt_limit = min(
|
||||||
|
self.scheduler_config.max_model_len,
|
||||||
|
self.scheduler_config.max_num_batched_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Model is fine tuned with long context. Return the fine tuned max_len.
|
||||||
|
if request.lora_request and request.lora_request.long_lora_max_len:
|
||||||
|
assert prompt_limit <= request.lora_request.long_lora_max_len
|
||||||
|
return request.lora_request.long_lora_max_len
|
||||||
|
else:
|
||||||
|
return prompt_limit
|
||||||
@@ -132,6 +132,22 @@ class NPUPlatform(Platform):
|
|||||||
)
|
)
|
||||||
cache_config.enable_prefix_caching = False
|
cache_config.enable_prefix_caching = False
|
||||||
|
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
# Activate custom ops for v1.
|
||||||
|
vllm_config.compilation_config.custom_ops = ["all"]
|
||||||
|
additional_config = vllm_config.additional_config
|
||||||
|
# If ascend_scheduler_config exists in additional_config,
|
||||||
|
# extents original scheduler_config to use AscendScheduler.
|
||||||
|
if additional_config and additional_config.get(
|
||||||
|
"ascend_scheduler_config", None) is not None:
|
||||||
|
additional_scheduler_config = additional_config.get(
|
||||||
|
"ascend_scheduler_config")
|
||||||
|
from vllm_ascend.core.schedule_config import \
|
||||||
|
AscendSchedulerConfig
|
||||||
|
ascend_scheduler_config = AscendSchedulerConfig.initialize_from_config(
|
||||||
|
vllm_config.scheduler_config, additional_scheduler_config)
|
||||||
|
vllm_config.scheduler_config = ascend_scheduler_config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||||
kv_cache_dtype, block_size, use_v1, use_mla):
|
kv_cache_dtype, block_size, use_v1, use_mla):
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ import numpy as np
|
|||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from vllm.attention import AttentionType
|
from vllm.attention import AttentionType, get_attn_backend
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
@@ -37,7 +37,8 @@ from vllm.model_executor.model_loader import get_model
|
|||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import DeviceMemoryProfiler, LayerBlockType, cdiv
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||||
|
LayerBlockType, cdiv)
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheSpec)
|
KVCacheSpec)
|
||||||
@@ -45,15 +46,14 @@ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
|||||||
from vllm.v1.utils import bind_kv_cache
|
from vllm.v1.utils import bind_kv_cache
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
|
from vllm_ascend.attention.attention import AttentionMaskBuilder
|
||||||
|
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||||
AscendMetadata)
|
AscendMetadata)
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|
||||||
NPU_PAGED_ATTENTION_MASK_VALUE = -10000
|
|
||||||
|
|
||||||
|
|
||||||
class NPUModelRunner:
|
class NPUModelRunner:
|
||||||
|
|
||||||
@@ -74,6 +74,32 @@ class NPUModelRunner:
|
|||||||
self.num_attn_layers = self.model_config.get_num_layers_by_block_type(
|
self.num_attn_layers = self.model_config.get_num_layers_by_block_type(
|
||||||
vllm_config.parallel_config, LayerBlockType.attention)
|
vllm_config.parallel_config, LayerBlockType.attention)
|
||||||
self.hidden_size = self.model_config.get_hidden_size()
|
self.hidden_size = self.model_config.get_hidden_size()
|
||||||
|
self.dtype = self.model_config.dtype
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
if cache_config.cache_dtype == "auto":
|
||||||
|
self.kv_cache_dtype = self.dtype
|
||||||
|
else:
|
||||||
|
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||||
|
cache_config.cache_dtype]
|
||||||
|
|
||||||
|
self.head_size = self.model_config.get_head_size()
|
||||||
|
self.attn_backend = get_attn_backend(
|
||||||
|
self.head_size,
|
||||||
|
self.dtype,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
self.block_size,
|
||||||
|
self.model_config.is_attention_free,
|
||||||
|
use_mla=self.model_config.use_mla,
|
||||||
|
)
|
||||||
|
if self.attn_backend is None:
|
||||||
|
error_msg = (
|
||||||
|
f"Error with get_att_backend: {self.head_size=}, "
|
||||||
|
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
|
||||||
|
f"{self.model_config.is_attention_free=}, "
|
||||||
|
f"{self.model_config.use_mla=}")
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Non-Attention backend is not supported by V1 NPUModelRunner.")
|
||||||
|
|
||||||
# Multi-modal data support
|
# Multi-modal data support
|
||||||
self.input_registry = INPUT_REGISTRY
|
self.input_registry = INPUT_REGISTRY
|
||||||
@@ -135,7 +161,7 @@ class NPUModelRunner:
|
|||||||
|
|
||||||
self.inputs_embeds = torch.zeros(
|
self.inputs_embeds = torch.zeros(
|
||||||
(self.max_num_tokens, self.hidden_size),
|
(self.max_num_tokens, self.hidden_size),
|
||||||
dtype=self.model_config.dtype,
|
dtype=self.dtype,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
||||||
@@ -183,13 +209,8 @@ class NPUModelRunner:
|
|||||||
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
|
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
|
||||||
self.attn_mask_len = min(self.model_config.max_model_len,
|
self.attn_mask_len = min(self.model_config.max_model_len,
|
||||||
int(mask_len))
|
int(mask_len))
|
||||||
self.attn_mask_npu = torch.full(
|
self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
|
||||||
(self.attn_mask_len, self.attn_mask_len),
|
self.attn_mask_len, self.dtype)
|
||||||
NPU_PAGED_ATTENTION_MASK_VALUE,
|
|
||||||
device=self.device,
|
|
||||||
dtype=self.vllm_config.model_config.dtype)
|
|
||||||
self.attn_mask_npu.masked_fill_(
|
|
||||||
self.attn_mask_npu.tril() == NPU_PAGED_ATTENTION_MASK_VALUE, 0)
|
|
||||||
|
|
||||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||||
"""Update the cached states and the persistent batch with the scheduler
|
"""Update the cached states and the persistent batch with the scheduler
|
||||||
@@ -346,35 +367,20 @@ class NPUModelRunner:
|
|||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def _make_attention_mask(self, seq_lens, query_lens,
|
def _make_attention_mask(self, seq_lens, query_lens, position,
|
||||||
position) -> torch.Tensor:
|
attn_state) -> torch.Tensor:
|
||||||
max_seq_len = max(seq_lens, default=0)
|
# Chunk Prefill situation.
|
||||||
if max_seq_len <= self.attn_mask_len:
|
if attn_state == AscendAttentionState.ChunkedPrefill:
|
||||||
return torch.index_select(self.attn_mask_npu,
|
return self.attn_mask_builder.get_splitfuse_attn_mask(
|
||||||
dim=0,
|
seq_lens, query_lens, position, self.dtype, self.device)
|
||||||
index=position)[:, :max_seq_len]
|
# Prefill-only situation.
|
||||||
|
elif attn_state == AscendAttentionState.PrefillOnly:
|
||||||
total_q_len = sum(query_lens)
|
max_seq_len = max(seq_lens, default=0)
|
||||||
attn_mask = torch.zeros((total_q_len, max_seq_len),
|
return self.attn_mask_builder.get_attn_mask(
|
||||||
dtype=self.vllm_config.model_config.dtype,
|
max_seq_len, self.dtype, self.device)
|
||||||
device="cpu")
|
# Decode-only situation.
|
||||||
|
else:
|
||||||
current_row = 0
|
return None
|
||||||
for i in range(len(query_lens)):
|
|
||||||
seq_len = seq_lens[i]
|
|
||||||
q_len = query_lens[i]
|
|
||||||
context_len = seq_len - q_len
|
|
||||||
|
|
||||||
assert context_len >= 0
|
|
||||||
attn_mask[current_row:current_row + q_len,
|
|
||||||
context_len:] = NPU_PAGED_ATTENTION_MASK_VALUE
|
|
||||||
right_tensor = attn_mask[current_row:current_row + q_len,
|
|
||||||
context_len:seq_len]
|
|
||||||
right_tensor.mask_fill_(
|
|
||||||
right_tensor.tril() == NPU_PAGED_ATTENTION_MASK_VALUE, 0)
|
|
||||||
current_row += q_len
|
|
||||||
|
|
||||||
return attn_mask.to(self.device, non_blocking=True)
|
|
||||||
|
|
||||||
def _process_reqs(
|
def _process_reqs(
|
||||||
self,
|
self,
|
||||||
@@ -408,6 +414,9 @@ class NPUModelRunner:
|
|||||||
cu_num_tokens = np.cumsum(num_scheduled_tokens)
|
cu_num_tokens = np.cumsum(num_scheduled_tokens)
|
||||||
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
|
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
|
||||||
num_scheduled_tokens)
|
num_scheduled_tokens)
|
||||||
|
sample_indices = cu_num_tokens - 1
|
||||||
|
sample_indices = torch.from_numpy(sample_indices).to(self.device,
|
||||||
|
non_blocking=True)
|
||||||
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
|
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
|
||||||
|
|
||||||
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
||||||
@@ -437,9 +446,18 @@ class NPUModelRunner:
|
|||||||
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
|
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
|
||||||
self.device, non_blocking=True)
|
self.device, non_blocking=True)
|
||||||
|
|
||||||
|
attn_state = AscendAttentionState.ChunkedPrefill
|
||||||
|
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
||||||
|
attn_state = AscendAttentionState.PrefillOnly
|
||||||
|
elif np.all(num_scheduled_tokens == 1):
|
||||||
|
attn_state = AscendAttentionState.DecodeOnly
|
||||||
|
else:
|
||||||
|
attn_state = AscendAttentionState.ChunkedPrefill
|
||||||
|
|
||||||
attn_mask = self._make_attention_mask(seq_lens=seq_lens,
|
attn_mask = self._make_attention_mask(seq_lens=seq_lens,
|
||||||
query_lens=num_scheduled_tokens,
|
query_lens=num_scheduled_tokens,
|
||||||
position=positions)
|
position=positions,
|
||||||
|
attn_state=attn_state)
|
||||||
|
|
||||||
attn_metadata = AscendMetadata(
|
attn_metadata = AscendMetadata(
|
||||||
seq_lens=query_lens,
|
seq_lens=query_lens,
|
||||||
@@ -448,6 +466,7 @@ class NPUModelRunner:
|
|||||||
block_tables=(
|
block_tables=(
|
||||||
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
|
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
|
attn_state=attn_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare input_ids
|
# Prepare input_ids
|
||||||
@@ -472,7 +491,7 @@ class NPUModelRunner:
|
|||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return hidden_states[cu_num_tokens - 1]
|
return hidden_states[sample_indices]
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
@@ -636,7 +655,7 @@ class NPUModelRunner:
|
|||||||
self.intermediate_tensors = (
|
self.intermediate_tensors = (
|
||||||
self.model.make_empty_intermediate_tensors(
|
self.model.make_empty_intermediate_tensors(
|
||||||
batch_size=self.max_num_tokens,
|
batch_size=self.max_num_tokens,
|
||||||
dtype=self.model_config.dtype,
|
dtype=self.dtype,
|
||||||
device=self.device))
|
device=self.device))
|
||||||
intermediate_tensors = IntermediateTensors({
|
intermediate_tensors = IntermediateTensors({
|
||||||
k: v[:self.max_num_tokens]
|
k: v[:self.max_num_tokens]
|
||||||
@@ -708,6 +727,7 @@ class NPUModelRunner:
|
|||||||
kv_cache_config: Configuration for the KV cache, including the KV
|
kv_cache_config: Configuration for the KV cache, including the KV
|
||||||
cache size of each layer
|
cache size of each layer
|
||||||
"""
|
"""
|
||||||
|
import torch_npu
|
||||||
kv_caches: Dict[str, torch.Tensor] = {}
|
kv_caches: Dict[str, torch.Tensor] = {}
|
||||||
for kv_cache_group in kv_cache_config.kv_cache_groups:
|
for kv_cache_group in kv_cache_config.kv_cache_groups:
|
||||||
kv_cache_spec = kv_cache_group.kv_cache_spec
|
kv_cache_spec = kv_cache_group.kv_cache_spec
|
||||||
@@ -724,13 +744,14 @@ class NPUModelRunner:
|
|||||||
# the min of all `num_blocks`. Verify it here.
|
# the min of all `num_blocks`. Verify it here.
|
||||||
assert num_blocks >= kv_cache_config.num_blocks
|
assert num_blocks >= kv_cache_config.num_blocks
|
||||||
if isinstance(kv_cache_spec, FullAttentionSpec):
|
if isinstance(kv_cache_spec, FullAttentionSpec):
|
||||||
kv_cache_shape = AscendAttentionBackend.get_kv_cache_shape(
|
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||||
num_blocks, kv_cache_spec.block_size,
|
num_blocks, kv_cache_spec.block_size,
|
||||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||||
dtype = kv_cache_spec.dtype
|
dtype = kv_cache_spec.dtype
|
||||||
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
|
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
torch_npu.npu_format_cast(kv_caches[layer_name], 2)
|
||||||
else:
|
else:
|
||||||
# TODO: add new branches when introducing more types of
|
# TODO: add new branches when introducing more types of
|
||||||
# KV cache specs.
|
# KV cache specs.
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ from vllm.v1.utils import bind_kv_cache
|
|||||||
from vllm.v1.worker.worker_base import WorkerBase
|
from vllm.v1.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
|
from vllm_ascend.utils import try_register_lib
|
||||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||||
|
|
||||||
|
|
||||||
@@ -66,6 +67,11 @@ class NPUWorker(WorkerBase):
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
is_driver_worker=is_driver_worker)
|
is_driver_worker=is_driver_worker)
|
||||||
|
# Try to import mindie_turbo to accelerate vLLM inference.
|
||||||
|
try_register_lib(
|
||||||
|
"mindie_turbo",
|
||||||
|
"MindIE Turbo is installed. vLLM inference will be accelerated with MindIE Turbo."
|
||||||
|
)
|
||||||
if self.cache_config.cache_dtype == "auto":
|
if self.cache_config.cache_dtype == "auto":
|
||||||
self.cache_dtype = self.model_config.dtype
|
self.cache_dtype = self.model_config.dtype
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user