init
This commit is contained in:
176
vllm_vacc/vllm/v1/engine/llm_engine.py
Normal file
176
vllm_vacc/vllm/v1/engine/llm_engine.py
Normal file
@@ -0,0 +1,176 @@
|
||||
|
||||
from collections.abc import Mapping
|
||||
from copy import copy
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
# from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Device
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
|
||||
StatLoggerFactory)
|
||||
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
class LLMEngine:
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
disable_log_stats: bool = False,
|
||||
) -> "LLMEngine":
|
||||
# vacc support spec_num = 1
|
||||
from .vllm_config_checker import check_spec_model
|
||||
check_spec_model(vllm_config)
|
||||
|
||||
from vllm.v1.engine.llm_engine import LLMEngine as DefaultLLM
|
||||
default_cls = DefaultLLM
|
||||
return default_cls(vllm_config=vllm_config,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
log_stats=(not disable_log_stats),
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING)
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: EngineArgs,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
enable_multiprocessing: bool = False,
|
||||
) -> "LLMEngine":
|
||||
"""Creates an LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
vllm_config = engine_args.create_engine_config(usage_context)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
# vacc support spec_num = 1
|
||||
from .vllm_config_checker import check_spec_model
|
||||
check_spec_model(vllm_config)
|
||||
|
||||
if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
|
||||
logger.debug("Enabling multiprocessing for LLMEngine.")
|
||||
enable_multiprocessing = True
|
||||
|
||||
# Create the LLMEngine.
|
||||
from vllm.v1.engine.llm_engine import LLMEngine as DefaultLLM
|
||||
default_cls = DefaultLLM
|
||||
return default_cls(vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
multiprocess_mode=enable_multiprocessing)
|
||||
|
||||
"""Legacy LLMEngine for backwards compatibility."""
|
||||
def add_requests(
|
||||
self,
|
||||
items: list[tuple[
|
||||
str, # request_id
|
||||
PromptType, # prompt
|
||||
Union[SamplingParams, PoolingParams], # params
|
||||
Optional[float], # arrival_time
|
||||
Optional[LoRARequest], # lora_request
|
||||
Optional[dict], # tokenization_kwargs
|
||||
Optional[dict], # trace_headers
|
||||
# Optional[PromptAdapterRequest], # prompt_adapter_request
|
||||
int, # priority
|
||||
]],
|
||||
) -> None:
|
||||
"""批量把请求送入 EngineCore,一次性触发 ADD_BULK。"""
|
||||
core_reqs: list[EngineCoreRequest] = []
|
||||
|
||||
for (request_id, prompt, params, arrival_time, lora_request,
|
||||
tokenization_kwargs, trace_headers,
|
||||
priority) in items:
|
||||
|
||||
# 复用现有逐条流程的解析入口
|
||||
prompt_str, request = self.processor.process_inputs(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
params=params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
# prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
n = params.n if isinstance(params, SamplingParams) else 1
|
||||
|
||||
if n == 1:
|
||||
# Make a new RequestState and queue.
|
||||
self.output_processor.add_request(request, prompt_str, None, 0)
|
||||
# Add the request to EngineCore.
|
||||
core_reqs.append(request)
|
||||
continue
|
||||
# self.engine_core.add_request(request)
|
||||
# return
|
||||
|
||||
# Fan out child requests (for n>1).
|
||||
parent_req = ParentRequest(request_id, params)
|
||||
for idx in range(n):
|
||||
request_id, params = parent_req.get_child_info(idx)
|
||||
child_request = request if idx == n - 1 else copy(request)
|
||||
child_request.request_id = request_id
|
||||
child_request.sampling_params = params
|
||||
|
||||
# Make a new RequestState and queue.
|
||||
self.output_processor.add_request(child_request, prompt_str,
|
||||
parent_req, idx)
|
||||
# Add the request to EngineCore.
|
||||
# self.engine_core.add_request(child_request)
|
||||
# print("add_requests: child_request id=", child_request.request_id)
|
||||
core_reqs.append(child_request)
|
||||
|
||||
# output_processor 需要为每个“实际进入引擎的 req_id”建索引。
|
||||
# 如果是 SamplingParams 且 n>1/best_of>1,要做 parent-children 拆分;
|
||||
# 否则直接登记单条。
|
||||
# if isinstance(params, SamplingParams) and (
|
||||
# (params.n is not None and params.n > 1) or
|
||||
# (getattr(params, "best_of", 1) and getattr(params, "best_of", 1) > 1)
|
||||
# ):
|
||||
# parent = self.parallel_sampler.create_parent(request_id, params)
|
||||
# # 注意:最后一个 child 可以直接复用 request,其余用 copy
|
||||
# children = self.parallel_sampler.materialize_children(parent, request)
|
||||
# for child_idx, child in enumerate(children):
|
||||
# self.output_processor.add_request(
|
||||
# request=child,
|
||||
# prompt_str=prompt_str,
|
||||
# parent=parent,
|
||||
# child_index=child_idx,
|
||||
# )
|
||||
# core_reqs.append(child)
|
||||
# else:
|
||||
# self.output_processor.add_request(request, prompt_str)
|
||||
# core_reqs.append(request)
|
||||
|
||||
# 关键:一次性下发给 Core。EngineCoreClient 会发送 ADD_BULK。
|
||||
print('self.engine_core', self.engine_core)
|
||||
self.engine_core.add_requests(core_reqs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user