158 lines
6.5 KiB
Python
158 lines
6.5 KiB
Python
################################################################################
|
|
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
################################################################################
|
|
|
|
import os
|
|
import time
|
|
from typing import Optional
|
|
|
|
from fastcore.basics import patch_to
|
|
|
|
from vllm.config import ParallelConfig, VllmConfig
|
|
from vllm.logger import logger
|
|
from vllm.v1.core.kv_cache_utils import (generate_scheduler_kv_cache_config,
|
|
get_kv_cache_configs)
|
|
from vllm.v1.engine import EngineCoreOutputs
|
|
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
|
|
|
|
@patch_to(EngineCore)
|
|
def _initialize_kv_caches(
|
|
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
|
|
start = time.time()
|
|
|
|
# Get all kv cache needed by the model
|
|
kv_cache_specs = self.model_executor.get_kv_cache_specs()
|
|
|
|
has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
|
|
if has_kv_cache:
|
|
if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
|
|
dp_group = getattr(self, "dp_group", None)
|
|
assert dp_group is not None
|
|
self.available_gpu_memory_for_kv_cache = \
|
|
ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
|
|
available_gpu_memory = [self.available_gpu_memory_for_kv_cache
|
|
] * len(kv_cache_specs)
|
|
else:
|
|
# Profiles the peak memory usage of the model to determine how
|
|
# much memory can be allocated for kv cache.
|
|
available_gpu_memory = (
|
|
self.model_executor.determine_available_memory())
|
|
self.available_gpu_memory_for_kv_cache = \
|
|
available_gpu_memory[0]
|
|
else:
|
|
# Attention free models don't need memory for kv cache
|
|
available_gpu_memory = [0] * len(kv_cache_specs)
|
|
available_gpu_memory = self.model_executor.determine_available_memory()
|
|
assert len(kv_cache_specs) == len(available_gpu_memory)
|
|
|
|
kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
|
|
available_gpu_memory)
|
|
scheduler_kv_cache_config = generate_scheduler_kv_cache_config(
|
|
kv_cache_configs)
|
|
num_gpu_blocks = scheduler_kv_cache_config.num_blocks
|
|
num_cpu_blocks = 0
|
|
|
|
# Initialize kv cache and warmup the execution
|
|
self.model_executor.initialize_from_config(kv_cache_configs)
|
|
|
|
elapsed = time.time() - start
|
|
logger.info(("init engine (profile, create kv cache, "
|
|
"warmup model) took %.2f seconds"), elapsed)
|
|
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
|
|
|
|
|
|
@patch_to(EngineCore)
|
|
def step_with_batch_queue(
|
|
self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
|
|
"""Schedule and execute batches with the batch queue.
|
|
Note that if nothing to output in this step, None is returned.
|
|
|
|
The execution flow is as follows:
|
|
1. Try to schedule a new batch if the batch queue is not full.
|
|
If a new batch is scheduled, directly return an empty engine core
|
|
output. In other words, fulfilling the batch queue has a higher priority
|
|
than getting model outputs.
|
|
2. If there is no new scheduled batch, meaning that the batch queue
|
|
is full or no other requests can be scheduled, we block until the first
|
|
batch in the job queue is finished.
|
|
3. Update the scheduler from the output.
|
|
"""
|
|
batch_queue = self.batch_queue
|
|
assert batch_queue is not None
|
|
|
|
# Try to schedule a new batch if the batch queue is not full, but
|
|
# the scheduler may return an empty batch if all requests are scheduled.
|
|
# Note that this is not blocking.
|
|
assert len(batch_queue) < self.batch_queue_size
|
|
|
|
model_executed = False
|
|
if self.scheduler.has_requests():
|
|
scheduler_output = self.scheduler.schedule()
|
|
future = self.model_executor.execute_model(scheduler_output,
|
|
non_block=True)
|
|
batch_queue.appendleft(
|
|
(future, scheduler_output)) # type: ignore[arg-type]
|
|
|
|
model_executed = scheduler_output.total_num_scheduled_tokens > 0
|
|
if model_executed and len(batch_queue) < self.batch_queue_size \
|
|
and not batch_queue[-1][0].done():
|
|
# Don't block on next worker response unless the queue is full
|
|
# or there are no more requests to schedule.
|
|
return None, True
|
|
|
|
elif not batch_queue:
|
|
# Queue is empty. We should not reach here since this method should
|
|
# only be called when the scheduler contains requests or the queue
|
|
# is non-empty.
|
|
return None, False
|
|
|
|
# Block until the next result is available.
|
|
future, scheduler_output = batch_queue.pop()
|
|
model_output = self.execute_model_with_error_logging(
|
|
lambda _: future.result(), scheduler_output)
|
|
if scheduler_output.total_num_scheduled_tokens != 0:
|
|
engine_core_outputs = self.scheduler.update_from_output(
|
|
scheduler_output, model_output)
|
|
if self.use_spec_decode:
|
|
# Take the draft token ids.
|
|
# draft_token_ids = self.model_executor.take_draft_token_ids()
|
|
if model_output.draft_token_ids is not None:
|
|
model_output.draft_token_ids.req_ids = model_output.req_ids
|
|
self.scheduler.update_draft_token_ids(
|
|
model_output.draft_token_ids)
|
|
else:
|
|
pass
|
|
return engine_core_outputs, model_executed
|
|
else:
|
|
return None, False
|
|
|
|
|
|
@patch_to(EngineCoreProc)
|
|
def _process_engine_step(self) -> bool:
|
|
"""Called only when there are unfinished local requests."""
|
|
|
|
# Step the engine core.
|
|
outputs, model_executed = self.step_fn()
|
|
# Put EngineCoreOutputs into the output queue.
|
|
for output in (outputs.items() if outputs else ()):
|
|
self.output_queue.put_nowait(output)
|
|
# Post-step hook.
|
|
# if outputs is not None:
|
|
# self.post_step(model_executed)
|
|
|
|
return model_executed
|