[PD] Support get local ip from NIC for PD disaggregation (#7237)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
@@ -35,12 +35,7 @@ from sglang.srt.disaggregation.common.utils import (
|
||||
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_free_port,
|
||||
get_int_env_var,
|
||||
get_ip,
|
||||
get_local_ip_by_remote,
|
||||
)
|
||||
from sglang.srt.utils import get_free_port, get_int_env_var, get_ip, get_local_ip_auto
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -130,8 +125,9 @@ class MooncakeKVManager(BaseKVManager):
|
||||
is_mla_backend: Optional[bool] = False,
|
||||
):
|
||||
self.kv_args = args
|
||||
self.local_ip = get_local_ip_auto()
|
||||
self.engine = MooncakeTransferEngine(
|
||||
hostname=get_local_ip_by_remote(),
|
||||
hostname=self.local_ip,
|
||||
gpu_id=self.kv_args.gpu_id,
|
||||
ib_device=self.kv_args.ib_device,
|
||||
)
|
||||
@@ -432,7 +428,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
|
||||
def start_prefill_thread(self):
|
||||
self.rank_port = get_free_port()
|
||||
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
|
||||
self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
|
||||
|
||||
def bootstrap_thread():
|
||||
"""This thread recvs pre-alloc notification from the decode engine"""
|
||||
@@ -471,7 +467,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
|
||||
def start_decode_thread(self):
|
||||
self.rank_port = get_free_port()
|
||||
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
|
||||
self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
|
||||
|
||||
def decode_thread():
|
||||
while True:
|
||||
@@ -620,7 +616,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
"role": "Prefill",
|
||||
"tp_size": self.tp_size,
|
||||
"dp_size": self.dp_size,
|
||||
"rank_ip": get_local_ip_by_remote(),
|
||||
"rank_ip": self.local_ip,
|
||||
"rank_port": self.rank_port,
|
||||
"engine_rank": self.kv_args.engine_rank,
|
||||
}
|
||||
@@ -953,7 +949,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
sock.send_multipart(
|
||||
[
|
||||
"None".encode("ascii"),
|
||||
get_local_ip_by_remote().encode("ascii"),
|
||||
self.kv_mgr.local_ip.encode("ascii"),
|
||||
str(self.kv_mgr.rank_port).encode("ascii"),
|
||||
self.session_id.encode("ascii"),
|
||||
packed_kv_data_ptrs,
|
||||
@@ -983,7 +979,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
sock.send_multipart(
|
||||
[
|
||||
str(self.bootstrap_room).encode("ascii"),
|
||||
get_local_ip_by_remote().encode("ascii"),
|
||||
self.kv_mgr.local_ip.encode("ascii"),
|
||||
str(self.kv_mgr.rank_port).encode("ascii"),
|
||||
self.session_id.encode("ascii"),
|
||||
kv_indices.tobytes() if not is_dummy else b"",
|
||||
|
||||
@@ -2141,6 +2141,44 @@ def get_free_port():
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def get_local_ip_auto() -> str:
|
||||
interface = os.environ.get("SGLANG_LOCAL_IP_NIC", None)
|
||||
return (
|
||||
get_local_ip_by_nic(interface)
|
||||
if interface is not None
|
||||
else get_local_ip_by_remote()
|
||||
)
|
||||
|
||||
|
||||
def get_local_ip_by_nic(interface: str) -> str:
|
||||
try:
|
||||
import netifaces
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Environment variable SGLANG_LOCAL_IP_NIC requires package netifaces, please install it through 'pip install netifaces'"
|
||||
) from e
|
||||
|
||||
try:
|
||||
addresses = netifaces.ifaddresses(interface)
|
||||
if netifaces.AF_INET in addresses:
|
||||
for addr_info in addresses[netifaces.AF_INET]:
|
||||
ip = addr_info.get("addr")
|
||||
if ip and ip != "127.0.0.1" and ip != "0.0.0.0":
|
||||
return ip
|
||||
if netifaces.AF_INET6 in addresses:
|
||||
for addr_info in addresses[netifaces.AF_INET6]:
|
||||
ip = addr_info.get("addr")
|
||||
if ip and not ip.startswith("fe80::") and ip != "::1":
|
||||
return ip.split("%")[0]
|
||||
except (ValueError, OSError) as e:
|
||||
raise ValueError(
|
||||
"Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
|
||||
)
|
||||
|
||||
# Fallback
|
||||
return get_local_ip_by_remote()
|
||||
|
||||
|
||||
def get_local_ip_by_remote() -> str:
|
||||
# try ipv4
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
|
||||
Reference in New Issue
Block a user