init v0.11.0rc0
This commit is contained in:
@@ -20,14 +20,19 @@ from typing import Type, Union
|
||||
|
||||
from vllm.config import SchedulerConfig
|
||||
|
||||
MAX_INT = 2147483647
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendSchedulerConfig(SchedulerConfig):
|
||||
enable_chunked_prefill: bool = False
|
||||
max_long_partial_prefills: int = MAX_INT
|
||||
long_prefill_token_threshold: int = MAX_INT
|
||||
policy: str = "fcfs"
|
||||
num_scheduler_steps: int = 1
|
||||
scheduler_cls: Union[str, Type[object]] = (
|
||||
"vllm_ascend.core.scheduler.AscendScheduler")
|
||||
enable_pd_transfer: bool = False
|
||||
decode_max_num_seqs: int = 0
|
||||
|
||||
@classmethod
|
||||
def initialize_from_config(
|
||||
@@ -41,10 +46,13 @@ class AscendSchedulerConfig(SchedulerConfig):
|
||||
}
|
||||
# Override default values into original SchedulerConfig
|
||||
scheduler_config["enable_chunked_prefill"] = False
|
||||
scheduler_config["max_long_partial_prefills"] = None
|
||||
scheduler_config["long_prefill_token_threshold"] = None
|
||||
scheduler_config["policy"] = "fcfs"
|
||||
scheduler_config["num_scheduler_steps"] = 1
|
||||
scheduler_config["scheduler_cls"] = (
|
||||
"vllm_ascend.core.scheduler.AscendScheduler")
|
||||
scheduler_config["enable_pd_transfer"] = False
|
||||
scheduler_config["decode_max_num_seqs"] = 0
|
||||
# Override params in original SchedulerConfig with params in ascend_scheduler_config
|
||||
for k, _ in scheduler_config.items():
|
||||
if hasattr(ascend_scheduler_config, k):
|
||||
@@ -65,20 +73,36 @@ class AscendSchedulerConfig(SchedulerConfig):
|
||||
"max_num_batched_tokens and makes vLLM reject longer "
|
||||
"sequences. Please increase max_num_batched_tokens or "
|
||||
"decrease max_model_len.")
|
||||
# concurrent partial prefills. Default is inf
|
||||
if self.max_long_partial_prefills is None:
|
||||
self.max_long_partial_prefills = MAX_INT
|
||||
self.long_prefill_token_threshold = MAX_INT
|
||||
|
||||
if self.long_prefill_token_threshold is None or \
|
||||
self.long_prefill_token_threshold <= 0:
|
||||
if self.max_model_len is None:
|
||||
self.long_prefill_token_threshold = MAX_INT
|
||||
else:
|
||||
self.long_prefill_token_threshold = \
|
||||
max(1, int(self.max_model_len * 0.04))
|
||||
|
||||
if self.max_long_partial_prefills < 0:
|
||||
raise ValueError(
|
||||
f"max_long_partial_prefills must be non-negative, but got "
|
||||
f"{self.max_long_partial_prefills}")
|
||||
if self.long_prefill_token_threshold < 0:
|
||||
raise ValueError(
|
||||
f"long_prefill_token_threshold must be non-negative, but got "
|
||||
f"{self.long_prefill_token_threshold}")
|
||||
|
||||
if self.policy != "fcfs":
|
||||
raise NotImplementedError(
|
||||
f"currently AscendScheduler only supports fcfs policy, got {self.policy}"
|
||||
)
|
||||
if self.is_multimodal_model:
|
||||
raise NotImplementedError(
|
||||
"currently AscendScheduler only supports LLM models.")
|
||||
if self.num_scheduler_steps > 1:
|
||||
raise NotImplementedError(
|
||||
"currently AscendScheduler doesn't support multi-step.")
|
||||
if self.send_delta_data:
|
||||
raise NotImplementedError(
|
||||
"currently AscendScheduler doesn't support send_delta_data.")
|
||||
if self.delay_factor > 0:
|
||||
if getattr(self, "scheduler_delay_factor", 0) > 0:
|
||||
raise NotImplementedError(
|
||||
"currently AscendScheduler doesn't support scheduler_delay_factor."
|
||||
)
|
||||
|
||||
@@ -23,6 +23,7 @@ from vllm.distributed.kv_events import KVEventBatch
|
||||
from vllm.logger import logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutputs
|
||||
@@ -31,13 +32,6 @@ from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
else:
|
||||
KVCacheBlocks = None
|
||||
|
||||
|
||||
class AscendScheduler(Scheduler):
|
||||
"""This Scheduler extends vllm's original v1 scheduler
|
||||
@@ -58,6 +52,15 @@ class AscendScheduler(Scheduler):
|
||||
self.scheduled_req_ids: set[str] = set()
|
||||
self.running: list[Request] = []
|
||||
|
||||
self.finished_prefill_reqs: deque[Request] = deque()
|
||||
enable_pd_transfer = getattr(self.scheduler_config,
|
||||
'enable_pd_transfer', False)
|
||||
decode_max_num_seqs = getattr(self.scheduler_config,
|
||||
'decode_max_num_seqs', 0)
|
||||
self.phase = "" if not enable_pd_transfer else "prefill"
|
||||
self.decode_max_num_running_reqs = max(self.max_num_running_reqs,
|
||||
decode_max_num_seqs)
|
||||
|
||||
def schedule(self) -> SchedulerOutput:
|
||||
if self.scheduler_config.chunked_prefill_enabled:
|
||||
return super().schedule()
|
||||
@@ -66,12 +69,14 @@ class AscendScheduler(Scheduler):
|
||||
scheduled_running_reqs: list[Request] = []
|
||||
preempted_reqs: list[Request] = []
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
req_to_new_block_ids: dict[str, list[list[int]]] = {}
|
||||
else:
|
||||
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
|
||||
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
|
||||
num_scheduled_tokens: dict[str, int] = {}
|
||||
token_budget = self.max_num_scheduled_tokens
|
||||
|
||||
# Encoder-related.
|
||||
scheduled_encoder_inputs: dict[str, list[int]] = {}
|
||||
encoder_budget = self.max_num_encoder_input_tokens
|
||||
|
||||
# Spec decode-related.
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
|
||||
|
||||
@@ -85,9 +90,33 @@ class AscendScheduler(Scheduler):
|
||||
# and put back at the head of the waiting queue later
|
||||
skipped_waiting_requests: deque[Request] = deque()
|
||||
|
||||
if self.phase == "prefill":
|
||||
remaining_running_reqs = []
|
||||
for request in self.running:
|
||||
# move request has finished prefill to finished_prefill_reqs
|
||||
if request.num_tokens > request.num_prompt_tokens:
|
||||
self.finished_prefill_reqs.append(request)
|
||||
else:
|
||||
remaining_running_reqs.append(request)
|
||||
self.running = remaining_running_reqs
|
||||
# all request prefilled, change phase to decode
|
||||
if not self.waiting and not self.running:
|
||||
self.phase = "decode"
|
||||
# Skip long prompt requests in prefill stage.
|
||||
# long_prefill_budget is float('inf') if not use.
|
||||
if self.vllm_config.scheduler_config.long_prefill_token_threshold == 0:
|
||||
long_prefill_budget = float('inf')
|
||||
long_prefill_token_threshold = float('inf')
|
||||
else:
|
||||
long_prefill_budget = self.vllm_config.scheduler_config.max_long_partial_prefills
|
||||
long_prefill_token_threshold = self.vllm_config.scheduler_config.long_prefill_token_threshold
|
||||
|
||||
# Schedule prefill requests first.
|
||||
while self.waiting and token_budget > 0:
|
||||
if len(self.running) == self.max_num_running_reqs:
|
||||
if len(self.running) == (self.decode_max_num_running_reqs
|
||||
if self.phase == "decode" else
|
||||
self.max_num_running_reqs):
|
||||
|
||||
break
|
||||
|
||||
request = self.waiting[0]
|
||||
@@ -139,6 +168,9 @@ class AscendScheduler(Scheduler):
|
||||
num_new_local_computed_tokens = 0
|
||||
num_computed_tokens = request.num_computed_tokens
|
||||
|
||||
encoder_inputs_to_schedule = None
|
||||
new_encoder_budget = encoder_budget
|
||||
|
||||
# P/D: loading remote KV, do not allocate for new work.
|
||||
if load_kv_async:
|
||||
assert num_external_computed_tokens > 0
|
||||
@@ -176,6 +208,17 @@ class AscendScheduler(Scheduler):
|
||||
assert num_new_tokens > 0
|
||||
blocks = new_computed_blocks.blocks[0]
|
||||
|
||||
# Schedule encoder inputs.
|
||||
if request.has_encoder_inputs:
|
||||
(encoder_inputs_to_schedule, num_new_tokens,
|
||||
new_encoder_budget) = self._try_schedule_encoder_inputs(
|
||||
request, num_computed_tokens, num_new_tokens,
|
||||
encoder_budget)
|
||||
if num_new_tokens == 0 or len(
|
||||
encoder_inputs_to_schedule) == 0:
|
||||
# The request cannot be scheduled.
|
||||
break
|
||||
|
||||
watermark = getattr(self.scheduler_config, "watermark", 0.01)
|
||||
if not self._check_watermark_for_prefill(request, num_new_tokens,
|
||||
blocks, watermark):
|
||||
@@ -183,6 +226,11 @@ class AscendScheduler(Scheduler):
|
||||
skip_cur_request()
|
||||
continue
|
||||
|
||||
if num_new_tokens > long_prefill_token_threshold \
|
||||
and long_prefill_budget <= 0:
|
||||
skip_cur_request()
|
||||
continue
|
||||
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens + num_external_computed_tokens,
|
||||
@@ -227,26 +275,41 @@ class AscendScheduler(Scheduler):
|
||||
|
||||
if self.lora_config and request.lora_request:
|
||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
req_to_new_block_ids[request.request_id] = (
|
||||
self.kv_cache_manager.get_block_ids(request.request_id))
|
||||
else:
|
||||
req_to_new_blocks[
|
||||
request.request_id] = self.kv_cache_manager.get_blocks(
|
||||
request.request_id)
|
||||
|
||||
req_to_new_blocks[
|
||||
request.request_id] = self.kv_cache_manager.get_blocks(
|
||||
request.request_id)
|
||||
# Update request info.
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
if num_new_tokens > long_prefill_token_threshold:
|
||||
long_prefill_budget -= 1
|
||||
request.status = RequestStatus.RUNNING
|
||||
request.num_computed_tokens = num_computed_tokens
|
||||
# Count the number of prefix cached tokens.
|
||||
if request.num_cached_tokens < 0:
|
||||
request.num_cached_tokens = num_computed_tokens
|
||||
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request.request_id] = (
|
||||
encoder_inputs_to_schedule)
|
||||
# Allocate the encoder cache.
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_budget = new_encoder_budget
|
||||
|
||||
# Put back any skipped requests at the head of the waiting queue
|
||||
if skipped_waiting_requests:
|
||||
self.waiting.extendleft(skipped_waiting_requests)
|
||||
|
||||
if self.phase == "decode":
|
||||
while len(
|
||||
self.running
|
||||
) < self.decode_max_num_running_reqs and self.finished_prefill_reqs:
|
||||
request = self.finished_prefill_reqs.popleft()
|
||||
self.running.append(request)
|
||||
|
||||
# If no prefill requests are scheduled,
|
||||
# Schedule decode requests next.
|
||||
if len(self.scheduled_req_ids) == 0:
|
||||
@@ -267,6 +330,16 @@ class AscendScheduler(Scheduler):
|
||||
num_new_tokens = min(
|
||||
num_new_tokens,
|
||||
self.max_model_len - request.num_computed_tokens)
|
||||
|
||||
# Schedule encoder inputs.
|
||||
encoder_inputs_to_schedule = None
|
||||
new_encoder_budget = encoder_budget
|
||||
if request.has_encoder_inputs:
|
||||
(encoder_inputs_to_schedule, num_new_tokens,
|
||||
new_encoder_budget) = self._try_schedule_encoder_inputs(
|
||||
request, request.num_computed_tokens, num_new_tokens,
|
||||
encoder_budget)
|
||||
|
||||
# Check that adding the request still respects the max_loras
|
||||
# constraint.
|
||||
if self.lora_config and request.lora_request and (
|
||||
@@ -322,11 +395,7 @@ class AscendScheduler(Scheduler):
|
||||
# Schedule the request.
|
||||
scheduled_running_reqs.append(request)
|
||||
self.scheduled_req_ids.add(request.request_id)
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
req_to_new_block_ids[request.request_id] = (
|
||||
new_blocks.get_block_ids())
|
||||
else:
|
||||
req_to_new_blocks[request.request_id] = new_blocks
|
||||
req_to_new_blocks[request.request_id] = new_blocks
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
req_index += 1
|
||||
@@ -342,6 +411,15 @@ class AscendScheduler(Scheduler):
|
||||
scheduled_spec_decode_tokens[request.request_id] = (
|
||||
request.spec_token_ids)
|
||||
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request.request_id] = (
|
||||
encoder_inputs_to_schedule)
|
||||
# Allocate the encoder cache.
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_budget = new_encoder_budget
|
||||
|
||||
# Record scheduled LoRA requests.
|
||||
if self.lora_config and request.lora_request:
|
||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||
@@ -350,7 +428,9 @@ class AscendScheduler(Scheduler):
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
||||
assert token_budget >= 0
|
||||
assert len(self.running) <= self.max_num_running_reqs
|
||||
assert len(
|
||||
self.running
|
||||
) <= self.decode_max_num_running_reqs if self.phase == "decode" else self.max_num_running_reqs
|
||||
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
|
||||
scheduled_running_reqs) <= len(self.running)
|
||||
|
||||
@@ -365,67 +445,36 @@ class AscendScheduler(Scheduler):
|
||||
any_request, len(self.running)))
|
||||
|
||||
# Construct the scheduler output.
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req, req_to_new_block_ids[req.request_id])
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
scheduled_running_reqs, scheduled_resumed_reqs,
|
||||
num_scheduled_tokens, scheduled_spec_decode_tokens,
|
||||
req_to_new_block_ids)
|
||||
else:
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req, req_to_new_blocks[req.request_id].get_block_ids())
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req, req_to_new_blocks[req.request_id].get_block_ids())
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
scheduled_running_reqs, scheduled_resumed_reqs,
|
||||
num_scheduled_tokens, scheduled_spec_decode_tokens,
|
||||
req_to_new_blocks)
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
scheduled_running_reqs, scheduled_resumed_reqs,
|
||||
num_scheduled_tokens, scheduled_spec_decode_tokens,
|
||||
req_to_new_blocks)
|
||||
scheduled_cached_reqs = cached_reqs_data
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
scheduled_cached_reqs=scheduled_cached_reqs,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
# It contains the request IDs that are finished in between
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids, # type: ignore
|
||||
free_encoder_input_ids=self.encoder_cache_manager.
|
||||
get_freed_ids(),
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
else:
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
scheduled_cached_reqs=scheduled_cached_reqs,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
# It contains the request IDs that are finished in between
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids, # type: ignore
|
||||
free_encoder_mm_hashes=self.encoder_cache_manager.
|
||||
get_freed_mm_hashes(),
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
scheduled_cached_reqs=scheduled_cached_reqs,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
# It contains the request IDs that are finished in between
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids, # type: ignore
|
||||
free_encoder_mm_hashes=self.encoder_cache_manager.
|
||||
get_freed_mm_hashes(),
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
# NOTE(Kuntai): this function is designed for multiple purposes:
|
||||
# 1. Plan the KV cache store
|
||||
|
||||
Reference in New Issue
Block a user