228 lines
8.6 KiB
Python
228 lines
8.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
from vllm.v1.engine.core_client import (
|
|
EngineCoreClient,
|
|
InprocClient,
|
|
SyncMPClient,
|
|
AsyncMPClient,
|
|
DPAsyncMPClient,
|
|
DPLBAsyncMPClient,
|
|
)
|
|
from vllm.v1.engine import EngineCoreRequest
|
|
from vllm.config import VllmConfig
|
|
from vllm.v1.executor import Executor
|
|
|
|
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
|
|
|
class EngineCoreClient_MluHiack(EngineCoreClient):
|
|
|
|
@staticmethod
|
|
def make_async_mp_client(
|
|
vllm_config: VllmConfig,
|
|
executor_class: type[Executor],
|
|
log_stats: bool,
|
|
client_addresses: dict[str, str] | None = None,
|
|
client_count: int = 1,
|
|
client_index: int = 0,
|
|
) -> "MPClient":
|
|
parallel_config = vllm_config.parallel_config
|
|
client_args = (
|
|
vllm_config,
|
|
executor_class,
|
|
log_stats,
|
|
client_addresses,
|
|
client_count,
|
|
client_index,
|
|
)
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: disagg use DPAsyncMPClient instead of DPLBAsyncMPClient.
|
|
'''
|
|
if parallel_config.data_parallel_size > 1:
|
|
if parallel_config.data_parallel_external_lb or vllm_config.kv_transfer_config is not None:
|
|
# External load balancer - client per DP rank.
|
|
return DPAsyncMPClient(*client_args)
|
|
# Internal load balancer - client balances to all DP ranks.
|
|
return DPLBAsyncMPClient(*client_args)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
return AsyncMPClient(*client_args)
|
|
|
|
|
|
class InprocClient_MluHiack(InprocClient):
|
|
|
|
def get_hfu_info(self, batch, input_len, output_len):
|
|
return self.engine_core.get_hfu_info(batch, input_len, output_len)
|
|
|
|
def get_latency(self):
|
|
return self.engine_core.get_latency()
|
|
|
|
def get_memory_usage(self):
|
|
return self.engine_core.get_memory_usage()
|
|
|
|
def recapture_model(
|
|
self,
|
|
prefill_enable_mlugraph: bool,
|
|
batch_size: int,
|
|
input_len: int,
|
|
):
|
|
return self.engine_core.recapture_model(
|
|
prefill_enable_mlugraph, batch_size, input_len
|
|
)
|
|
|
|
def init_metric(self, use_unchunk_sched: bool, min_prefill_batch: int):
|
|
return self.engine_core.init_metric(
|
|
use_unchunk_sched, min_prefill_batch,
|
|
)
|
|
|
|
def start_scheduler_profile(self):
|
|
self.engine_core.start_scheduler_profile()
|
|
|
|
def stop_scheduler_profile(self):
|
|
self.engine_core.stop_scheduler_profile()
|
|
|
|
def response_remote_alloc_once(self) -> None:
|
|
self.engine_core.response_remote_alloc_once()
|
|
|
|
|
|
class SyncMPClient_MluHiack(SyncMPClient):
|
|
|
|
def get_hfu_info(self, batch, input_len, output_len):
|
|
try:
|
|
return self.call_utility("get_hfu_info", batch, input_len, output_len)
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to get HFU info: {str(e)}")
|
|
|
|
def get_latency(self):
|
|
return self.call_utility("get_latency")
|
|
|
|
def get_memory_usage(self):
|
|
return self.call_utility("get_memory_usage")
|
|
|
|
def recapture_model(self,
|
|
prefill_enable_mlugraph: bool,
|
|
batch_size: int,
|
|
input_len: int):
|
|
return self.call_utility("recapture_model",
|
|
prefill_enable_mlugraph, batch_size, input_len)
|
|
|
|
def init_metric(self, use_unchunk_sched: bool, min_prefill_batch: int):
|
|
return self.call_utility("init_metric",
|
|
use_unchunk_sched,
|
|
min_prefill_batch)
|
|
|
|
def start_scheduler_profile(self):
|
|
self.call_utility("start_scheduler_profile")
|
|
|
|
def stop_scheduler_profile(self):
|
|
self.call_utility("stop_scheduler_profile")
|
|
|
|
def response_remote_alloc_once(self) -> None:
|
|
self.call_utility("response_remote_alloc_once")
|
|
|
|
|
|
class AsyncMPClient_MluHijack(AsyncMPClient):
|
|
|
|
async def start_scheduler_profile(self) -> None:
|
|
await self.call_utility_async("start_scheduler_profile")
|
|
|
|
async def stop_scheduler_profile(self) -> None:
|
|
await self.call_utility_async("stop_scheduler_profile")
|
|
|
|
async def response_remote_alloc_once(self) -> None:
|
|
await self.call_utility_async("response_remote_alloc_once")
|
|
|
|
|
|
class DPAsyncMPClient_MluHijack(DPAsyncMPClient):
|
|
|
|
def get_core_engine_for_request(self, request: EngineCoreRequest):
|
|
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: disagg need proxy to assign dp_rank
|
|
'''
|
|
if request.data_parallel_rank is not None:
|
|
# engines are already in rank order
|
|
return self.core_engines[request.data_parallel_rank]
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
return self.core_engine
|
|
|
|
|
|
MluHijackObject.apply_hijack(EngineCoreClient,
|
|
EngineCoreClient.make_async_mp_client,
|
|
EngineCoreClient_MluHiack.make_async_mp_client)
|
|
MluHijackObject.apply_hijack(InprocClient,
|
|
"get_hfu_info",
|
|
InprocClient_MluHiack.get_hfu_info)
|
|
MluHijackObject.apply_hijack(InprocClient,
|
|
"get_latency",
|
|
InprocClient_MluHiack.get_latency)
|
|
MluHijackObject.apply_hijack(InprocClient,
|
|
"get_memory_usage",
|
|
InprocClient_MluHiack.get_memory_usage)
|
|
MluHijackObject.apply_hijack(InprocClient,
|
|
"recapture_model",
|
|
InprocClient_MluHiack.recapture_model)
|
|
MluHijackObject.apply_hijack(InprocClient,
|
|
"init_metric",
|
|
InprocClient_MluHiack.init_metric)
|
|
MluHijackObject.apply_hijack(InprocClient,
|
|
"start_scheduler_profile",
|
|
InprocClient_MluHiack.start_scheduler_profile)
|
|
MluHijackObject.apply_hijack(InprocClient,
|
|
"stop_scheduler_profile",
|
|
InprocClient_MluHiack.stop_scheduler_profile)
|
|
MluHijackObject.apply_hijack(InprocClient,
|
|
"response_remote_alloc_once",
|
|
InprocClient_MluHiack.response_remote_alloc_once)
|
|
MluHijackObject.apply_hijack(SyncMPClient,
|
|
"get_hfu_info",
|
|
SyncMPClient_MluHiack.get_hfu_info)
|
|
MluHijackObject.apply_hijack(SyncMPClient,
|
|
"get_latency",
|
|
SyncMPClient_MluHiack.get_latency)
|
|
MluHijackObject.apply_hijack(SyncMPClient,
|
|
"get_memory_usage",
|
|
SyncMPClient_MluHiack.get_memory_usage)
|
|
MluHijackObject.apply_hijack(SyncMPClient,
|
|
"recapture_model",
|
|
SyncMPClient_MluHiack.recapture_model)
|
|
MluHijackObject.apply_hijack(SyncMPClient,
|
|
"init_metric",
|
|
SyncMPClient_MluHiack.init_metric)
|
|
MluHijackObject.apply_hijack(SyncMPClient,
|
|
"start_scheduler_profile",
|
|
SyncMPClient_MluHiack.start_scheduler_profile)
|
|
MluHijackObject.apply_hijack(SyncMPClient,
|
|
"stop_scheduler_profile",
|
|
SyncMPClient_MluHiack.stop_scheduler_profile)
|
|
MluHijackObject.apply_hijack(SyncMPClient,
|
|
"response_remote_alloc_once",
|
|
SyncMPClient_MluHiack.response_remote_alloc_once)
|
|
MluHijackObject.apply_hijack(AsyncMPClient,
|
|
"start_scheduler_profile",
|
|
AsyncMPClient_MluHijack.start_scheduler_profile)
|
|
MluHijackObject.apply_hijack(AsyncMPClient,
|
|
"stop_scheduler_profile",
|
|
AsyncMPClient_MluHijack.stop_scheduler_profile)
|
|
MluHijackObject.apply_hijack(AsyncMPClient,
|
|
"response_remote_alloc_once",
|
|
AsyncMPClient_MluHijack.response_remote_alloc_once)
|
|
MluHijackObject.apply_hijack(DPAsyncMPClient,
|
|
DPAsyncMPClient.get_core_engine_for_request,
|
|
DPAsyncMPClient_MluHijack.get_core_engine_for_request)
|