Files
enginex-vastai-va16-vllm/vllm_vacc/vllm/v1/engine/llm_engine.py
2026-04-02 04:55:00 +00:00

176 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from collections.abc import Mapping
from copy import copy
from typing import Any, Callable, Optional, Union
from typing_extensions import TypeVar
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
# from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
StatLoggerFactory)
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
from vllm.v1.engine import EngineCoreRequest
logger = init_logger(__name__)
class LLMEngine:
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
disable_log_stats: bool = False,
) -> "LLMEngine":
# vacc support spec_num = 1
from .vllm_config_checker import check_spec_model
check_spec_model(vllm_config)
from vllm.v1.engine.llm_engine import LLMEngine as DefaultLLM
default_cls = DefaultLLM
return default_cls(vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
log_stats=(not disable_log_stats),
usage_context=usage_context,
stat_loggers=stat_loggers,
multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING)
@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
enable_multiprocessing: bool = False,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
vllm_config = engine_args.create_engine_config(usage_context)
executor_class = Executor.get_class(vllm_config)
# vacc support spec_num = 1
from .vllm_config_checker import check_spec_model
check_spec_model(vllm_config)
if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
logger.debug("Enabling multiprocessing for LLMEngine.")
enable_multiprocessing = True
# Create the LLMEngine.
from vllm.v1.engine.llm_engine import LLMEngine as DefaultLLM
default_cls = DefaultLLM
return default_cls(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
stat_loggers=stat_loggers,
multiprocess_mode=enable_multiprocessing)
"""Legacy LLMEngine for backwards compatibility."""
def add_requests(
self,
items: list[tuple[
str, # request_id
PromptType, # prompt
Union[SamplingParams, PoolingParams], # params
Optional[float], # arrival_time
Optional[LoRARequest], # lora_request
Optional[dict], # tokenization_kwargs
Optional[dict], # trace_headers
# Optional[PromptAdapterRequest], # prompt_adapter_request
int, # priority
]],
) -> None:
"""批量把请求送入 EngineCore一次性触发 ADD_BULK。"""
core_reqs: list[EngineCoreRequest] = []
for (request_id, prompt, params, arrival_time, lora_request,
tokenization_kwargs, trace_headers,
priority) in items:
# 复用现有逐条流程的解析入口
prompt_str, request = self.processor.process_inputs(
request_id=request_id,
prompt=prompt,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
# prompt_adapter_request=prompt_adapter_request,
priority=priority,
)
n = params.n if isinstance(params, SamplingParams) else 1
if n == 1:
# Make a new RequestState and queue.
self.output_processor.add_request(request, prompt_str, None, 0)
# Add the request to EngineCore.
core_reqs.append(request)
continue
# self.engine_core.add_request(request)
# return
# Fan out child requests (for n>1).
parent_req = ParentRequest(request_id, params)
for idx in range(n):
request_id, params = parent_req.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = params
# Make a new RequestState and queue.
self.output_processor.add_request(child_request, prompt_str,
parent_req, idx)
# Add the request to EngineCore.
# self.engine_core.add_request(child_request)
# print("add_requests: child_request id=", child_request.request_id)
core_reqs.append(child_request)
# output_processor 需要为每个“实际进入引擎的 req_id”建索引。
# 如果是 SamplingParams 且 n>1/best_of>1要做 parent-children 拆分;
# 否则直接登记单条。
# if isinstance(params, SamplingParams) and (
# (params.n is not None and params.n > 1) or
# (getattr(params, "best_of", 1) and getattr(params, "best_of", 1) > 1)
# ):
# parent = self.parallel_sampler.create_parent(request_id, params)
# # 注意:最后一个 child 可以直接复用 request其余用 copy
# children = self.parallel_sampler.materialize_children(parent, request)
# for child_idx, child in enumerate(children):
# self.output_processor.add_request(
# request=child,
# prompt_str=prompt_str,
# parent=parent,
# child_index=child_idx,
# )
# core_reqs.append(child)
# else:
# self.output_processor.add_request(request, prompt_str)
# core_reqs.append(request)
# 关键:一次性下发给 Core。EngineCoreClient 会发送 ADD_BULK。
print('self.engine_core', self.engine_core)
self.engine_core.add_requests(core_reqs)