forked from EngineX-MetaX/enginex-c_series-vllm
[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
558
vllm/v1/engine/async_llm.py
Normal file
558
vllm/v1/engine/async_llm.py
Normal file
@@ -0,0 +1,558 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from copy import copy
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Device, cdiv
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
|
||||
from vllm.v1.engine.output_processor import (OutputProcessor,
|
||||
RequestOutputCollector)
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
|
||||
setup_default_loggers)
|
||||
from vllm.v1.metrics.prometheus import shutdown_prometheus
|
||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AsyncLLM(EngineClient):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
use_cached_outputs: bool = False,
|
||||
log_requests: bool = True,
|
||||
start_engine_loop: bool = True,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_index: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Create an AsyncLLM.
|
||||
|
||||
Args:
|
||||
vllm_config: global configuration.
|
||||
executor_class: an Executor impl, e.g. MultiprocExecutor.
|
||||
log_stats: Whether to log stats.
|
||||
usage_context: Usage context of the LLM.
|
||||
mm_registry: Multi-modal registry.
|
||||
use_cached_outputs: Whether to use cached outputs.
|
||||
log_requests: Whether to log requests.
|
||||
start_engine_loop: Whether to start the engine loop.
|
||||
stat_loggers: customized stat loggers for the engine.
|
||||
If not provided, default stat loggers will be used.
|
||||
PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
|
||||
IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError(
|
||||
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
|
||||
"This should not happen. As a workaround, try using "
|
||||
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
|
||||
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
|
||||
|
||||
# Ensure we can serialize custom transformer configs
|
||||
maybe_register_config_serialize_by_value()
|
||||
|
||||
self.model_config = vllm_config.model_config
|
||||
self.vllm_config = vllm_config
|
||||
self.log_requests = log_requests
|
||||
self.log_stats = log_stats
|
||||
|
||||
# Set up stat loggers; independent set for each DP rank.
|
||||
self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
|
||||
vllm_config=vllm_config,
|
||||
log_stats=self.log_stats,
|
||||
engine_num=vllm_config.parallel_config.data_parallel_size,
|
||||
custom_stat_loggers=stat_loggers,
|
||||
)
|
||||
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
lora_config=vllm_config.lora_config)
|
||||
|
||||
# Processor (converts Inputs --> EngineCoreRequests).
|
||||
self.processor = Processor(
|
||||
vllm_config=vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
mm_registry=mm_registry,
|
||||
)
|
||||
|
||||
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
|
||||
self.output_processor = OutputProcessor(self.tokenizer,
|
||||
log_stats=self.log_stats)
|
||||
|
||||
# EngineCore (starts the engine in background process).
|
||||
|
||||
self.engine_core = EngineCoreClient.make_async_mp_client(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=self.log_stats,
|
||||
client_addresses=client_addresses,
|
||||
client_index=client_index,
|
||||
)
|
||||
if self.stat_loggers:
|
||||
for stat_logger in self.stat_loggers[0]:
|
||||
stat_logger.log_engine_initialized()
|
||||
self.output_handler: Optional[asyncio.Task] = None
|
||||
try:
|
||||
# Start output handler eagerly if we are in the asyncio eventloop.
|
||||
asyncio.get_running_loop()
|
||||
self._run_output_handler()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
disable_log_requests: bool = False,
|
||||
disable_log_stats: bool = False,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_index: int = 0,
|
||||
) -> "AsyncLLM":
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError(
|
||||
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
|
||||
"This should not happen. As a workaround, try using "
|
||||
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
|
||||
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
|
||||
|
||||
# Create the LLMEngine.
|
||||
return cls(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
start_engine_loop=start_engine_loop,
|
||||
stat_loggers=stat_loggers,
|
||||
log_requests=not disable_log_requests,
|
||||
log_stats=not disable_log_stats,
|
||||
usage_context=usage_context,
|
||||
client_addresses=client_addresses,
|
||||
client_index=client_index,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: AsyncEngineArgs,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
) -> "AsyncLLM":
|
||||
"""Create an AsyncLLM from the EngineArgs."""
|
||||
|
||||
# Create the engine configs.
|
||||
vllm_config = engine_args.create_engine_config(usage_context)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
# Create the AsyncLLM.
|
||||
return cls(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
start_engine_loop=start_engine_loop,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown, cleaning up the background proc and IPC."""
|
||||
|
||||
shutdown_prometheus()
|
||||
|
||||
if engine_core := getattr(self, "engine_core", None):
|
||||
engine_core.shutdown()
|
||||
|
||||
if handler := getattr(self, "output_handler", None):
|
||||
handler.cancel()
|
||||
|
||||
async def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> RequestOutputCollector:
|
||||
"""Add new request to the AsyncLLM."""
|
||||
|
||||
if self.errored:
|
||||
raise EngineDeadError()
|
||||
|
||||
assert isinstance(params, SamplingParams), \
|
||||
"Pooling is not supported in V1"
|
||||
|
||||
# Create a new output collector for the request.
|
||||
queue = RequestOutputCollector(output_kind=params.output_kind)
|
||||
|
||||
# Convert Input --> Request.
|
||||
prompt_str, request = self.processor.process_inputs(
|
||||
request_id, prompt, params, arrival_time, lora_request,
|
||||
tokenization_kwargs, trace_headers, prompt_adapter_request,
|
||||
priority, data_parallel_rank)
|
||||
|
||||
if params.n == 1:
|
||||
await self._add_request(request, prompt_str, None, 0, queue)
|
||||
return queue
|
||||
|
||||
# Fan out child requests (for n>1).
|
||||
parent_request = ParentRequest(request_id, params)
|
||||
for idx in range(params.n):
|
||||
request_id, params = parent_request.get_child_info(idx)
|
||||
child_request = request if idx == params.n - 1 else copy(request)
|
||||
child_request.request_id = request_id
|
||||
child_request.sampling_params = params
|
||||
await self._add_request(child_request, prompt_str, parent_request,
|
||||
idx, queue)
|
||||
return queue
|
||||
|
||||
async def _add_request(self, request: EngineCoreRequest,
|
||||
prompt: Optional[str],
|
||||
parent_req: Optional[ParentRequest], index: int,
|
||||
queue: RequestOutputCollector):
|
||||
|
||||
# Add the request to OutputProcessor (this process).
|
||||
self.output_processor.add_request(request, prompt, parent_req, index,
|
||||
queue)
|
||||
|
||||
# Add the EngineCoreRequest to EngineCore (separate process).
|
||||
await self.engine_core.add_request_async(request)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Added request %s.", request.request_id)
|
||||
|
||||
# TODO: we should support multiple prompts in one call, as you
|
||||
# can do with LLM.generate. So that for multi-prompt completion
|
||||
# requests we don't need to send multiple messages to core proc,
|
||||
# and so we don't need multiple streams which then get
|
||||
# re-multiplexed in the API server anyhow.
|
||||
async def generate(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""
|
||||
Main function called by the API server to kick off a request
|
||||
* 1) Making an AsyncStream corresponding to the Request.
|
||||
* 2) Processing the Input.
|
||||
* 3) Adding the Request to the Detokenizer.
|
||||
* 4) Adding the Request to the EngineCore (separate process).
|
||||
|
||||
A separate output_handler loop runs in a background AsyncIO task,
|
||||
pulling outputs from EngineCore and putting them into the
|
||||
per-request AsyncStream.
|
||||
|
||||
The caller of generate() iterates the returned AsyncGenerator,
|
||||
returning the RequestOutput back to the caller.
|
||||
"""
|
||||
|
||||
try:
|
||||
# We start the output_handler on the first call to generate() so
|
||||
# we can call __init__ before the event loop, which enables us
|
||||
# to handle startup failure gracefully in the OpenAI server.
|
||||
self._run_output_handler()
|
||||
|
||||
q = await self.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
|
||||
# The output_handler task pushes items into the queue.
|
||||
# This task pulls from the queue and yields to caller.
|
||||
finished = False
|
||||
while not finished:
|
||||
# Note: drain queue without await if possible (avoids
|
||||
# task switching under load which helps performance).
|
||||
out = q.get_nowait() or await q.get()
|
||||
|
||||
# Note: both OutputProcessor and EngineCore handle their
|
||||
# own request cleanup based on finished.
|
||||
finished = out.finished
|
||||
yield out
|
||||
|
||||
# If the request is disconnected by the client, generate()
|
||||
# is cancelled or the generator is garbage collected. So,
|
||||
# we abort the request if we end up here.
|
||||
except (asyncio.CancelledError, GeneratorExit):
|
||||
await self.abort(request_id)
|
||||
if self.log_requests:
|
||||
logger.info("Request %s aborted.", request_id)
|
||||
raise
|
||||
|
||||
# Engine is dead. Do not abort since we shut down.
|
||||
except EngineDeadError:
|
||||
if self.log_requests:
|
||||
logger.info("Request %s failed (engine dead).", request_id)
|
||||
raise
|
||||
|
||||
# Request validation error.
|
||||
except ValueError:
|
||||
if self.log_requests:
|
||||
logger.info("Request %s failed (bad request).", request_id)
|
||||
raise
|
||||
|
||||
# Unexpected error in the generate() task (possibly recoverable).
|
||||
except Exception as e:
|
||||
await self.abort(request_id)
|
||||
if self.log_requests:
|
||||
logger.info("Request %s failed.", request_id)
|
||||
raise EngineGenerateError() from e
|
||||
|
||||
def _run_output_handler(self):
|
||||
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
|
||||
|
||||
if self.output_handler is not None:
|
||||
return
|
||||
|
||||
# Ensure that the task doesn't have a circular ref back to the AsyncLLM
|
||||
# object, or else it won't be garbage collected and cleaned up properly.
|
||||
engine_core = self.engine_core
|
||||
output_processor = self.output_processor
|
||||
log_stats = self.log_stats
|
||||
stat_loggers = self.stat_loggers if log_stats else None
|
||||
|
||||
async def output_handler():
|
||||
try:
|
||||
while True:
|
||||
# 1) Pull EngineCoreOutputs from the EngineCore.
|
||||
outputs = await engine_core.get_output_async()
|
||||
num_outputs = len(outputs.outputs)
|
||||
|
||||
iteration_stats = IterationStats() if (
|
||||
log_stats and num_outputs) else None
|
||||
|
||||
# Split outputs into chunks of at most
|
||||
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
|
||||
# event loop for too long.
|
||||
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
|
||||
slices = (outputs.outputs, )
|
||||
else:
|
||||
slices = np.array_split(
|
||||
outputs.outputs,
|
||||
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
|
||||
|
||||
for i, outputs_slice in enumerate(slices):
|
||||
# 2) Process EngineCoreOutputs.
|
||||
processed_outputs = output_processor.process_outputs(
|
||||
outputs_slice, outputs.timestamp, iteration_stats)
|
||||
# NOTE: RequestOutputs are pushed to their queues.
|
||||
assert not processed_outputs.request_outputs
|
||||
|
||||
# Allow other asyncio tasks to run between chunks
|
||||
if i + 1 < len(slices):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# 3) Abort any reqs that finished due to stop strings.
|
||||
await engine_core.abort_requests_async(
|
||||
processed_outputs.reqs_to_abort)
|
||||
|
||||
# 4) Logging.
|
||||
# TODO(rob): make into a coroutine and launch it in
|
||||
# background thread once Prometheus overhead is non-trivial.
|
||||
if stat_loggers:
|
||||
AsyncLLM._record_stats(
|
||||
stat_loggers[outputs.engine_index],
|
||||
scheduler_stats=outputs.scheduler_stats,
|
||||
iteration_stats=iteration_stats,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("AsyncLLM output_handler failed.")
|
||||
output_processor.propagate_error(e)
|
||||
|
||||
self.output_handler = asyncio.create_task(output_handler())
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
"""Abort RequestId in OutputProcessor and EngineCore."""
|
||||
|
||||
request_ids = self.output_processor.abort_requests((request_id, ))
|
||||
await self.engine_core.abort_requests_async(request_ids)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Aborted request %s.", request_id)
|
||||
|
||||
@staticmethod
|
||||
def _record_stats(
|
||||
stat_loggers: list[StatLoggerBase],
|
||||
scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats],
|
||||
):
|
||||
"""static so that it can be used from the output_handler task
|
||||
without a circular ref to AsyncLLM."""
|
||||
for stat_logger in stat_loggers:
|
||||
stat_logger.record(scheduler_stats=scheduler_stats,
|
||||
iteration_stats=iteration_stats)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
):
|
||||
raise ValueError("Not Supported on V1 yet.")
|
||||
|
||||
async def get_vllm_config(self) -> VllmConfig:
|
||||
return self.vllm_config
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
return self.model_config
|
||||
|
||||
async def get_decoding_config(self):
|
||||
raise ValueError("Not Supported on V1 yet.")
|
||||
|
||||
async def get_input_preprocessor(self) -> InputPreprocessor:
|
||||
return self.processor.input_preprocessor
|
||||
|
||||
async def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
return self.tokenizer.get_lora_tokenizer(lora_request)
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
return False
|
||||
|
||||
async def do_log_stats(
|
||||
self,
|
||||
scheduler_outputs=None,
|
||||
model_output=None,
|
||||
) -> None:
|
||||
for loggers in self.stat_loggers:
|
||||
for stat_logger in loggers:
|
||||
stat_logger.log()
|
||||
|
||||
async def check_health(self) -> None:
|
||||
logger.debug("Called check_health.")
|
||||
|
||||
async def start_profile(self) -> None:
|
||||
await self.engine_core.profile_async(True)
|
||||
|
||||
async def stop_profile(self) -> None:
|
||||
await self.engine_core.profile_async(False)
|
||||
|
||||
async def reset_mm_cache(self) -> None:
|
||||
self.processor.mm_registry.reset_processor_cache()
|
||||
self.processor.mm_input_cache_client.reset()
|
||||
await self.engine_core.reset_mm_cache_async()
|
||||
|
||||
async def reset_prefix_cache(self,
|
||||
device: Optional[Device] = None) -> None:
|
||||
if device == Device.CPU:
|
||||
raise ValueError("Not supported on CPU.")
|
||||
await self.engine_core.reset_prefix_cache_async()
|
||||
|
||||
async def sleep(self, level: int = 1) -> None:
|
||||
await self.engine_core.sleep_async(level)
|
||||
|
||||
async def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
||||
await self.engine_core.wake_up_async(tags)
|
||||
|
||||
async def is_sleeping(self) -> bool:
|
||||
return await self.engine_core.is_sleeping_async()
|
||||
|
||||
async def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
"""Load a new LoRA adapter into the engine for future requests."""
|
||||
return await self.engine_core.add_lora_async(lora_request)
|
||||
|
||||
async def remove_lora(self, lora_id: int) -> bool:
|
||||
"""Remove an already loaded LoRA adapter."""
|
||||
return await self.engine_core.remove_lora_async(lora_id)
|
||||
|
||||
async def list_loras(self) -> set[int]:
|
||||
"""List all registered adapters."""
|
||||
return await self.engine_core.list_loras_async()
|
||||
|
||||
async def pin_lora(self, lora_id: int) -> bool:
|
||||
"""Prevent an adapter from being evicted."""
|
||||
return await self.engine_core.pin_lora_async(lora_id)
|
||||
|
||||
async def collective_rpc(self,
|
||||
method: str,
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict] = None):
|
||||
"""
|
||||
Perform a collective RPC call to the given path.
|
||||
"""
|
||||
return await self.engine_core.collective_rpc_async(
|
||||
method, timeout, args, kwargs)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
# Is None before the loop is started.
|
||||
return self.output_handler is None or not self.output_handler.done()
|
||||
|
||||
@property
|
||||
def is_stopped(self) -> bool:
|
||||
return self.errored
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
return self.engine_core.resources.engine_dead or not self.is_running
|
||||
|
||||
@property
|
||||
def dead_error(self) -> BaseException:
|
||||
return EngineDeadError()
|
||||
Reference in New Issue
Block a user