Files

176 lines
7.6 KiB
Python
Raw Permalink Normal View History

2026-04-02 04:53:13 +00:00
from collections.abc import Mapping
from copy import copy
from typing import Any, Callable, Optional, Union
from typing_extensions import TypeVar
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
# from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
StatLoggerFactory)
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
from vllm.v1.engine import EngineCoreRequest
logger = init_logger(__name__)
class LLMEngine:
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
disable_log_stats: bool = False,
) -> "LLMEngine":
# vacc support spec_num = 1
from .vllm_config_checker import check_spec_model
check_spec_model(vllm_config)
from vllm.v1.engine.llm_engine import LLMEngine as DefaultLLM
default_cls = DefaultLLM
return default_cls(vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
log_stats=(not disable_log_stats),
usage_context=usage_context,
stat_loggers=stat_loggers,
multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING)
@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
enable_multiprocessing: bool = False,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
vllm_config = engine_args.create_engine_config(usage_context)
executor_class = Executor.get_class(vllm_config)
# vacc support spec_num = 1
from .vllm_config_checker import check_spec_model
check_spec_model(vllm_config)
if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
logger.debug("Enabling multiprocessing for LLMEngine.")
enable_multiprocessing = True
# Create the LLMEngine.
from vllm.v1.engine.llm_engine import LLMEngine as DefaultLLM
default_cls = DefaultLLM
return default_cls(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
stat_loggers=stat_loggers,
multiprocess_mode=enable_multiprocessing)
"""Legacy LLMEngine for backwards compatibility."""
def add_requests(
self,
items: list[tuple[
str, # request_id
PromptType, # prompt
Union[SamplingParams, PoolingParams], # params
Optional[float], # arrival_time
Optional[LoRARequest], # lora_request
Optional[dict], # tokenization_kwargs
Optional[dict], # trace_headers
# Optional[PromptAdapterRequest], # prompt_adapter_request
int, # priority
]],
) -> None:
"""批量把请求送入 EngineCore一次性触发 ADD_BULK。"""
core_reqs: list[EngineCoreRequest] = []
for (request_id, prompt, params, arrival_time, lora_request,
tokenization_kwargs, trace_headers,
priority) in items:
# 复用现有逐条流程的解析入口
prompt_str, request = self.processor.process_inputs(
request_id=request_id,
prompt=prompt,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
# prompt_adapter_request=prompt_adapter_request,
priority=priority,
)
n = params.n if isinstance(params, SamplingParams) else 1
if n == 1:
# Make a new RequestState and queue.
self.output_processor.add_request(request, prompt_str, None, 0)
# Add the request to EngineCore.
core_reqs.append(request)
continue
# self.engine_core.add_request(request)
# return
# Fan out child requests (for n>1).
parent_req = ParentRequest(request_id, params)
for idx in range(n):
request_id, params = parent_req.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = params
# Make a new RequestState and queue.
self.output_processor.add_request(child_request, prompt_str,
parent_req, idx)
# Add the request to EngineCore.
# self.engine_core.add_request(child_request)
# print("add_requests: child_request id=", child_request.request_id)
core_reqs.append(child_request)
# output_processor 需要为每个“实际进入引擎的 req_id”建索引。
# 如果是 SamplingParams 且 n>1/best_of>1要做 parent-children 拆分;
# 否则直接登记单条。
# if isinstance(params, SamplingParams) and (
# (params.n is not None and params.n > 1) or
# (getattr(params, "best_of", 1) and getattr(params, "best_of", 1) > 1)
# ):
# parent = self.parallel_sampler.create_parent(request_id, params)
# # 注意:最后一个 child 可以直接复用 request其余用 copy
# children = self.parallel_sampler.materialize_children(parent, request)
# for child_idx, child in enumerate(children):
# self.output_processor.add_request(
# request=child,
# prompt_str=prompt_str,
# parent=parent,
# child_index=child_idx,
# )
# core_reqs.append(child)
# else:
# self.output_processor.add_request(request, prompt_str)
# core_reqs.append(request)
# 关键:一次性下发给 Core。EngineCoreClient 会发送 ADD_BULK。
print('self.engine_core', self.engine_core)
self.engine_core.add_requests(core_reqs)