Files
sglang/python/sglang/srt/disaggregation/mooncake/transfer_engine.py
2025-06-25 18:55:24 -07:00

125 lines
4.2 KiB
Python

import json
import logging
from dataclasses import dataclass
from typing import List, Optional
logger = logging.getLogger(__name__)
class MooncakeTransferEngine:
def __init__(self, hostname: str, gpu_id: int, ib_device: Optional[str] = None):
try:
from mooncake.engine import TransferEngine
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run SGLang with MooncakeTransferEngine."
) from e
self.engine = TransferEngine()
self.hostname = hostname
self.gpu_id = gpu_id
self.ib_device = ib_device
self.initialize(
hostname=self.hostname,
device_name=self.ib_device,
)
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
def register(self, ptr, length):
try:
ret_value = self.engine.register_memory(ptr, length)
except Exception:
# Mark register as failed
ret_value = -1
if ret_value != 0:
logger.debug("Mooncake memory registration %s failed.", ptr)
def deregister(self, ptr):
try:
ret_value = self.engine.unregister_memory(ptr)
except Exception:
# Mark deregister as failed
ret_value = -1
if ret_value != 0:
logger.debug("Mooncake memory deregistration %s failed.", ptr)
def initialize(
self,
hostname: str,
device_name: Optional[str],
) -> None:
"""Initialize the mooncake instance."""
ret_value = self.engine.initialize(
hostname,
"P2PHANDSHAKE",
"rdma",
device_name if device_name is not None else "",
)
if ret_value != 0:
logger.error("Mooncake Transfer Engine initialization failed.")
raise RuntimeError("Mooncake Transfer Engine initialization failed.")
def transfer_sync(
self, session_id: str, buffer: int, peer_buffer_address: int, length: int
) -> int:
"""Synchronously transfer data to the specified address."""
try:
# the first time: based on session_id (which contains remote_ip) to construct a queue pair, and cache the queue pair
# later: based on the cached queue pair to send data
ret = self.engine.transfer_sync_write(
session_id, buffer, peer_buffer_address, length
)
except Exception:
# Mark transfer request as failed
ret = -1
if ret < 0:
# Do not raise an exception here, since some transfer requests fail should be accepted and the execution thread should not be stopped.
logger.debug(
"Failed to transfer data from %s to %s - %s.",
buffer,
session_id,
peer_buffer_address,
)
return ret
def batch_transfer_sync(
self,
session_id: str,
buffers: List[int],
peer_buffer_addresses: List[int],
lengths: List[int],
) -> int:
"""Synchronously transfer data to the specified addresses in batches."""
try:
ret = self.engine.batch_transfer_sync_write(
session_id, buffers, peer_buffer_addresses, lengths
)
except Exception:
ret = -1
# Inform user to upgrade mooncake-transfer-engine >= 0.3.4.post2
if not hasattr(self.engine, "batch_transfer_sync_write"):
raise RuntimeError(
"Mooncake's batch transfer requires mooncake-transfer-engine >= 0.3.4.post2. "
"Please upgrade Mooncake by 'pip install mooncake-transfer-engine --upgrade'"
)
if ret < 0:
logger.debug(
"Failed to batch transfer data. Buffers: %s, Session: %s, Peer addresses: %s",
buffers,
session_id,
peer_buffer_addresses,
)
return ret
def get_session_id(self):
return self.session_id