Files
enginex-mlu590-vllm/vllm_mlu/v1/engine/core_client.py

228 lines
8.6 KiB
Python
Raw Permalink Normal View History

2026-04-24 09:50:34 +08:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
from vllm.v1.engine.core_client import (
EngineCoreClient,
InprocClient,
SyncMPClient,
AsyncMPClient,
DPAsyncMPClient,
DPLBAsyncMPClient,
)
from vllm.v1.engine import EngineCoreRequest
from vllm.config import VllmConfig
from vllm.v1.executor import Executor
from vllm_mlu.mlu_hijack_utils import MluHijackObject
class EngineCoreClient_MluHiack(EngineCoreClient):
@staticmethod
def make_async_mp_client(
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
client_addresses: dict[str, str] | None = None,
client_count: int = 1,
client_index: int = 0,
) -> "MPClient":
parallel_config = vllm_config.parallel_config
client_args = (
vllm_config,
executor_class,
log_stats,
client_addresses,
client_count,
client_index,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: disagg use DPAsyncMPClient instead of DPLBAsyncMPClient.
'''
if parallel_config.data_parallel_size > 1:
if parallel_config.data_parallel_external_lb or vllm_config.kv_transfer_config is not None:
# External load balancer - client per DP rank.
return DPAsyncMPClient(*client_args)
# Internal load balancer - client balances to all DP ranks.
return DPLBAsyncMPClient(*client_args)
'''
==================
End of MLU Hijack
==================
'''
return AsyncMPClient(*client_args)
class InprocClient_MluHiack(InprocClient):
def get_hfu_info(self, batch, input_len, output_len):
return self.engine_core.get_hfu_info(batch, input_len, output_len)
def get_latency(self):
return self.engine_core.get_latency()
def get_memory_usage(self):
return self.engine_core.get_memory_usage()
def recapture_model(
self,
prefill_enable_mlugraph: bool,
batch_size: int,
input_len: int,
):
return self.engine_core.recapture_model(
prefill_enable_mlugraph, batch_size, input_len
)
def init_metric(self, use_unchunk_sched: bool, min_prefill_batch: int):
return self.engine_core.init_metric(
use_unchunk_sched, min_prefill_batch,
)
def start_scheduler_profile(self):
self.engine_core.start_scheduler_profile()
def stop_scheduler_profile(self):
self.engine_core.stop_scheduler_profile()
def response_remote_alloc_once(self) -> None:
self.engine_core.response_remote_alloc_once()
class SyncMPClient_MluHiack(SyncMPClient):
def get_hfu_info(self, batch, input_len, output_len):
try:
return self.call_utility("get_hfu_info", batch, input_len, output_len)
except Exception as e:
raise RuntimeError(f"Failed to get HFU info: {str(e)}")
def get_latency(self):
return self.call_utility("get_latency")
def get_memory_usage(self):
return self.call_utility("get_memory_usage")
def recapture_model(self,
prefill_enable_mlugraph: bool,
batch_size: int,
input_len: int):
return self.call_utility("recapture_model",
prefill_enable_mlugraph, batch_size, input_len)
def init_metric(self, use_unchunk_sched: bool, min_prefill_batch: int):
return self.call_utility("init_metric",
use_unchunk_sched,
min_prefill_batch)
def start_scheduler_profile(self):
self.call_utility("start_scheduler_profile")
def stop_scheduler_profile(self):
self.call_utility("stop_scheduler_profile")
def response_remote_alloc_once(self) -> None:
self.call_utility("response_remote_alloc_once")
class AsyncMPClient_MluHijack(AsyncMPClient):
async def start_scheduler_profile(self) -> None:
await self.call_utility_async("start_scheduler_profile")
async def stop_scheduler_profile(self) -> None:
await self.call_utility_async("stop_scheduler_profile")
async def response_remote_alloc_once(self) -> None:
await self.call_utility_async("response_remote_alloc_once")
class DPAsyncMPClient_MluHijack(DPAsyncMPClient):
def get_core_engine_for_request(self, request: EngineCoreRequest):
'''
=============================
Modify by vllm_mlu
=============================
@brief: disagg need proxy to assign dp_rank
'''
if request.data_parallel_rank is not None:
# engines are already in rank order
return self.core_engines[request.data_parallel_rank]
'''
==================
End of MLU Hijack
==================
'''
return self.core_engine
MluHijackObject.apply_hijack(EngineCoreClient,
EngineCoreClient.make_async_mp_client,
EngineCoreClient_MluHiack.make_async_mp_client)
MluHijackObject.apply_hijack(InprocClient,
"get_hfu_info",
InprocClient_MluHiack.get_hfu_info)
MluHijackObject.apply_hijack(InprocClient,
"get_latency",
InprocClient_MluHiack.get_latency)
MluHijackObject.apply_hijack(InprocClient,
"get_memory_usage",
InprocClient_MluHiack.get_memory_usage)
MluHijackObject.apply_hijack(InprocClient,
"recapture_model",
InprocClient_MluHiack.recapture_model)
MluHijackObject.apply_hijack(InprocClient,
"init_metric",
InprocClient_MluHiack.init_metric)
MluHijackObject.apply_hijack(InprocClient,
"start_scheduler_profile",
InprocClient_MluHiack.start_scheduler_profile)
MluHijackObject.apply_hijack(InprocClient,
"stop_scheduler_profile",
InprocClient_MluHiack.stop_scheduler_profile)
MluHijackObject.apply_hijack(InprocClient,
"response_remote_alloc_once",
InprocClient_MluHiack.response_remote_alloc_once)
MluHijackObject.apply_hijack(SyncMPClient,
"get_hfu_info",
SyncMPClient_MluHiack.get_hfu_info)
MluHijackObject.apply_hijack(SyncMPClient,
"get_latency",
SyncMPClient_MluHiack.get_latency)
MluHijackObject.apply_hijack(SyncMPClient,
"get_memory_usage",
SyncMPClient_MluHiack.get_memory_usage)
MluHijackObject.apply_hijack(SyncMPClient,
"recapture_model",
SyncMPClient_MluHiack.recapture_model)
MluHijackObject.apply_hijack(SyncMPClient,
"init_metric",
SyncMPClient_MluHiack.init_metric)
MluHijackObject.apply_hijack(SyncMPClient,
"start_scheduler_profile",
SyncMPClient_MluHiack.start_scheduler_profile)
MluHijackObject.apply_hijack(SyncMPClient,
"stop_scheduler_profile",
SyncMPClient_MluHiack.stop_scheduler_profile)
MluHijackObject.apply_hijack(SyncMPClient,
"response_remote_alloc_once",
SyncMPClient_MluHiack.response_remote_alloc_once)
MluHijackObject.apply_hijack(AsyncMPClient,
"start_scheduler_profile",
AsyncMPClient_MluHijack.start_scheduler_profile)
MluHijackObject.apply_hijack(AsyncMPClient,
"stop_scheduler_profile",
AsyncMPClient_MluHijack.stop_scheduler_profile)
MluHijackObject.apply_hijack(AsyncMPClient,
"response_remote_alloc_once",
AsyncMPClient_MluHijack.response_remote_alloc_once)
MluHijackObject.apply_hijack(DPAsyncMPClient,
DPAsyncMPClient.get_core_engine_for_request,
DPAsyncMPClient_MluHijack.get_core_engine_for_request)