[PD] Remove the requirement of config file for mooncake backend (#5460)
This commit is contained in:
@@ -121,7 +121,7 @@ class DecodePreallocQueue:
|
||||
kv_args.aux_item_lens = [
|
||||
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
||||
]
|
||||
kv_args.ib_device = "mock-ib-device"
|
||||
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
||||
kv_args.gpu_id = self.scheduler.gpu_id
|
||||
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
||||
kv_manager = kv_manager_class(
|
||||
|
||||
@@ -99,8 +99,12 @@ class MooncakeKVManager(BaseKVManager):
|
||||
disaggregation_mode: DisaggregationMode,
|
||||
server_args: ServerArgs,
|
||||
):
|
||||
self.engine = MooncakeTransferEngine()
|
||||
self.kv_args = args
|
||||
self.engine = MooncakeTransferEngine(
|
||||
hostname=get_local_ip_by_remote(),
|
||||
gpu_id=self.kv_args.gpu_id,
|
||||
ib_device=self.kv_args.ib_device,
|
||||
)
|
||||
self.disaggregation_mode = disaggregation_mode
|
||||
# for p/d multi node infer
|
||||
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
||||
@@ -503,52 +507,8 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
||||
self.thread.start()
|
||||
|
||||
def _setup_routes(self):
|
||||
self.app.router.add_route("*", "/metadata", self._handle_metadata)
|
||||
self.app.router.add_route("*", "/route", self._handle_route)
|
||||
|
||||
async def _handle_metadata(self, request: web.Request):
|
||||
key = request.query.get("key", "")
|
||||
|
||||
if request.method == "GET":
|
||||
return await self._handle_metadata_get(key)
|
||||
elif request.method == "PUT":
|
||||
return await self._handle_metadata_put(key, request)
|
||||
elif request.method == "DELETE":
|
||||
return await self._handle_metadata_delete(key)
|
||||
return web.Response(
|
||||
text="Method not allowed", status=405, content_type="application/json"
|
||||
)
|
||||
|
||||
async def _handle_metadata_get(self, key):
|
||||
async with self.lock:
|
||||
value = self.store.get(key)
|
||||
if value is None:
|
||||
return web.Response(
|
||||
text="metadata not found", status=404, content_type="application/json"
|
||||
)
|
||||
return web.Response(body=value, status=200, content_type="application/json")
|
||||
|
||||
async def _handle_metadata_put(self, key, request):
|
||||
data = await request.read()
|
||||
async with self.lock:
|
||||
self.store[key] = data
|
||||
return web.Response(
|
||||
text="metadata updated", status=200, content_type="application/json"
|
||||
)
|
||||
|
||||
async def _handle_metadata_delete(self, key):
|
||||
async with self.lock:
|
||||
if key not in self.store:
|
||||
return web.Response(
|
||||
text="metadata not found",
|
||||
status=404,
|
||||
content_type="application/json",
|
||||
)
|
||||
del self.store[key]
|
||||
return web.Response(
|
||||
text="metadata deleted", status=200, content_type="application/json"
|
||||
)
|
||||
|
||||
async def _handle_route(self, request: web.Request):
|
||||
method = request.method
|
||||
if method == "PUT":
|
||||
|
||||
@@ -1,45 +1,14 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeTransferEngineConfig:
|
||||
local_hostname: str
|
||||
metadata_server: str
|
||||
protocol: str
|
||||
device_name: str
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str) -> "MooncakeTransferEngineConfig":
|
||||
"""Load the config from a JSON file."""
|
||||
with open(file_path) as fin:
|
||||
config = json.load(fin)
|
||||
return MooncakeTransferEngineConfig(
|
||||
local_hostname=config.get("local_hostname", None),
|
||||
metadata_server=config.get("metadata_server"),
|
||||
protocol=config.get("protocol", "rdma"),
|
||||
device_name=config.get("device_name", ""),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_from_env() -> "MooncakeTransferEngineConfig":
|
||||
"""Load config from a file specified in the environment variable."""
|
||||
config_file_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
||||
if config_file_path is None:
|
||||
raise ValueError(
|
||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
|
||||
)
|
||||
return MooncakeTransferEngineConfig.from_file(config_file_path)
|
||||
|
||||
|
||||
class MooncakeTransferEngine:
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, hostname: str, gpu_id: int, ib_device: Optional[str] = None):
|
||||
try:
|
||||
from mooncake.engine import TransferEngine
|
||||
except ImportError as e:
|
||||
@@ -50,43 +19,43 @@ class MooncakeTransferEngine:
|
||||
) from e
|
||||
|
||||
self.engine = TransferEngine()
|
||||
self.hostname = hostname
|
||||
self.gpu_id = gpu_id
|
||||
self.ib_device = ib_device
|
||||
|
||||
try:
|
||||
self.config = MooncakeTransferEngineConfig.load_from_env()
|
||||
logger.info("Mooncake Configuration loaded successfully.")
|
||||
except ValueError as e:
|
||||
logger.error(e)
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error("An error occurred while loading the configuration: %s", exc)
|
||||
raise
|
||||
|
||||
self.config = MooncakeTransferEngineConfig.load_from_env()
|
||||
|
||||
session_suffix = "_" + str(uuid.uuid4())
|
||||
self.session_id = self.config.local_hostname + session_suffix
|
||||
self.initialize(
|
||||
self.session_id,
|
||||
self.config.metadata_server,
|
||||
self.config.protocol,
|
||||
self.config.device_name,
|
||||
hostname=self.hostname,
|
||||
device_name=self.ib_device,
|
||||
)
|
||||
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
|
||||
|
||||
def register(self, ptr, length):
|
||||
self.engine.register_memory(ptr, length)
|
||||
ret_value = self.engine.register_memory(ptr, length)
|
||||
if ret_value != 0:
|
||||
logger.error("Mooncake memory registration failed.")
|
||||
raise RuntimeError("Mooncake memory registration failed.")
|
||||
|
||||
def deregister(self, ptr):
|
||||
self.engine.unregister_memory(ptr)
|
||||
ret_value = self.engine.unregister_memory(ptr)
|
||||
if ret_value != 0:
|
||||
logger.error("Mooncake memory deregistration failed.")
|
||||
raise RuntimeError("Mooncake memory deregistration failed.")
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
local_hostname: str,
|
||||
metadata_server: str,
|
||||
protocol: str,
|
||||
device_name: str,
|
||||
hostname: str,
|
||||
device_name: Optional[str],
|
||||
) -> None:
|
||||
"""Initialize the mooncake instance."""
|
||||
self.engine.initialize(local_hostname, metadata_server, protocol, device_name)
|
||||
ret_value = self.engine.initialize(
|
||||
hostname,
|
||||
"P2PHANDSHAKE",
|
||||
"rdma",
|
||||
device_name if device_name is not None else "",
|
||||
)
|
||||
if ret_value != 0:
|
||||
logger.error("Mooncake Transfer Engine initialization failed.")
|
||||
raise RuntimeError("Mooncake Transfer Engine initialization failed.")
|
||||
|
||||
def transfer_sync(
|
||||
self, session_id: str, buffer: int, peer_buffer_address: int, length: int
|
||||
@@ -97,12 +66,12 @@ class MooncakeTransferEngine:
|
||||
session_id, buffer, peer_buffer_address, length
|
||||
)
|
||||
if ret < 0:
|
||||
logger.error("Transfer Return Error")
|
||||
raise Exception("Transfer Return Error")
|
||||
logger.error("Mooncake Transfer Engine Return Error.")
|
||||
raise RuntimeError("Mooncake Transfer Engine Return Error.")
|
||||
return ret
|
||||
|
||||
def get_localhost(self):
|
||||
return self.config.local_hostname
|
||||
return self.hostname
|
||||
|
||||
def get_session_id(self):
|
||||
return self.session_id
|
||||
|
||||
@@ -103,7 +103,7 @@ class PrefillBootstrapQueue:
|
||||
kv_args.aux_item_lens = [
|
||||
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
||||
]
|
||||
kv_args.ib_device = "mock-ib-device"
|
||||
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
||||
kv_args.gpu_id = self.scheduler.gpu_id
|
||||
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
||||
kv_manager = kv_manager_class(
|
||||
|
||||
@@ -196,6 +196,7 @@ class ServerArgs:
|
||||
disaggregation_mode: str = "null"
|
||||
disaggregation_bootstrap_port: int = 8998
|
||||
disaggregation_transfer_backend: str = "mooncake"
|
||||
disaggregation_ib_device: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Expert parallelism
|
||||
@@ -1193,6 +1194,12 @@ class ServerArgs:
|
||||
default=ServerArgs.disaggregation_transfer_backend,
|
||||
help="The backend for disaggregation transfer. Default is mooncake.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disaggregation-ib-device",
|
||||
type=str,
|
||||
default=ServerArgs.disaggregation_ib_device,
|
||||
help="The ib device for disaggregation transfer. Default is None, it will be detected automatically if using the mooncake backend.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
|
||||
Reference in New Issue
Block a user