[PD] Remove the requirement of config file for mooncake backend (#5460)
This commit is contained in:
@@ -121,7 +121,7 @@ class DecodePreallocQueue:
|
|||||||
kv_args.aux_item_lens = [
|
kv_args.aux_item_lens = [
|
||||||
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
||||||
]
|
]
|
||||||
kv_args.ib_device = "mock-ib-device"
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
||||||
kv_args.gpu_id = self.scheduler.gpu_id
|
kv_args.gpu_id = self.scheduler.gpu_id
|
||||||
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
||||||
kv_manager = kv_manager_class(
|
kv_manager = kv_manager_class(
|
||||||
|
|||||||
@@ -99,8 +99,12 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
disaggregation_mode: DisaggregationMode,
|
disaggregation_mode: DisaggregationMode,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
):
|
):
|
||||||
self.engine = MooncakeTransferEngine()
|
|
||||||
self.kv_args = args
|
self.kv_args = args
|
||||||
|
self.engine = MooncakeTransferEngine(
|
||||||
|
hostname=get_local_ip_by_remote(),
|
||||||
|
gpu_id=self.kv_args.gpu_id,
|
||||||
|
ib_device=self.kv_args.ib_device,
|
||||||
|
)
|
||||||
self.disaggregation_mode = disaggregation_mode
|
self.disaggregation_mode = disaggregation_mode
|
||||||
# for p/d multi node infer
|
# for p/d multi node infer
|
||||||
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
||||||
@@ -503,52 +507,8 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|||||||
self.thread.start()
|
self.thread.start()
|
||||||
|
|
||||||
def _setup_routes(self):
|
def _setup_routes(self):
|
||||||
self.app.router.add_route("*", "/metadata", self._handle_metadata)
|
|
||||||
self.app.router.add_route("*", "/route", self._handle_route)
|
self.app.router.add_route("*", "/route", self._handle_route)
|
||||||
|
|
||||||
async def _handle_metadata(self, request: web.Request):
|
|
||||||
key = request.query.get("key", "")
|
|
||||||
|
|
||||||
if request.method == "GET":
|
|
||||||
return await self._handle_metadata_get(key)
|
|
||||||
elif request.method == "PUT":
|
|
||||||
return await self._handle_metadata_put(key, request)
|
|
||||||
elif request.method == "DELETE":
|
|
||||||
return await self._handle_metadata_delete(key)
|
|
||||||
return web.Response(
|
|
||||||
text="Method not allowed", status=405, content_type="application/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _handle_metadata_get(self, key):
|
|
||||||
async with self.lock:
|
|
||||||
value = self.store.get(key)
|
|
||||||
if value is None:
|
|
||||||
return web.Response(
|
|
||||||
text="metadata not found", status=404, content_type="application/json"
|
|
||||||
)
|
|
||||||
return web.Response(body=value, status=200, content_type="application/json")
|
|
||||||
|
|
||||||
async def _handle_metadata_put(self, key, request):
|
|
||||||
data = await request.read()
|
|
||||||
async with self.lock:
|
|
||||||
self.store[key] = data
|
|
||||||
return web.Response(
|
|
||||||
text="metadata updated", status=200, content_type="application/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _handle_metadata_delete(self, key):
|
|
||||||
async with self.lock:
|
|
||||||
if key not in self.store:
|
|
||||||
return web.Response(
|
|
||||||
text="metadata not found",
|
|
||||||
status=404,
|
|
||||||
content_type="application/json",
|
|
||||||
)
|
|
||||||
del self.store[key]
|
|
||||||
return web.Response(
|
|
||||||
text="metadata deleted", status=200, content_type="application/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _handle_route(self, request: web.Request):
|
async def _handle_route(self, request: web.Request):
|
||||||
method = request.method
|
method = request.method
|
||||||
if method == "PUT":
|
if method == "PUT":
|
||||||
|
|||||||
@@ -1,45 +1,14 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MooncakeTransferEngineConfig:
|
|
||||||
local_hostname: str
|
|
||||||
metadata_server: str
|
|
||||||
protocol: str
|
|
||||||
device_name: str
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_file(file_path: str) -> "MooncakeTransferEngineConfig":
|
|
||||||
"""Load the config from a JSON file."""
|
|
||||||
with open(file_path) as fin:
|
|
||||||
config = json.load(fin)
|
|
||||||
return MooncakeTransferEngineConfig(
|
|
||||||
local_hostname=config.get("local_hostname", None),
|
|
||||||
metadata_server=config.get("metadata_server"),
|
|
||||||
protocol=config.get("protocol", "rdma"),
|
|
||||||
device_name=config.get("device_name", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_from_env() -> "MooncakeTransferEngineConfig":
|
|
||||||
"""Load config from a file specified in the environment variable."""
|
|
||||||
config_file_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
|
||||||
if config_file_path is None:
|
|
||||||
raise ValueError(
|
|
||||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
|
|
||||||
)
|
|
||||||
return MooncakeTransferEngineConfig.from_file(config_file_path)
|
|
||||||
|
|
||||||
|
|
||||||
class MooncakeTransferEngine:
|
class MooncakeTransferEngine:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, hostname: str, gpu_id: int, ib_device: Optional[str] = None):
|
||||||
try:
|
try:
|
||||||
from mooncake.engine import TransferEngine
|
from mooncake.engine import TransferEngine
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@@ -50,43 +19,43 @@ class MooncakeTransferEngine:
|
|||||||
) from e
|
) from e
|
||||||
|
|
||||||
self.engine = TransferEngine()
|
self.engine = TransferEngine()
|
||||||
|
self.hostname = hostname
|
||||||
|
self.gpu_id = gpu_id
|
||||||
|
self.ib_device = ib_device
|
||||||
|
|
||||||
try:
|
|
||||||
self.config = MooncakeTransferEngineConfig.load_from_env()
|
|
||||||
logger.info("Mooncake Configuration loaded successfully.")
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(e)
|
|
||||||
raise
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("An error occurred while loading the configuration: %s", exc)
|
|
||||||
raise
|
|
||||||
|
|
||||||
self.config = MooncakeTransferEngineConfig.load_from_env()
|
|
||||||
|
|
||||||
session_suffix = "_" + str(uuid.uuid4())
|
|
||||||
self.session_id = self.config.local_hostname + session_suffix
|
|
||||||
self.initialize(
|
self.initialize(
|
||||||
self.session_id,
|
hostname=self.hostname,
|
||||||
self.config.metadata_server,
|
device_name=self.ib_device,
|
||||||
self.config.protocol,
|
|
||||||
self.config.device_name,
|
|
||||||
)
|
)
|
||||||
|
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
|
||||||
|
|
||||||
def register(self, ptr, length):
|
def register(self, ptr, length):
|
||||||
self.engine.register_memory(ptr, length)
|
ret_value = self.engine.register_memory(ptr, length)
|
||||||
|
if ret_value != 0:
|
||||||
|
logger.error("Mooncake memory registration failed.")
|
||||||
|
raise RuntimeError("Mooncake memory registration failed.")
|
||||||
|
|
||||||
def deregister(self, ptr):
|
def deregister(self, ptr):
|
||||||
self.engine.unregister_memory(ptr)
|
ret_value = self.engine.unregister_memory(ptr)
|
||||||
|
if ret_value != 0:
|
||||||
|
logger.error("Mooncake memory deregistration failed.")
|
||||||
|
raise RuntimeError("Mooncake memory deregistration failed.")
|
||||||
|
|
||||||
def initialize(
|
def initialize(
|
||||||
self,
|
self,
|
||||||
local_hostname: str,
|
hostname: str,
|
||||||
metadata_server: str,
|
device_name: Optional[str],
|
||||||
protocol: str,
|
|
||||||
device_name: str,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the mooncake instance."""
|
"""Initialize the mooncake instance."""
|
||||||
self.engine.initialize(local_hostname, metadata_server, protocol, device_name)
|
ret_value = self.engine.initialize(
|
||||||
|
hostname,
|
||||||
|
"P2PHANDSHAKE",
|
||||||
|
"rdma",
|
||||||
|
device_name if device_name is not None else "",
|
||||||
|
)
|
||||||
|
if ret_value != 0:
|
||||||
|
logger.error("Mooncake Transfer Engine initialization failed.")
|
||||||
|
raise RuntimeError("Mooncake Transfer Engine initialization failed.")
|
||||||
|
|
||||||
def transfer_sync(
|
def transfer_sync(
|
||||||
self, session_id: str, buffer: int, peer_buffer_address: int, length: int
|
self, session_id: str, buffer: int, peer_buffer_address: int, length: int
|
||||||
@@ -97,12 +66,12 @@ class MooncakeTransferEngine:
|
|||||||
session_id, buffer, peer_buffer_address, length
|
session_id, buffer, peer_buffer_address, length
|
||||||
)
|
)
|
||||||
if ret < 0:
|
if ret < 0:
|
||||||
logger.error("Transfer Return Error")
|
logger.error("Mooncake Transfer Engine Return Error.")
|
||||||
raise Exception("Transfer Return Error")
|
raise RuntimeError("Mooncake Transfer Engine Return Error.")
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def get_localhost(self):
|
def get_localhost(self):
|
||||||
return self.config.local_hostname
|
return self.hostname
|
||||||
|
|
||||||
def get_session_id(self):
|
def get_session_id(self):
|
||||||
return self.session_id
|
return self.session_id
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ class PrefillBootstrapQueue:
|
|||||||
kv_args.aux_item_lens = [
|
kv_args.aux_item_lens = [
|
||||||
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
||||||
]
|
]
|
||||||
kv_args.ib_device = "mock-ib-device"
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
||||||
kv_args.gpu_id = self.scheduler.gpu_id
|
kv_args.gpu_id = self.scheduler.gpu_id
|
||||||
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
||||||
kv_manager = kv_manager_class(
|
kv_manager = kv_manager_class(
|
||||||
|
|||||||
@@ -196,6 +196,7 @@ class ServerArgs:
|
|||||||
disaggregation_mode: str = "null"
|
disaggregation_mode: str = "null"
|
||||||
disaggregation_bootstrap_port: int = 8998
|
disaggregation_bootstrap_port: int = 8998
|
||||||
disaggregation_transfer_backend: str = "mooncake"
|
disaggregation_transfer_backend: str = "mooncake"
|
||||||
|
disaggregation_ib_device: Optional[str] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Expert parallelism
|
# Expert parallelism
|
||||||
@@ -1193,6 +1194,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.disaggregation_transfer_backend,
|
default=ServerArgs.disaggregation_transfer_backend,
|
||||||
help="The backend for disaggregation transfer. Default is mooncake.",
|
help="The backend for disaggregation transfer. Default is mooncake.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disaggregation-ib-device",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.disaggregation_ib_device,
|
||||||
|
help="The ib device for disaggregation transfer. Default is None, it will be detected automatically if using the mooncake backend.",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
|
|||||||
Reference in New Issue
Block a user