[PD] Make bootstrap code common between NIXL and Mooncake (#6473)
This commit is contained in:
@@ -47,3 +47,44 @@ $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --dis
|
|||||||
# decode 1
|
# decode 1
|
||||||
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128
|
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## NIXL
|
||||||
|
### Requirements
|
||||||
|
|
||||||
|
Install via pip.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install nixl
|
||||||
|
```
|
||||||
|
|
||||||
|
Or build from source - may be required if you already have UCX installed.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/ai-dynamo/nixl.git
|
||||||
|
cd nixl
|
||||||
|
pip install . --config-settings=setup-args="-Ducx_path=/path/to/ucx"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
### Llama Single Node
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-transfer-backend nixl
|
||||||
|
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-transfer-backend nixl
|
||||||
|
$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
|
||||||
|
```
|
||||||
|
|
||||||
|
### DeepSeek Multi-Node
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# prefill 0
|
||||||
|
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode normal --mem-fraction-static 0.8
|
||||||
|
# prefill 1
|
||||||
|
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode normal --mem-fraction-static 0.8
|
||||||
|
# decode 0
|
||||||
|
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128
|
||||||
|
# decode 1
|
||||||
|
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128
|
||||||
|
```
|
||||||
|
|||||||
1
python/sglang/srt/disaggregation/common/__init__.py
Normal file
1
python/sglang/srt/disaggregation/common/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .conn import CommonKVBootstrapServer, CommonKVManager, CommonKVReceiver
|
||||||
401
python/sglang/srt/disaggregation/common/conn.py
Normal file
401
python/sglang/srt/disaggregation/common/conn.py
Normal file
@@ -0,0 +1,401 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import socket
|
||||||
|
import threading
|
||||||
|
from functools import cache
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
import requests
|
||||||
|
import zmq
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from sglang.srt.disaggregation.base.conn import (
|
||||||
|
BaseKVBootstrapServer,
|
||||||
|
BaseKVManager,
|
||||||
|
BaseKVReceiver,
|
||||||
|
BaseKVSender,
|
||||||
|
KVArgs,
|
||||||
|
KVPoll,
|
||||||
|
)
|
||||||
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CommonKVManager(BaseKVManager):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args: KVArgs,
|
||||||
|
disaggregation_mode: DisaggregationMode,
|
||||||
|
server_args: ServerArgs,
|
||||||
|
is_mla_backend: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
self.kv_args = args
|
||||||
|
self.is_mla_backend = is_mla_backend
|
||||||
|
self.disaggregation_mode = disaggregation_mode
|
||||||
|
# for p/d multi node infer
|
||||||
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
||||||
|
self.dist_init_addr = server_args.dist_init_addr
|
||||||
|
self.tp_size = server_args.tp_size
|
||||||
|
self.dp_size = server_args.dp_size
|
||||||
|
self.enable_dp_attention = server_args.enable_dp_attention
|
||||||
|
if not server_args.enable_dp_attention and server_args.dp_size != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"If dp_attention is not enabled, dp size must be 1 in disaggregation mode."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.rank_port = get_free_port()
|
||||||
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
|
self._register_to_bootstrap()
|
||||||
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
|
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
||||||
|
self.prefill_tp_size_table: Dict[str, int] = {}
|
||||||
|
self.prefill_dp_size_table: Dict[str, int] = {}
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _register_to_bootstrap(self):
|
||||||
|
"""Register KVSender to bootstrap server via HTTP POST."""
|
||||||
|
if self.dist_init_addr:
|
||||||
|
ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
|
||||||
|
else:
|
||||||
|
ip_address = get_ip()
|
||||||
|
|
||||||
|
bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
|
||||||
|
url = f"http://{bootstrap_server_url}/route"
|
||||||
|
payload = {
|
||||||
|
"role": "Prefill",
|
||||||
|
"tp_size": self.tp_size,
|
||||||
|
"dp_size": self.dp_size,
|
||||||
|
"rank_ip": get_local_ip_by_remote(),
|
||||||
|
"rank_port": self.rank_port,
|
||||||
|
"engine_rank": self.kv_args.engine_rank,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.put(url, json=payload)
|
||||||
|
if response.status_code == 200:
|
||||||
|
logger.debug("Prefill successfully registered to bootstrap server.")
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Prefill Failed to register to bootstrap server: {e}")
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def _connect(self, endpoint: str):
|
||||||
|
socket = zmq.Context().socket(zmq.PUSH)
|
||||||
|
socket.connect(endpoint)
|
||||||
|
return socket
|
||||||
|
|
||||||
|
|
||||||
|
class CommonKVReceiver(BaseKVReceiver):
|
||||||
|
_ctx = zmq.Context()
|
||||||
|
_socket_cache = {}
|
||||||
|
_socket_locks = {}
|
||||||
|
_global_lock = threading.Lock()
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mgr: BaseKVManager,
|
||||||
|
bootstrap_addr: str,
|
||||||
|
bootstrap_room: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self.bootstrap_room = bootstrap_room
|
||||||
|
self.bootstrap_addr = bootstrap_addr
|
||||||
|
self.kv_mgr = mgr
|
||||||
|
|
||||||
|
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
||||||
|
self.prefill_tp_size, self.prefill_dp_size = (
|
||||||
|
self._get_prefill_dp_size_from_server()
|
||||||
|
)
|
||||||
|
if self.prefill_tp_size is None or self.prefill_dp_size is None:
|
||||||
|
logger.error(
|
||||||
|
f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = (
|
||||||
|
self.prefill_tp_size
|
||||||
|
)
|
||||||
|
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
|
||||||
|
self.prefill_dp_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[
|
||||||
|
self.bootstrap_addr
|
||||||
|
]
|
||||||
|
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
|
||||||
|
self.bootstrap_addr
|
||||||
|
]
|
||||||
|
|
||||||
|
# Currently, we don't allow prefill instance and decode instance to
|
||||||
|
# have different TP sizes per DP rank, except for models using MLA.
|
||||||
|
local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size
|
||||||
|
prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
|
||||||
|
if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
|
||||||
|
self.target_tp_rank = (
|
||||||
|
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
|
||||||
|
)
|
||||||
|
self.required_dst_info_num = 1
|
||||||
|
self.target_tp_ranks = [self.target_tp_rank]
|
||||||
|
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
|
||||||
|
assert (
|
||||||
|
self.kv_mgr.is_mla_backend
|
||||||
|
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
|
||||||
|
self.target_tp_rank = (
|
||||||
|
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
|
||||||
|
) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
|
||||||
|
self.required_dst_info_num = (
|
||||||
|
local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
|
||||||
|
)
|
||||||
|
self.target_tp_ranks = [self.target_tp_rank]
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
self.kv_mgr.is_mla_backend
|
||||||
|
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
|
||||||
|
|
||||||
|
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
|
||||||
|
self.target_tp_ranks = [
|
||||||
|
rank
|
||||||
|
for rank in range(
|
||||||
|
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank)
|
||||||
|
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
|
||||||
|
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1)
|
||||||
|
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
|
||||||
|
# multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
|
||||||
|
# or the KVPoll will never be set correctly
|
||||||
|
self.target_tp_rank = self.target_tp_ranks[0]
|
||||||
|
self.required_dst_info_num = 1
|
||||||
|
|
||||||
|
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
||||||
|
|
||||||
|
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
||||||
|
bootstrap_key = (
|
||||||
|
f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if bootstrap_key not in self.kv_mgr.connection_pool:
|
||||||
|
bootstrap_infos = []
|
||||||
|
for target_tp_rank in self.target_tp_ranks:
|
||||||
|
bootstrap_info = self._get_bootstrap_info_from_server(
|
||||||
|
target_tp_rank,
|
||||||
|
self.target_dp_group,
|
||||||
|
)
|
||||||
|
if bootstrap_info is not None:
|
||||||
|
# NOTE: only support MLA for now: select one prefill rank as real rank
|
||||||
|
bootstrap_info["is_dummy"] = not bool(
|
||||||
|
target_tp_rank == self.target_tp_rank
|
||||||
|
or self.target_tp_rank is None
|
||||||
|
)
|
||||||
|
bootstrap_infos.append(bootstrap_info)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}"
|
||||||
|
)
|
||||||
|
self.bootstrap_infos = bootstrap_infos
|
||||||
|
|
||||||
|
if len(self.bootstrap_infos) == 0:
|
||||||
|
logger.error(
|
||||||
|
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
|
||||||
|
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
|
||||||
|
self._register_kv_args()
|
||||||
|
else:
|
||||||
|
self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
|
||||||
|
|
||||||
|
assert len(self.bootstrap_infos) > 0
|
||||||
|
|
||||||
|
def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
|
||||||
|
"""Fetch the bootstrap info from the bootstrap server."""
|
||||||
|
try:
|
||||||
|
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
|
||||||
|
response = requests.get(url)
|
||||||
|
if response.status_code == 200:
|
||||||
|
bootstrap_info = response.json()
|
||||||
|
return bootstrap_info
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to get prefill server info: {response.status_code}, {response.text}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_prefill_dp_size_from_server(self) -> int:
|
||||||
|
"""Fetch the prefill parallel info from the bootstrap server."""
|
||||||
|
try:
|
||||||
|
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
|
||||||
|
response = requests.get(url)
|
||||||
|
if response.status_code == 200:
|
||||||
|
prefill_parallel_info = response.json()
|
||||||
|
return int(prefill_parallel_info["prefill_tp_size"]), int(
|
||||||
|
prefill_parallel_info["prefill_dp_size"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _connect(cls, endpoint: str):
|
||||||
|
with cls._global_lock:
|
||||||
|
if endpoint not in cls._socket_cache:
|
||||||
|
sock = cls._ctx.socket(zmq.PUSH)
|
||||||
|
sock.connect(endpoint)
|
||||||
|
cls._socket_cache[endpoint] = sock
|
||||||
|
cls._socket_locks[endpoint] = threading.Lock()
|
||||||
|
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
||||||
|
|
||||||
|
def _register_kv_args(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def failure_exception(self):
|
||||||
|
raise Exception("Fake KVReceiver Exception")
|
||||||
|
|
||||||
|
|
||||||
|
class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
||||||
|
def __init__(self, port: int):
|
||||||
|
self.port = port
|
||||||
|
self.app = web.Application()
|
||||||
|
self.store = dict()
|
||||||
|
self.lock = asyncio.Lock()
|
||||||
|
self._setup_routes()
|
||||||
|
self.tp_size = None
|
||||||
|
self.dp_size = None
|
||||||
|
self.tp_size_per_dp_rank = None
|
||||||
|
self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
|
||||||
|
|
||||||
|
# Start bootstrap server
|
||||||
|
self.thread = threading.Thread(target=self._run_server, daemon=True)
|
||||||
|
self.run()
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
|
def _setup_routes(self):
|
||||||
|
self.app.router.add_route("*", "/route", self._handle_route)
|
||||||
|
|
||||||
|
async def _handle_route(self, request: web.Request):
|
||||||
|
method = request.method
|
||||||
|
if method == "PUT":
|
||||||
|
return await self._handle_route_put(request)
|
||||||
|
elif method == "GET":
|
||||||
|
return await self._handle_route_get(request)
|
||||||
|
else:
|
||||||
|
return web.Response(
|
||||||
|
text="Method not allowed", status=405, content_type="application/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_route_put(self, request: web.Request):
|
||||||
|
data = await request.json()
|
||||||
|
role = data["role"]
|
||||||
|
tp_size = data["tp_size"]
|
||||||
|
dp_size = data["dp_size"]
|
||||||
|
rank_ip = data["rank_ip"]
|
||||||
|
rank_port = int(data["rank_port"])
|
||||||
|
engine_rank = int(data["engine_rank"])
|
||||||
|
|
||||||
|
if self.tp_size is None:
|
||||||
|
self.tp_size = tp_size
|
||||||
|
|
||||||
|
if self.dp_size is None:
|
||||||
|
self.dp_size = dp_size
|
||||||
|
|
||||||
|
tp_size_per_dp_rank = tp_size // dp_size
|
||||||
|
if self.tp_size_per_dp_rank == None:
|
||||||
|
self.tp_size_per_dp_rank = tp_size_per_dp_rank
|
||||||
|
|
||||||
|
# Add lock to make sure thread-safe
|
||||||
|
if role == "Prefill":
|
||||||
|
dp_group = engine_rank // tp_size_per_dp_rank
|
||||||
|
tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank
|
||||||
|
|
||||||
|
async with self.lock:
|
||||||
|
if dp_group not in self.prefill_port_table:
|
||||||
|
self.prefill_port_table[dp_group] = {}
|
||||||
|
|
||||||
|
self.prefill_port_table[dp_group][tp_rank_in_dp_group] = {
|
||||||
|
"rank_ip": rank_ip,
|
||||||
|
"rank_port": rank_port,
|
||||||
|
}
|
||||||
|
logger.debug(
|
||||||
|
f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return web.Response(text="OK", status=200)
|
||||||
|
|
||||||
|
async def _handle_route_get(self, request: web.Request):
|
||||||
|
engine_rank = request.query.get("engine_rank")
|
||||||
|
target_dp_group = request.query.get("target_dp_group")
|
||||||
|
if not engine_rank or not target_dp_group:
|
||||||
|
return web.Response(text="Missing inputs for bootstrap server.", status=400)
|
||||||
|
|
||||||
|
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
|
||||||
|
if int(engine_rank) == -1 and int(target_dp_group) == -1:
|
||||||
|
prefill_parallel_info = {
|
||||||
|
"prefill_tp_size": self.tp_size,
|
||||||
|
"prefill_dp_size": self.dp_size,
|
||||||
|
}
|
||||||
|
return web.json_response(prefill_parallel_info, status=200)
|
||||||
|
|
||||||
|
# Find corresponding prefill info
|
||||||
|
async with self.lock:
|
||||||
|
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
|
||||||
|
int(engine_rank)
|
||||||
|
]
|
||||||
|
|
||||||
|
if bootstrap_info is not None:
|
||||||
|
return web.json_response(bootstrap_info, status=200)
|
||||||
|
else:
|
||||||
|
return web.Response(text="Bootstrap info not Found", status=404)
|
||||||
|
|
||||||
|
def _run_server(self):
|
||||||
|
try:
|
||||||
|
# Event Loop
|
||||||
|
self._loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self._loop)
|
||||||
|
|
||||||
|
self._runner = web.AppRunner(self.app)
|
||||||
|
self._loop.run_until_complete(self._runner.setup())
|
||||||
|
|
||||||
|
site = web.TCPSite(self._runner, port=self.port)
|
||||||
|
self._loop.run_until_complete(site.start())
|
||||||
|
self._loop.run_forever()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Server error: {str(e)}")
|
||||||
|
finally:
|
||||||
|
# Cleanup
|
||||||
|
self._loop.run_until_complete(self._runner.cleanup())
|
||||||
|
self._loop.close()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Shutdown"""
|
||||||
|
if self._loop is not None and self._loop.is_running():
|
||||||
|
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||||
|
logger.info("Stopping server loop...")
|
||||||
|
|
||||||
|
if self.thread.is_alive():
|
||||||
|
self.thread.join(timeout=2)
|
||||||
|
logger.info("Server thread stopped")
|
||||||
|
|
||||||
|
def poll(self) -> KVPoll: ...
|
||||||
@@ -29,7 +29,10 @@ from sglang.srt.disaggregation.base.conn import (
|
|||||||
KVPoll,
|
KVPoll,
|
||||||
)
|
)
|
||||||
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
||||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
from sglang.srt.disaggregation.utils import (
|
||||||
|
DisaggregationMode,
|
||||||
|
group_concurrent_contiguous,
|
||||||
|
)
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_free_port,
|
get_free_port,
|
||||||
@@ -41,23 +44,6 @@ from sglang.srt.utils import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def group_concurrent_contiguous(
|
|
||||||
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
|
|
||||||
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
|
||||||
"""Vectorised NumPy implementation."""
|
|
||||||
if src_indices.size == 0:
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
|
|
||||||
src_groups = np.split(src_indices, brk)
|
|
||||||
dst_groups = np.split(dst_indices, brk)
|
|
||||||
|
|
||||||
src_groups = [g.tolist() for g in src_groups]
|
|
||||||
dst_groups = [g.tolist() for g in dst_groups]
|
|
||||||
|
|
||||||
return src_groups, dst_groups
|
|
||||||
|
|
||||||
|
|
||||||
class KVTransferError(Exception):
|
class KVTransferError(Exception):
|
||||||
def __init__(self, bootstrap_room: int, failure_reason: str):
|
def __init__(self, bootstrap_room: int, failure_reason: str):
|
||||||
super().__init__(failure_reason)
|
super().__init__(failure_reason)
|
||||||
|
|||||||
@@ -18,40 +18,23 @@ import requests
|
|||||||
import zmq
|
import zmq
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from sglang.srt.disaggregation.base.conn import (
|
from sglang.srt.disaggregation.base.conn import BaseKVSender, KVArgs, KVPoll
|
||||||
BaseKVBootstrapServer,
|
from sglang.srt.disaggregation.common.conn import (
|
||||||
BaseKVManager,
|
CommonKVBootstrapServer,
|
||||||
BaseKVReceiver,
|
CommonKVManager,
|
||||||
BaseKVSender,
|
CommonKVReceiver,
|
||||||
KVArgs,
|
)
|
||||||
KVPoll,
|
from sglang.srt.disaggregation.utils import (
|
||||||
|
DisaggregationMode,
|
||||||
|
group_concurrent_contiguous,
|
||||||
)
|
)
|
||||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
|
from sglang.srt.utils import get_local_ip_by_remote
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
|
NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
|
||||||
|
|
||||||
|
|
||||||
def group_concurrent_contiguous(
|
|
||||||
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
|
|
||||||
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
|
||||||
"""Vectorised NumPy implementation."""
|
|
||||||
if src_indices.size == 0:
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
|
|
||||||
src_groups = np.split(src_indices, brk)
|
|
||||||
dst_groups = np.split(dst_indices, brk)
|
|
||||||
|
|
||||||
src_groups = [g.tolist() for g in src_groups]
|
|
||||||
dst_groups = [g.tolist() for g in dst_groups]
|
|
||||||
|
|
||||||
return src_groups, dst_groups
|
|
||||||
|
|
||||||
|
|
||||||
GUARD = "NixlMsgGuard".encode("ascii")
|
GUARD = "NixlMsgGuard".encode("ascii")
|
||||||
|
|
||||||
|
|
||||||
@@ -61,11 +44,13 @@ class TransferInfo:
|
|||||||
endpoint: str
|
endpoint: str
|
||||||
dst_port: int
|
dst_port: int
|
||||||
agent_metadata: bytes
|
agent_metadata: bytes
|
||||||
|
agent_name: str
|
||||||
dst_kv_ptrs: list[int]
|
dst_kv_ptrs: list[int]
|
||||||
dst_kv_indices: npt.NDArray[np.int64]
|
dst_kv_indices: npt.NDArray[np.int64]
|
||||||
dst_aux_ptrs: list[int]
|
dst_aux_ptrs: list[int]
|
||||||
dst_aux_index: int
|
dst_aux_index: int
|
||||||
dst_gpu_id: int
|
dst_gpu_id: int
|
||||||
|
required_dst_info_num: int
|
||||||
|
|
||||||
def is_dummy(self):
|
def is_dummy(self):
|
||||||
return self.endpoint == ""
|
return self.endpoint == ""
|
||||||
@@ -79,11 +64,13 @@ class TransferInfo:
|
|||||||
endpoint="",
|
endpoint="",
|
||||||
dst_port=0,
|
dst_port=0,
|
||||||
agent_metadata=b"",
|
agent_metadata=b"",
|
||||||
|
agent_name="",
|
||||||
dst_kv_ptrs=[],
|
dst_kv_ptrs=[],
|
||||||
dst_kv_indices=np.array([], dtype=np.int64),
|
dst_kv_indices=np.array([], dtype=np.int64),
|
||||||
dst_aux_ptrs=[],
|
dst_aux_ptrs=[],
|
||||||
dst_aux_index=0,
|
dst_aux_index=0,
|
||||||
dst_gpu_id=0,
|
dst_gpu_id=0,
|
||||||
|
required_dst_info_num=0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return cls(
|
return cls(
|
||||||
@@ -91,11 +78,13 @@ class TransferInfo:
|
|||||||
endpoint=msg[1].decode("ascii"),
|
endpoint=msg[1].decode("ascii"),
|
||||||
dst_port=int(msg[2].decode("ascii")),
|
dst_port=int(msg[2].decode("ascii")),
|
||||||
agent_metadata=msg[3],
|
agent_metadata=msg[3],
|
||||||
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
|
agent_name=msg[4].decode("ascii"),
|
||||||
dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
|
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
||||||
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
|
dst_kv_indices=np.frombuffer(msg[6], dtype=np.int64),
|
||||||
dst_aux_index=int(msg[7].decode("ascii")),
|
dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
|
||||||
dst_gpu_id=int(msg[8].decode("ascii")),
|
dst_aux_index=int(msg[8].decode("ascii")),
|
||||||
|
dst_gpu_id=int(msg[9].decode("ascii")),
|
||||||
|
required_dst_info_num=int(msg[10].decode("ascii")),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -116,7 +105,7 @@ class TransferStatus:
|
|||||||
return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
|
return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
|
||||||
|
|
||||||
|
|
||||||
class NixlKVManager(BaseKVManager):
|
class NixlKVManager(CommonKVManager):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
args: KVArgs,
|
args: KVArgs,
|
||||||
@@ -124,6 +113,7 @@ class NixlKVManager(BaseKVManager):
|
|||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
is_mla_backend: Optional[bool] = False,
|
is_mla_backend: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
|
super().__init__(args, disaggregation_mode, server_args, is_mla_backend)
|
||||||
try:
|
try:
|
||||||
from nixl._api import nixl_agent
|
from nixl._api import nixl_agent
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@@ -133,38 +123,15 @@ class NixlKVManager(BaseKVManager):
|
|||||||
"to run SGLang with NixlTransferEngine."
|
"to run SGLang with NixlTransferEngine."
|
||||||
) from e
|
) from e
|
||||||
self.agent = nixl_agent(str(uuid.uuid4()))
|
self.agent = nixl_agent(str(uuid.uuid4()))
|
||||||
self.kv_args = args
|
|
||||||
self.disaggregation_mode = disaggregation_mode
|
|
||||||
# for p/d multi node infer
|
|
||||||
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
|
||||||
self.dist_init_addr = server_args.dist_init_addr
|
|
||||||
self.tp_size = server_args.tp_size
|
|
||||||
|
|
||||||
self.tp_rank = args.engine_rank
|
|
||||||
self.enable_dp_attention = server_args.enable_dp_attention
|
|
||||||
if self.enable_dp_attention:
|
|
||||||
assert (
|
|
||||||
server_args.dp_size > 1
|
|
||||||
), "If dp_attention is enabled, dp size must be greater than 1 in disaggregation mode."
|
|
||||||
self.dp_size = server_args.dp_size
|
|
||||||
self.tp_size_of_dp = server_args.tp_size // server_args.dp_size
|
|
||||||
self.attn_tp_rank = args.engine_rank % self.tp_size_of_dp
|
|
||||||
self.dp_rank = args.engine_rank // self.tp_size_of_dp
|
|
||||||
|
|
||||||
self.rank_port = None
|
|
||||||
self.server_socket = zmq.Context().socket(zmq.PULL)
|
self.server_socket = zmq.Context().socket(zmq.PULL)
|
||||||
self.register_buffer_to_engine()
|
self.register_buffer_to_engine()
|
||||||
|
|
||||||
self.rank_port = get_free_port()
|
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
|
self.request_status = {}
|
||||||
self.transfer_infos: Dict[int, TransferInfo] = {}
|
self.transfer_infos: Dict[int, TransferInfo] = {}
|
||||||
self.condition = threading.Condition()
|
self.peer_names: Dict[str, str] = {}
|
||||||
self.peer_names: Dict[int, str] = {}
|
|
||||||
self._start_bootstrap_thread()
|
self._start_bootstrap_thread()
|
||||||
self._register_to_bootstrap()
|
|
||||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
# bootstrap key -> (remote_engine_rank -> possible remote source info)
|
|
||||||
self.prefill_peer_infos: Dict[str, list[Dict[int, NixlEngineInfo]]] = {}
|
|
||||||
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
|
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
|
||||||
TransferStatus
|
TransferStatus
|
||||||
)
|
)
|
||||||
@@ -173,6 +140,18 @@ class NixlKVManager(BaseKVManager):
|
|||||||
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def check_status(self, bootstrap_room: int):
|
||||||
|
return self.request_status[bootstrap_room]
|
||||||
|
|
||||||
|
def update_status(self, bootstrap_room: int, status: KVPoll):
|
||||||
|
if bootstrap_room not in self.request_status:
|
||||||
|
self.request_status[bootstrap_room] = status
|
||||||
|
else:
|
||||||
|
# NOTE: The prefill engine could recv bootstrapping first
|
||||||
|
self.request_status[bootstrap_room] = max(
|
||||||
|
self.request_status[bootstrap_room], status
|
||||||
|
)
|
||||||
|
|
||||||
def register_buffer_to_engine(self):
|
def register_buffer_to_engine(self):
|
||||||
kv_addrs = []
|
kv_addrs = []
|
||||||
for kv_data_ptr, kv_data_len in zip(
|
for kv_data_ptr, kv_data_len in zip(
|
||||||
@@ -193,16 +172,10 @@ class NixlKVManager(BaseKVManager):
|
|||||||
if not self.aux_descs:
|
if not self.aux_descs:
|
||||||
raise Exception("NIXL memory registration failed for aux tensors")
|
raise Exception("NIXL memory registration failed for aux tensors")
|
||||||
|
|
||||||
@cache
|
def _add_remote(self, agent_name: str, agent_metadata: bytes):
|
||||||
def _connect(self, endpoint: str):
|
if agent_name not in self.peer_names:
|
||||||
socket = zmq.Context().socket(zmq.PUSH)
|
self.peer_names[agent_name] = self.agent.add_remote_agent(agent_metadata)
|
||||||
socket.connect(endpoint)
|
return self.peer_names[agent_name]
|
||||||
return socket
|
|
||||||
|
|
||||||
def _add_remote(self, room: int, agent_metadata: bytes):
|
|
||||||
if room not in self.peer_names:
|
|
||||||
self.peer_names[room] = self.agent.add_remote_agent(agent_metadata)
|
|
||||||
return self.peer_names[room]
|
|
||||||
|
|
||||||
def send_kvcache(
|
def send_kvcache(
|
||||||
self,
|
self,
|
||||||
@@ -300,40 +273,38 @@ class NixlKVManager(BaseKVManager):
|
|||||||
assert self.disaggregation_mode == DisaggregationMode.PREFILL
|
assert self.disaggregation_mode == DisaggregationMode.PREFILL
|
||||||
assert not is_last or (is_last and aux_index is not None)
|
assert not is_last or (is_last and aux_index is not None)
|
||||||
|
|
||||||
# Wait for transfer info to be populated by bootstrap thread.
|
reqs_to_be_processed = self.transfer_infos[bootstrap_room].values()
|
||||||
with self.condition:
|
handles = []
|
||||||
self.condition.wait_for(lambda: bootstrap_room in self.transfer_infos)
|
for req in reqs_to_be_processed:
|
||||||
req = self.transfer_infos[bootstrap_room]
|
assert bootstrap_room == req.room
|
||||||
assert bootstrap_room == req.room
|
if req.is_dummy():
|
||||||
|
return []
|
||||||
|
|
||||||
if req.is_dummy():
|
peer_name = self._add_remote(req.agent_name, req.agent_metadata)
|
||||||
return []
|
chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
|
||||||
|
assert len(chunked_dst_kv_indice) == len(kv_indices)
|
||||||
|
|
||||||
peer_name = self._add_remote(bootstrap_room, req.agent_metadata)
|
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
|
||||||
chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
|
kv_xfer_handle = self.send_kvcache(
|
||||||
assert len(chunked_dst_kv_indice) == len(kv_indices)
|
|
||||||
|
|
||||||
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
|
|
||||||
kv_xfer_handle = self.send_kvcache(
|
|
||||||
peer_name,
|
|
||||||
kv_indices,
|
|
||||||
req.dst_kv_ptrs,
|
|
||||||
chunked_dst_kv_indice,
|
|
||||||
req.dst_gpu_id,
|
|
||||||
notif,
|
|
||||||
)
|
|
||||||
handles = [kv_xfer_handle]
|
|
||||||
# Only the last chunk we need to send the aux data.
|
|
||||||
if is_last:
|
|
||||||
assert aux_index is not None
|
|
||||||
aux_xfer_handle = self.send_aux(
|
|
||||||
peer_name,
|
peer_name,
|
||||||
aux_index,
|
kv_indices,
|
||||||
req.dst_aux_ptrs,
|
req.dst_kv_ptrs,
|
||||||
req.dst_aux_index,
|
chunked_dst_kv_indice,
|
||||||
str(req.room) + "_aux",
|
req.dst_gpu_id,
|
||||||
|
notif,
|
||||||
)
|
)
|
||||||
handles.append(aux_xfer_handle)
|
handles.append(kv_xfer_handle)
|
||||||
|
# Only the last chunk we need to send the aux data.
|
||||||
|
if is_last:
|
||||||
|
assert aux_index is not None
|
||||||
|
aux_xfer_handle = self.send_aux(
|
||||||
|
peer_name,
|
||||||
|
aux_index,
|
||||||
|
req.dst_aux_ptrs,
|
||||||
|
req.dst_aux_index,
|
||||||
|
str(req.room) + "_aux",
|
||||||
|
)
|
||||||
|
handles.append(aux_xfer_handle)
|
||||||
return handles
|
return handles
|
||||||
|
|
||||||
def update_transfer_status(self):
|
def update_transfer_status(self):
|
||||||
@@ -348,7 +319,7 @@ class NixlKVManager(BaseKVManager):
|
|||||||
room = int(components[0])
|
room = int(components[0])
|
||||||
if components[1] == "kv":
|
if components[1] == "kv":
|
||||||
chunk_id = int(components[2])
|
chunk_id = int(components[2])
|
||||||
is_last = bool(components[3])
|
is_last = bool(int(components[3]))
|
||||||
self.transfer_statuses[room].received_kvs.add(chunk_id)
|
self.transfer_statuses[room].received_kvs.add(chunk_id)
|
||||||
if is_last:
|
if is_last:
|
||||||
self.transfer_statuses[room].num_kvs_expected = chunk_id + 1
|
self.transfer_statuses[room].num_kvs_expected = chunk_id + 1
|
||||||
@@ -360,34 +331,6 @@ class NixlKVManager(BaseKVManager):
|
|||||||
return False
|
return False
|
||||||
return self.transfer_statuses[room].is_done()
|
return self.transfer_statuses[room].is_done()
|
||||||
|
|
||||||
def _register_to_bootstrap(self):
|
|
||||||
"""Register KVSender to bootstrap server via HTTP POST."""
|
|
||||||
if self.dist_init_addr:
|
|
||||||
ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
|
|
||||||
else:
|
|
||||||
ip_address = get_ip()
|
|
||||||
|
|
||||||
bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
|
|
||||||
url = f"http://{bootstrap_server_url}/route"
|
|
||||||
payload = {
|
|
||||||
"role": "Prefill",
|
|
||||||
"rank_ip": get_local_ip_by_remote(),
|
|
||||||
"rank_port": self.rank_port,
|
|
||||||
"engine_rank": self.kv_args.engine_rank,
|
|
||||||
"agent_name": self.agent.name,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = requests.put(url, json=payload)
|
|
||||||
if response.status_code == 200:
|
|
||||||
logger.debug("Prefill successfully registered to bootstrap server.")
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Prefill Failed to register to bootstrap server: {e}")
|
|
||||||
|
|
||||||
def _start_bootstrap_thread(self):
|
def _start_bootstrap_thread(self):
|
||||||
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
|
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
|
||||||
|
|
||||||
@@ -405,10 +348,19 @@ class NixlKVManager(BaseKVManager):
|
|||||||
room = waiting_req_bytes[0].decode("ascii")
|
room = waiting_req_bytes[0].decode("ascii")
|
||||||
if room == "None":
|
if room == "None":
|
||||||
continue
|
continue
|
||||||
|
required_dst_info_num = int(waiting_req_bytes[10].decode("ascii"))
|
||||||
room = int(room)
|
room = int(room)
|
||||||
with self.condition:
|
agent_name = waiting_req_bytes[4].decode("ascii")
|
||||||
self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes)
|
if room not in self.transfer_infos:
|
||||||
self.condition.notify_all()
|
self.transfer_infos[room] = {}
|
||||||
|
self.transfer_infos[room][agent_name] = TransferInfo.from_zmq(
|
||||||
|
waiting_req_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"got info {room=} {agent_name=} {required_dst_info_num=}")
|
||||||
|
if len(self.transfer_infos[room]) == required_dst_info_num:
|
||||||
|
logger.debug(f"{room=} is bootstrapped")
|
||||||
|
self.update_status(room, KVPoll.WaitingForInput)
|
||||||
|
|
||||||
threading.Thread(target=bootstrap_thread).start()
|
threading.Thread(target=bootstrap_thread).start()
|
||||||
|
|
||||||
@@ -423,6 +375,9 @@ class NixlKVSender(BaseKVSender):
|
|||||||
self.xfer_handles = []
|
self.xfer_handles = []
|
||||||
self.has_sent = False
|
self.has_sent = False
|
||||||
self.chunk_id = 0
|
self.chunk_id = 0
|
||||||
|
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
||||||
|
# inner state
|
||||||
|
self.curr_idx = 0
|
||||||
|
|
||||||
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
||||||
self.num_kv_indices = num_kv_indices
|
self.num_kv_indices = num_kv_indices
|
||||||
@@ -431,9 +386,11 @@ class NixlKVSender(BaseKVSender):
|
|||||||
def send(
|
def send(
|
||||||
self,
|
self,
|
||||||
kv_indices: npt.NDArray[np.int64],
|
kv_indices: npt.NDArray[np.int64],
|
||||||
index_slice: slice,
|
|
||||||
is_last: bool,
|
|
||||||
):
|
):
|
||||||
|
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
|
||||||
|
self.curr_idx += len(kv_indices)
|
||||||
|
is_last = self.curr_idx == self.num_kv_indices
|
||||||
|
|
||||||
new_xfer_handles = self.kv_mgr.add_transfer_request(
|
new_xfer_handles = self.kv_mgr.add_transfer_request(
|
||||||
self.bootstrap_room,
|
self.bootstrap_room,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
@@ -449,7 +406,7 @@ class NixlKVSender(BaseKVSender):
|
|||||||
|
|
||||||
def poll(self) -> KVPoll:
|
def poll(self) -> KVPoll:
|
||||||
if not self.has_sent:
|
if not self.has_sent:
|
||||||
return KVPoll.WaitingForInput # type: ignore
|
return self.kv_mgr.check_status(self.bootstrap_room)
|
||||||
states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles]
|
states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles]
|
||||||
if all([x == "DONE" for x in states]):
|
if all([x == "DONE" for x in states]):
|
||||||
return KVPoll.Success # type: ignore
|
return KVPoll.Success # type: ignore
|
||||||
@@ -461,128 +418,40 @@ class NixlKVSender(BaseKVSender):
|
|||||||
raise Exception("Fake KVSender Exception")
|
raise Exception("Fake KVSender Exception")
|
||||||
|
|
||||||
|
|
||||||
class NixlKVReceiver(BaseKVReceiver):
|
class NixlKVReceiver(CommonKVReceiver):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
mgr: NixlKVManager,
|
mgr: NixlKVManager,
|
||||||
bootstrap_addr: str,
|
bootstrap_addr: str,
|
||||||
bootstrap_room: Optional[int] = None,
|
bootstrap_room: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.bootstrap_room = bootstrap_room
|
|
||||||
self.bootstrap_addr = bootstrap_addr
|
|
||||||
self.kv_mgr = mgr
|
|
||||||
self.started_transfer = False
|
self.started_transfer = False
|
||||||
|
super().__init__(mgr, bootstrap_addr, bootstrap_room)
|
||||||
# NOTE: key distinguished by bootstrap_addr and engine_rank
|
|
||||||
bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
|
|
||||||
|
|
||||||
if bootstrap_key not in self.kv_mgr.prefill_peer_infos:
|
|
||||||
self.bootstrap_info = self._get_bootstrap_info_from_server(
|
|
||||||
self.kv_mgr.kv_args.engine_rank
|
|
||||||
)
|
|
||||||
if self.bootstrap_info is None:
|
|
||||||
logger.error(
|
|
||||||
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.kv_mgr.prefill_peer_infos[bootstrap_key] = self.bootstrap_info
|
|
||||||
else:
|
|
||||||
self.bootstrap_info = self.kv_mgr.prefill_peer_infos[bootstrap_key]
|
|
||||||
assert self.bootstrap_info is not None
|
|
||||||
|
|
||||||
# return a list of remotes in a dict, [(remote_engine_rank -> NixlEngineInfo), ...]
|
|
||||||
# In each dict, there are multiple possible remotes named "equal sources".
|
|
||||||
# We only need to select one to split the traffic. i.e. we totally select len(list) remotes.
|
|
||||||
def _get_bootstrap_info_from_server(
|
|
||||||
self, engine_rank
|
|
||||||
) -> Optional[List[Dict[int, NixlEngineInfo]]]:
|
|
||||||
"""Fetch the bootstrap info from the bootstrap server."""
|
|
||||||
try:
|
|
||||||
if self.kv_mgr.enable_dp_attention:
|
|
||||||
url = f"http://{self.bootstrap_addr}/route"
|
|
||||||
response = requests.get(url)
|
|
||||||
if response.status_code != 200:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to get prefill server info: {response.status_code}, {response.text}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
bootstrap_info = response.json()
|
|
||||||
assert isinstance(bootstrap_info, dict)
|
|
||||||
bootstrap_info = {int(k): v for k, v in bootstrap_info.items()}
|
|
||||||
|
|
||||||
# split out who need to send to this rank.
|
|
||||||
# currently for dpsk mla model, those ranks share the same latent cache.
|
|
||||||
# pick one as the real source
|
|
||||||
|
|
||||||
prefill_tp_size = len(bootstrap_info.keys())
|
|
||||||
|
|
||||||
assert (
|
|
||||||
prefill_tp_size >= self.kv_mgr.tp_size_of_dp
|
|
||||||
), f"Only support Prefill TP size >= Decode TP size of DP, now we have {prefill_tp_size} vs {self.kv_mgr.tp_size_of_dp}"
|
|
||||||
|
|
||||||
num_remote_tp_rank_we_managed = (
|
|
||||||
prefill_tp_size // self.kv_mgr.tp_size_of_dp
|
|
||||||
)
|
|
||||||
|
|
||||||
# We handle [num * self.attn_tp_rank, num * self.attn_tp_rank + num)
|
|
||||||
remote_tp_ranks = list(range(0, prefill_tp_size))
|
|
||||||
# split it into tp_size_of_dp parts and get our part
|
|
||||||
remote_tp_ranks_grouped = [
|
|
||||||
remote_tp_ranks[i : i + num_remote_tp_rank_we_managed]
|
|
||||||
for i in range(0, prefill_tp_size, self.kv_mgr.tp_size_of_dp)
|
|
||||||
]
|
|
||||||
managed_ranks = remote_tp_ranks_grouped[self.kv_mgr.attn_tp_rank]
|
|
||||||
|
|
||||||
assert len(managed_ranks) == num_remote_tp_rank_we_managed
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Rank {self.kv_mgr.kv_args.engine_rank} source can be {managed_ranks}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
rk: bootstrap_info[rk]
|
|
||||||
for rk in bootstrap_info.keys()
|
|
||||||
if rk in managed_ranks
|
|
||||||
}
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
|
|
||||||
response = requests.get(url)
|
|
||||||
if response.status_code == 200:
|
|
||||||
bootstrap_info = response.json()
|
|
||||||
return [{engine_rank: bootstrap_info}]
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to get prefill server info: {response.status_code}, {response.text}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def _connect(self, endpoint: str):
|
|
||||||
socket = zmq.Context().socket(zmq.PUSH)
|
|
||||||
socket.connect(endpoint)
|
|
||||||
return socket
|
|
||||||
|
|
||||||
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
||||||
|
for bootstrap_info in self.bootstrap_infos:
|
||||||
assert self.bootstrap_info is not None
|
self.prefill_server_url = (
|
||||||
assert self.bootstrap_room is not None
|
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
||||||
|
|
||||||
for equal_sources in self.bootstrap_info:
|
|
||||||
remote_rank = list(equal_sources.keys())[
|
|
||||||
self.bootstrap_room % len(equal_sources)
|
|
||||||
]
|
|
||||||
self.prefill_server_url = f"{equal_sources[remote_rank]['rank_ip']}:{equal_sources[remote_rank]['rank_port']}"
|
|
||||||
logger.debug(
|
|
||||||
f"Fetched bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}, source: {remote_rank}, all: {list(equal_sources.keys())}"
|
|
||||||
)
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
||||||
|
)
|
||||||
|
is_dummy = bootstrap_info["is_dummy"]
|
||||||
|
|
||||||
|
# TODO: just send "" for indices for dummy
|
||||||
|
if is_dummy:
|
||||||
|
# TODO: need to set success??
|
||||||
|
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
||||||
|
with lock:
|
||||||
|
sock.send_multipart(
|
||||||
|
[
|
||||||
|
GUARD,
|
||||||
|
str(self.bootstrap_room).encode("ascii"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# TODO: send_kv_args earlier
|
||||||
packed_kv_data_ptrs = b"".join(
|
packed_kv_data_ptrs = b"".join(
|
||||||
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
||||||
)
|
)
|
||||||
@@ -593,30 +462,22 @@ class NixlKVReceiver(BaseKVReceiver):
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}"
|
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}"
|
||||||
)
|
)
|
||||||
self._connect("tcp://" + self.prefill_server_url).send_multipart(
|
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
||||||
[
|
with lock:
|
||||||
GUARD,
|
sock.send_multipart(
|
||||||
str(self.bootstrap_room).encode("ascii"),
|
|
||||||
get_local_ip_by_remote().encode("ascii"),
|
|
||||||
str(self.kv_mgr.rank_port).encode("ascii"),
|
|
||||||
self.kv_mgr.agent.get_agent_metadata(),
|
|
||||||
packed_kv_data_ptrs,
|
|
||||||
kv_indices.tobytes(),
|
|
||||||
packed_aux_data_ptrs,
|
|
||||||
str(aux_index).encode("ascii"),
|
|
||||||
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
for dummy_rank in equal_sources.keys():
|
|
||||||
if dummy_rank == remote_rank:
|
|
||||||
continue
|
|
||||||
dummy_info = equal_sources[dummy_rank]
|
|
||||||
dummy_url = f"{dummy_info['rank_ip']}:{dummy_info['rank_port']}"
|
|
||||||
self._connect("tcp://" + dummy_url).send_multipart(
|
|
||||||
[
|
[
|
||||||
GUARD,
|
GUARD,
|
||||||
str(self.bootstrap_room).encode("ascii"),
|
str(self.bootstrap_room).encode("ascii"),
|
||||||
|
get_local_ip_by_remote().encode("ascii"),
|
||||||
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
||||||
|
self.kv_mgr.agent.get_agent_metadata(),
|
||||||
|
self.kv_mgr.agent.name.encode("ascii"),
|
||||||
|
packed_kv_data_ptrs,
|
||||||
|
kv_indices.tobytes(),
|
||||||
|
packed_aux_data_ptrs,
|
||||||
|
str(aux_index).encode("ascii"),
|
||||||
|
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
||||||
|
str(self.required_dst_info_num).encode("ascii"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -632,152 +493,12 @@ class NixlKVReceiver(BaseKVReceiver):
|
|||||||
return KVPoll.Success # type: ignore
|
return KVPoll.Success # type: ignore
|
||||||
return KVPoll.WaitingForInput # type: ignore
|
return KVPoll.WaitingForInput # type: ignore
|
||||||
|
|
||||||
|
def _register_kv_args(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def failure_exception(self):
|
def failure_exception(self):
|
||||||
raise Exception("Fake KVReceiver Exception")
|
raise Exception("Fake KVReceiver Exception")
|
||||||
|
|
||||||
|
|
||||||
class NixlKVBootstrapServer(BaseKVBootstrapServer):
|
class NixlKVBootstrapServer(CommonKVBootstrapServer):
|
||||||
def __init__(self, port: int):
|
pass
|
||||||
logger.debug(f"NixlKVBootstrapServer started on port {port}")
|
|
||||||
self.port = port
|
|
||||||
self.app = web.Application()
|
|
||||||
self.store = dict()
|
|
||||||
self.lock = asyncio.Lock()
|
|
||||||
self._setup_routes()
|
|
||||||
self.prefill_port_table: Dict[int, Dict[str, Union[str, int]]] = {}
|
|
||||||
|
|
||||||
# Start bootstrap server
|
|
||||||
self.thread = threading.Thread(target=self._run_server, daemon=True)
|
|
||||||
self.run()
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
self.thread.start()
|
|
||||||
|
|
||||||
def _setup_routes(self):
|
|
||||||
self.app.router.add_route("*", "/metadata", self._handle_metadata)
|
|
||||||
self.app.router.add_route("*", "/route", self._handle_route)
|
|
||||||
|
|
||||||
async def _handle_metadata(self, request: web.Request):
|
|
||||||
key = request.query.get("key", "")
|
|
||||||
|
|
||||||
if request.method == "GET":
|
|
||||||
return await self._handle_metadata_get(key)
|
|
||||||
elif request.method == "PUT":
|
|
||||||
return await self._handle_metadata_put(key, request)
|
|
||||||
elif request.method == "DELETE":
|
|
||||||
return await self._handle_metadata_delete(key)
|
|
||||||
return web.Response(
|
|
||||||
text="Method not allowed", status=405, content_type="application/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _handle_metadata_get(self, key):
|
|
||||||
async with self.lock:
|
|
||||||
value = self.store.get(key)
|
|
||||||
if value is None:
|
|
||||||
return web.Response(
|
|
||||||
text="metadata not found", status=404, content_type="application/json"
|
|
||||||
)
|
|
||||||
return web.Response(body=value, status=200, content_type="application/json")
|
|
||||||
|
|
||||||
async def _handle_metadata_put(self, key, request):
|
|
||||||
data = await request.read()
|
|
||||||
async with self.lock:
|
|
||||||
self.store[key] = data
|
|
||||||
return web.Response(
|
|
||||||
text="metadata updated", status=200, content_type="application/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _handle_metadata_delete(self, key):
|
|
||||||
async with self.lock:
|
|
||||||
if key not in self.store:
|
|
||||||
return web.Response(
|
|
||||||
text="metadata not found",
|
|
||||||
status=404,
|
|
||||||
content_type="application/json",
|
|
||||||
)
|
|
||||||
del self.store[key]
|
|
||||||
return web.Response(
|
|
||||||
text="metadata deleted", status=200, content_type="application/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _handle_route(self, request: web.Request):
|
|
||||||
method = request.method
|
|
||||||
if method == "PUT":
|
|
||||||
return await self._handle_route_put(request)
|
|
||||||
elif method == "GET":
|
|
||||||
return await self._handle_route_get(request)
|
|
||||||
else:
|
|
||||||
return web.Response(
|
|
||||||
text="Method not allowed", status=405, content_type="application/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _handle_route_put(self, request: web.Request):
|
|
||||||
data = await request.json()
|
|
||||||
role = data["role"]
|
|
||||||
rank_ip = data["rank_ip"]
|
|
||||||
rank_port = int(data["rank_port"])
|
|
||||||
engine_rank = int(data["engine_rank"])
|
|
||||||
agent_name = data["agent_name"]
|
|
||||||
|
|
||||||
if role == "Prefill":
|
|
||||||
async with self.lock:
|
|
||||||
self.prefill_port_table[engine_rank] = {
|
|
||||||
"rank_ip": rank_ip,
|
|
||||||
"rank_port": rank_port,
|
|
||||||
"agent_name": agent_name,
|
|
||||||
}
|
|
||||||
logger.info(
|
|
||||||
f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port} and name: {agent_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return web.Response(text="OK", status=200)
|
|
||||||
|
|
||||||
async def _handle_route_get(self, request: web.Request):
|
|
||||||
engine_rank = request.query.get("engine_rank")
|
|
||||||
if not engine_rank:
|
|
||||||
logger.debug(
|
|
||||||
f"No engine_rank specified, return all {len(self.prefill_port_table)} engine infos as a dict"
|
|
||||||
)
|
|
||||||
# Return a dict of all engine_rank
|
|
||||||
async with self.lock:
|
|
||||||
bootstrap_info = self.prefill_port_table
|
|
||||||
return web.json_response(bootstrap_info, status=200)
|
|
||||||
|
|
||||||
# Find corresponding prefill info
|
|
||||||
async with self.lock:
|
|
||||||
bootstrap_info = self.prefill_port_table.get(int(engine_rank))
|
|
||||||
if bootstrap_info is not None:
|
|
||||||
return web.json_response(bootstrap_info, status=200)
|
|
||||||
else:
|
|
||||||
return web.Response(text="Not Found", status=404)
|
|
||||||
|
|
||||||
def _run_server(self):
|
|
||||||
try:
|
|
||||||
# Event Loop
|
|
||||||
self._loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(self._loop)
|
|
||||||
|
|
||||||
self._runner = web.AppRunner(self.app)
|
|
||||||
self._loop.run_until_complete(self._runner.setup())
|
|
||||||
|
|
||||||
site = web.TCPSite(self._runner, port=self.port)
|
|
||||||
self._loop.run_until_complete(site.start())
|
|
||||||
self._loop.run_forever()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Server error: {str(e)}")
|
|
||||||
finally:
|
|
||||||
# Cleanup
|
|
||||||
self._loop.run_until_complete(self._runner.cleanup())
|
|
||||||
self._loop.close()
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""Shutdown"""
|
|
||||||
if self._loop is not None and self._loop.is_running():
|
|
||||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
|
||||||
logger.info("Stopping server loop...")
|
|
||||||
|
|
||||||
if self.thread.is_alive():
|
|
||||||
self.thread.join(timeout=2)
|
|
||||||
logger.info("Server thread stopped")
|
|
||||||
|
|
||||||
def poll(self) -> KVPoll: ...
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import requests
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from sglang.srt.utils import get_ip
|
from sglang.srt.utils import get_ip, get_local_ip_by_remote
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.schedule_batch import Req
|
from sglang.srt.managers.schedule_batch import Req
|
||||||
@@ -279,3 +279,20 @@ class MetadataBuffers:
|
|||||||
] = torch.tensor(
|
] = torch.tensor(
|
||||||
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def group_concurrent_contiguous(
|
||||||
|
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
|
||||||
|
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
||||||
|
"""Vectorised NumPy implementation."""
|
||||||
|
if src_indices.size == 0:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
|
||||||
|
src_groups = np.split(src_indices, brk)
|
||||||
|
dst_groups = np.split(dst_indices, brk)
|
||||||
|
|
||||||
|
src_groups = [g.tolist() for g in src_groups]
|
||||||
|
dst_groups = [g.tolist() for g in dst_groups]
|
||||||
|
|
||||||
|
return src_groups, dst_groups
|
||||||
|
|||||||
Reference in New Issue
Block a user