Distinguish bootstrap key only in decode server (#5422)
This commit is contained in:
@@ -28,13 +28,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
from sglang.srt.disaggregation.base import (
|
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
|
||||||
BaseKVManager,
|
|
||||||
BaseKVReceiver,
|
|
||||||
BaseKVSender,
|
|
||||||
KVArgs,
|
|
||||||
KVPoll,
|
|
||||||
)
|
|
||||||
from sglang.srt.disaggregation.utils import (
|
from sglang.srt.disaggregation.utils import (
|
||||||
DisaggregationMode,
|
DisaggregationMode,
|
||||||
KVClassType,
|
KVClassType,
|
||||||
|
|||||||
@@ -329,7 +329,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
"role": "Prefill",
|
"role": "Prefill",
|
||||||
"rank_ip": get_local_ip_by_remote(),
|
"rank_ip": get_local_ip_by_remote(),
|
||||||
"rank_port": self.rank_port,
|
"rank_port": self.rank_port,
|
||||||
"bootstrap_key": f"{bootstrap_server_url}_{self.kv_args.engine_rank}",
|
"engine_rank": self.kv_args.engine_rank,
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -400,28 +400,29 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
self.session_id = self.kv_mgr.get_session_id()
|
self.session_id = self.kv_mgr.get_session_id()
|
||||||
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
|
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
|
||||||
|
|
||||||
self.bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
|
# NOTE: key distinguished by bootstrap_addr and engine_rank
|
||||||
|
bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
|
||||||
|
|
||||||
if self.bootstrap_key not in self.kv_mgr.connection_pool:
|
if bootstrap_key not in self.kv_mgr.connection_pool:
|
||||||
self.bootstrap_info = self._get_bootstrap_info_from_server(
|
self.bootstrap_info = self._get_bootstrap_info_from_server(
|
||||||
self.bootstrap_key
|
self.kv_mgr.kv_args.engine_rank
|
||||||
)
|
)
|
||||||
if self.bootstrap_info is None:
|
if self.bootstrap_info is None:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.kv_mgr.connection_pool[self.bootstrap_key] = self.bootstrap_info
|
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
|
||||||
else:
|
else:
|
||||||
self.bootstrap_info = self.kv_mgr.connection_pool[self.bootstrap_key]
|
self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
|
||||||
|
|
||||||
assert self.bootstrap_info is not None
|
assert self.bootstrap_info is not None
|
||||||
self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
|
self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
|
||||||
|
|
||||||
def _get_bootstrap_info_from_server(self, bootstrap_key: str):
|
def _get_bootstrap_info_from_server(self, engine_rank):
|
||||||
"""Fetch the bootstrap info from the bootstrap server."""
|
"""Fetch the bootstrap info from the bootstrap server."""
|
||||||
try:
|
try:
|
||||||
url = f"http://{self.bootstrap_addr}/route?bootstrap_key={bootstrap_key}"
|
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
bootstrap_info = response.json()
|
bootstrap_info = response.json()
|
||||||
@@ -556,28 +557,28 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|||||||
role = data["role"]
|
role = data["role"]
|
||||||
rank_ip = data["rank_ip"]
|
rank_ip = data["rank_ip"]
|
||||||
rank_port = int(data["rank_port"])
|
rank_port = int(data["rank_port"])
|
||||||
bootstrap_key = data["bootstrap_key"]
|
engine_rank = int(data["engine_rank"])
|
||||||
|
|
||||||
# Add lock to make sure thread-safe
|
# Add lock to make sure thread-safe
|
||||||
if role == "Prefill":
|
if role == "Prefill":
|
||||||
self.prefill_port_table[bootstrap_key] = {
|
self.prefill_port_table[engine_rank] = {
|
||||||
"rank_ip": rank_ip,
|
"rank_ip": rank_ip,
|
||||||
"rank_port": rank_port,
|
"rank_port": rank_port,
|
||||||
}
|
}
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Registered Prefill bootstrap_key: {bootstrap_key} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return web.Response(text="OK", status=200)
|
return web.Response(text="OK", status=200)
|
||||||
|
|
||||||
async def _handle_route_get(self, request: web.Request):
|
async def _handle_route_get(self, request: web.Request):
|
||||||
bootstrap_key = request.query.get("bootstrap_key")
|
engine_rank = request.query.get("engine_rank")
|
||||||
if not bootstrap_key:
|
if not engine_rank:
|
||||||
return web.Response(text="Missing bootstrap_key", status=400)
|
return web.Response(text="Missing rank", status=400)
|
||||||
|
|
||||||
# Find corresponding prefill info
|
# Find corresponding prefill info
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
bootstrap_info = self.prefill_port_table.get(bootstrap_key)
|
bootstrap_info = self.prefill_port_table.get(int(engine_rank))
|
||||||
|
|
||||||
if bootstrap_info is not None:
|
if bootstrap_info is not None:
|
||||||
return web.json_response(bootstrap_info, status=200)
|
return web.json_response(bootstrap_info, status=200)
|
||||||
|
|||||||
@@ -24,13 +24,7 @@ from typing import TYPE_CHECKING, List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.disaggregation.base import (
|
from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll
|
||||||
BaseKVManager,
|
|
||||||
BaseKVReceiver,
|
|
||||||
BaseKVSender,
|
|
||||||
KVArgs,
|
|
||||||
KVPoll,
|
|
||||||
)
|
|
||||||
from sglang.srt.disaggregation.utils import (
|
from sglang.srt.disaggregation.utils import (
|
||||||
DisaggregationMode,
|
DisaggregationMode,
|
||||||
KVClassType,
|
KVClassType,
|
||||||
|
|||||||
Reference in New Issue
Block a user