[PD] Support get local ip from NIC for PD disaggregation (#7237)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
shangmingc
2025-06-18 08:19:26 +08:00
committed by GitHub
parent 0650e5176f
commit ceaa85c9e6
2 changed files with 46 additions and 12 deletions

View File

@@ -35,12 +35,7 @@ from sglang.srt.disaggregation.common.utils import (
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
get_free_port,
get_int_env_var,
get_ip,
get_local_ip_by_remote,
)
from sglang.srt.utils import get_free_port, get_int_env_var, get_ip, get_local_ip_auto
logger = logging.getLogger(__name__)
@@ -130,8 +125,9 @@ class MooncakeKVManager(BaseKVManager):
is_mla_backend: Optional[bool] = False,
):
self.kv_args = args
self.local_ip = get_local_ip_auto()
self.engine = MooncakeTransferEngine(
hostname=get_local_ip_by_remote(),
hostname=self.local_ip,
gpu_id=self.kv_args.gpu_id,
ib_device=self.kv_args.ib_device,
)
@@ -432,7 +428,7 @@ class MooncakeKVManager(BaseKVManager):
def start_prefill_thread(self):
self.rank_port = get_free_port()
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
def bootstrap_thread():
"""This thread recvs pre-alloc notification from the decode engine"""
@@ -471,7 +467,7 @@ class MooncakeKVManager(BaseKVManager):
def start_decode_thread(self):
self.rank_port = get_free_port()
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
def decode_thread():
while True:
@@ -620,7 +616,7 @@ class MooncakeKVManager(BaseKVManager):
"role": "Prefill",
"tp_size": self.tp_size,
"dp_size": self.dp_size,
"rank_ip": get_local_ip_by_remote(),
"rank_ip": self.local_ip,
"rank_port": self.rank_port,
"engine_rank": self.kv_args.engine_rank,
}
@@ -953,7 +949,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
sock.send_multipart(
[
"None".encode("ascii"),
get_local_ip_by_remote().encode("ascii"),
self.kv_mgr.local_ip.encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.session_id.encode("ascii"),
packed_kv_data_ptrs,
@@ -983,7 +979,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
sock.send_multipart(
[
str(self.bootstrap_room).encode("ascii"),
get_local_ip_by_remote().encode("ascii"),
self.kv_mgr.local_ip.encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.session_id.encode("ascii"),
kv_indices.tobytes() if not is_dummy else b"",