103 lines
4.5 KiB
Python
103 lines
4.5 KiB
Python
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR,
|
|
RPCError,
|
|
RPCProcessRequest,
|
|
RPCAbortRequest)
|
|
from vllm.config import VllmConfig
|
|
import signal
|
|
from vllm.logger import init_logger
|
|
from vllm.transformers_utils.config import (
|
|
maybe_register_config_serialize_by_value)
|
|
from vllm.usage.usage_lib import UsageContext
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class MQLLMEngine:
|
|
|
|
def _handle_process_request(self, request: RPCProcessRequest):
|
|
"""Handle RPCProcessRequest by adding it to the LLMEngine."""
|
|
request_id = request.request_id
|
|
|
|
if self._errored_with is not None:
|
|
rpc_err = RPCError(request_id=request_id,
|
|
is_engine_errored=True,
|
|
exception=ENGINE_DEAD_ERROR(self._errored_with))
|
|
self._send_outputs(rpc_err)
|
|
|
|
try:
|
|
self.engine.add_request(
|
|
request_id=request_id,
|
|
prompt=request.prompt,
|
|
params=request.params,
|
|
lora_request=request.lora_request,
|
|
trace_headers=request.trace_headers,
|
|
prompt_adapter_request=request.prompt_adapter_request,
|
|
priority=request.priority)
|
|
|
|
if self.log_requests:
|
|
from vllm.engine.multiprocessing.engine import logger
|
|
|
|
if request.prompt.get('prompt_token_ids') is not None:
|
|
# logger.info("Added request: %s, %s, prompt length: %s", request.request_id, request.prompt['prompt_token_ids'], len(request.prompt['prompt_token_ids']))
|
|
logger.info("Added request: %s, prompt length: %s", request.request_id, len(request.prompt['prompt_token_ids']))
|
|
else:
|
|
logger.info("Added request %s.", request.request_id)
|
|
|
|
except Exception as e:
|
|
# We do not set self._errored = True here, since the error
|
|
# is due to an issue adding this request to the engine,
|
|
# rather than an issue with the engine itself.
|
|
is_errored = self._errored_with is not None
|
|
rpc_err = RPCError(request_id=request_id,
|
|
is_engine_errored=is_errored,
|
|
exception=e)
|
|
self._send_outputs(rpc_err)
|
|
|
|
# Remove request from the engine.
|
|
self.engine.abort_request(request_id)
|
|
|
|
def _handle_abort_request(self, request: RPCAbortRequest):
|
|
self.engine.abort_request(request.request_id)
|
|
if self.log_requests:
|
|
from vllm.engine.multiprocessing.engine import logger
|
|
import vllm_vacc.vllm.model_executor.models.vars as global_vars
|
|
logger.info("Aborted request: %s, prompt length: %s", request.request_id, global_vars.DO_SEQ_LENS)
|
|
|
|
def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext,
|
|
ipc_path: str, disable_log_stats: bool,
|
|
disable_log_requests: bool, engine_alive):
|
|
|
|
#patch to prevent num_speculative_tokens > 1
|
|
speculative_mode = hasattr(vllm_config, 'speculative_config')
|
|
if speculative_mode and \
|
|
hasattr(vllm_config.speculative_config, 'num_speculative_tokens') and \
|
|
vllm_config.speculative_config.num_speculative_tokens != 1:
|
|
raise ValueError(f'run_mp_engine: only support num_speculative_tokens == 1, but get {vllm_config.speculative_config.num_speculative_tokens}')
|
|
|
|
default_model_infos = "default"
|
|
if speculative_mode:
|
|
if hasattr(vllm_config.speculative_config, 'method'):
|
|
default_model_infos = vllm_config.speculative_config.method
|
|
|
|
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
|
|
vllm_vacc_config_manager().update_model_infos(default_model_infos)
|
|
|
|
try:
|
|
# Ensure we can serialize transformer config before spawning
|
|
maybe_register_config_serialize_by_value()
|
|
from vllm.engine.multiprocessing.engine import MQLLMEngine,signal_handler
|
|
engine = MQLLMEngine.from_vllm_config(
|
|
vllm_config=vllm_config,
|
|
usage_context=usage_context,
|
|
disable_log_stats=disable_log_stats,
|
|
disable_log_requests=disable_log_requests,
|
|
ipc_path=ipc_path)
|
|
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
|
|
engine.start()
|
|
|
|
except BaseException as e:
|
|
logger.exception(e)
|
|
engine_alive.value = False
|
|
raise e |