Files
2026-03-10 13:31:25 +08:00

180 lines
6.8 KiB
Python

################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import asyncio
import os
import socket
from typing import Optional
import torch
from fastcore.basics import patch_to
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.tracing import init_tracer
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
from vllm_br import envs as envs_br
from vllm_br.utils import (create_cpu_all_reduce_shared_mem,
get_cpu_all_reduce_shared_mem)
logger = init_logger(__name__)
@patch_to(AsyncLLM)
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
log_requests: bool = True,
start_engine_loop: bool = True,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
client_addresses: Optional[dict[str, str]] = None,
client_count: int = 1,
client_index: int = 0,
) -> None:
"""
Create an AsyncLLM.
Args:
vllm_config: global configuration.
executor_class: an Executor impl, e.g. MultiprocExecutor.
log_stats: Whether to log stats.
usage_context: Usage context of the LLM.
mm_registry: Multi-modal registry.
use_cached_outputs: Whether to use cached outputs.
log_requests: Whether to log requests.
start_engine_loop: Whether to start the engine loop.
stat_loggers: customized stat loggers for the engine.
If not provided, default stat loggers will be used.
PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.
Returns:
None
"""
if not envs.VLLM_USE_V1:
raise ValueError(
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
"This should not happen. As a workaround, try using "
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
if envs_br.VLLM_BR_USE_CPU_ALL_REDUCE != 0:
create_cpu_all_reduce_shared_mem()
# Ensure we can serialize custom transformer configs
maybe_register_config_serialize_by_value()
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.observability_config = vllm_config.observability_config
self.log_requests = log_requests
self.log_stats = log_stats or (stat_loggers is not None)
if not log_stats and stat_loggers is not None:
logger.info(
"AsyncLLM created with log_stats=False and non-empty custom "
"logger list; enabling logging without default stat loggers")
if self.model_config.skip_tokenizer_init:
self.tokenizer = None
else:
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config)
# Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor(
vllm_config=vllm_config,
tokenizer=self.tokenizer,
mm_registry=mm_registry,
)
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor(self.tokenizer,
log_stats=self.log_stats)
if self.observability_config.otlp_traces_endpoint is not None:
tracer = init_tracer("vllm.llm_engine",
self.observability_config.otlp_traces_endpoint)
self.output_processor.tracer = tracer
# EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_async_mp_client(
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
client_addresses=client_addresses,
client_count=client_count,
client_index=client_index,
)
# Loggers.
self.logger_manager: Optional[StatLoggerManager] = None # type: ignore
if self.log_stats:
self.logger_manager = StatLoggerManager(
vllm_config=vllm_config,
engine_idxs=self.engine_core.engine_ranks_managed,
custom_stat_loggers=stat_loggers,
enable_default_loggers=log_stats,
client_count=client_count,
)
self.logger_manager.log_engine_initialized()
self.output_handler: Optional[asyncio.Task] = None # type: ignore
try:
# Start output handler eagerly if we are in the asyncio eventloop.
asyncio.get_running_loop()
self._run_output_handler()
except RuntimeError:
pass
if envs.VLLM_TORCH_PROFILER_DIR:
logger.info(
"Torch profiler enabled. AsyncLLM CPU traces will be collected under %s", # noqa: E501
envs.VLLM_TORCH_PROFILER_DIR)
worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
],
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
envs.VLLM_TORCH_PROFILER_DIR,
worker_name=worker_name,
use_gzip=True))
else:
self.profiler = None
@patch_to(AsyncLLM)
def __del__(self):
if get_cpu_all_reduce_shared_mem() is not None:
get_cpu_all_reduce_shared_mem()._cleanup()
self.shutdown()