176 lines
7.6 KiB
Python
176 lines
7.6 KiB
Python
|
||
from collections.abc import Mapping
|
||
from copy import copy
|
||
from typing import Any, Callable, Optional, Union
|
||
|
||
from typing_extensions import TypeVar
|
||
|
||
import vllm.envs as envs
|
||
from vllm.config import ParallelConfig, VllmConfig
|
||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||
from vllm.engine.arg_utils import EngineArgs
|
||
from vllm.inputs import PromptType
|
||
from vllm.logger import init_logger
|
||
from vllm.lora.request import LoRARequest
|
||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||
from vllm.pooling_params import PoolingParams
|
||
# from vllm.prompt_adapter.request import PromptAdapterRequest
|
||
from vllm.sampling_params import SamplingParams
|
||
from vllm.usage.usage_lib import UsageContext
|
||
from vllm.utils import Device
|
||
from vllm.v1.engine.core_client import EngineCoreClient
|
||
from vllm.v1.engine.output_processor import OutputProcessor
|
||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||
from vllm.v1.engine.processor import Processor
|
||
from vllm.v1.executor.abstract import Executor
|
||
from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
|
||
StatLoggerFactory)
|
||
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
|
||
from vllm.v1.metrics.stats import IterationStats
|
||
from vllm.v1.engine import EngineCoreRequest
|
||
|
||
logger = init_logger(__name__)
|
||
class LLMEngine:
|
||
|
||
@classmethod
|
||
def from_vllm_config(
|
||
cls,
|
||
vllm_config: VllmConfig,
|
||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||
disable_log_stats: bool = False,
|
||
) -> "LLMEngine":
|
||
# vacc support spec_num = 1
|
||
from .vllm_config_checker import check_spec_model
|
||
check_spec_model(vllm_config)
|
||
|
||
from vllm.v1.engine.llm_engine import LLMEngine as DefaultLLM
|
||
default_cls = DefaultLLM
|
||
return default_cls(vllm_config=vllm_config,
|
||
executor_class=Executor.get_class(vllm_config),
|
||
log_stats=(not disable_log_stats),
|
||
usage_context=usage_context,
|
||
stat_loggers=stat_loggers,
|
||
multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING)
|
||
|
||
@classmethod
|
||
def from_engine_args(
|
||
cls,
|
||
engine_args: EngineArgs,
|
||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||
enable_multiprocessing: bool = False,
|
||
) -> "LLMEngine":
|
||
"""Creates an LLM engine from the engine arguments."""
|
||
# Create the engine configs.
|
||
vllm_config = engine_args.create_engine_config(usage_context)
|
||
executor_class = Executor.get_class(vllm_config)
|
||
|
||
# vacc support spec_num = 1
|
||
from .vllm_config_checker import check_spec_model
|
||
check_spec_model(vllm_config)
|
||
|
||
if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
|
||
logger.debug("Enabling multiprocessing for LLMEngine.")
|
||
enable_multiprocessing = True
|
||
|
||
# Create the LLMEngine.
|
||
from vllm.v1.engine.llm_engine import LLMEngine as DefaultLLM
|
||
default_cls = DefaultLLM
|
||
return default_cls(vllm_config=vllm_config,
|
||
executor_class=executor_class,
|
||
log_stats=not engine_args.disable_log_stats,
|
||
usage_context=usage_context,
|
||
stat_loggers=stat_loggers,
|
||
multiprocess_mode=enable_multiprocessing)
|
||
|
||
"""Legacy LLMEngine for backwards compatibility."""
|
||
def add_requests(
|
||
self,
|
||
items: list[tuple[
|
||
str, # request_id
|
||
PromptType, # prompt
|
||
Union[SamplingParams, PoolingParams], # params
|
||
Optional[float], # arrival_time
|
||
Optional[LoRARequest], # lora_request
|
||
Optional[dict], # tokenization_kwargs
|
||
Optional[dict], # trace_headers
|
||
# Optional[PromptAdapterRequest], # prompt_adapter_request
|
||
int, # priority
|
||
]],
|
||
) -> None:
|
||
"""批量把请求送入 EngineCore,一次性触发 ADD_BULK。"""
|
||
core_reqs: list[EngineCoreRequest] = []
|
||
|
||
for (request_id, prompt, params, arrival_time, lora_request,
|
||
tokenization_kwargs, trace_headers,
|
||
priority) in items:
|
||
|
||
# 复用现有逐条流程的解析入口
|
||
prompt_str, request = self.processor.process_inputs(
|
||
request_id=request_id,
|
||
prompt=prompt,
|
||
params=params,
|
||
arrival_time=arrival_time,
|
||
lora_request=lora_request,
|
||
tokenization_kwargs=tokenization_kwargs,
|
||
trace_headers=trace_headers,
|
||
# prompt_adapter_request=prompt_adapter_request,
|
||
priority=priority,
|
||
)
|
||
|
||
n = params.n if isinstance(params, SamplingParams) else 1
|
||
|
||
if n == 1:
|
||
# Make a new RequestState and queue.
|
||
self.output_processor.add_request(request, prompt_str, None, 0)
|
||
# Add the request to EngineCore.
|
||
core_reqs.append(request)
|
||
continue
|
||
# self.engine_core.add_request(request)
|
||
# return
|
||
|
||
# Fan out child requests (for n>1).
|
||
parent_req = ParentRequest(request_id, params)
|
||
for idx in range(n):
|
||
request_id, params = parent_req.get_child_info(idx)
|
||
child_request = request if idx == n - 1 else copy(request)
|
||
child_request.request_id = request_id
|
||
child_request.sampling_params = params
|
||
|
||
# Make a new RequestState and queue.
|
||
self.output_processor.add_request(child_request, prompt_str,
|
||
parent_req, idx)
|
||
# Add the request to EngineCore.
|
||
# self.engine_core.add_request(child_request)
|
||
# print("add_requests: child_request id=", child_request.request_id)
|
||
core_reqs.append(child_request)
|
||
|
||
# output_processor 需要为每个“实际进入引擎的 req_id”建索引。
|
||
# 如果是 SamplingParams 且 n>1/best_of>1,要做 parent-children 拆分;
|
||
# 否则直接登记单条。
|
||
# if isinstance(params, SamplingParams) and (
|
||
# (params.n is not None and params.n > 1) or
|
||
# (getattr(params, "best_of", 1) and getattr(params, "best_of", 1) > 1)
|
||
# ):
|
||
# parent = self.parallel_sampler.create_parent(request_id, params)
|
||
# # 注意:最后一个 child 可以直接复用 request,其余用 copy
|
||
# children = self.parallel_sampler.materialize_children(parent, request)
|
||
# for child_idx, child in enumerate(children):
|
||
# self.output_processor.add_request(
|
||
# request=child,
|
||
# prompt_str=prompt_str,
|
||
# parent=parent,
|
||
# child_index=child_idx,
|
||
# )
|
||
# core_reqs.append(child)
|
||
# else:
|
||
# self.output_processor.add_request(request, prompt_str)
|
||
# core_reqs.append(request)
|
||
|
||
# 关键:一次性下发给 Core。EngineCoreClient 会发送 ADD_BULK。
|
||
print('self.engine_core', self.engine_core)
|
||
self.engine_core.add_requests(core_reqs)
|
||
|
||
|