[PD] Support get local ip from NIC for PD disaggregation (#7237)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
@@ -35,12 +35,7 @@ from sglang.srt.disaggregation.common.utils import (
|
|||||||
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
||||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import get_free_port, get_int_env_var, get_ip, get_local_ip_auto
|
||||||
get_free_port,
|
|
||||||
get_int_env_var,
|
|
||||||
get_ip,
|
|
||||||
get_local_ip_by_remote,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -130,8 +125,9 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
is_mla_backend: Optional[bool] = False,
|
is_mla_backend: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
self.kv_args = args
|
self.kv_args = args
|
||||||
|
self.local_ip = get_local_ip_auto()
|
||||||
self.engine = MooncakeTransferEngine(
|
self.engine = MooncakeTransferEngine(
|
||||||
hostname=get_local_ip_by_remote(),
|
hostname=self.local_ip,
|
||||||
gpu_id=self.kv_args.gpu_id,
|
gpu_id=self.kv_args.gpu_id,
|
||||||
ib_device=self.kv_args.ib_device,
|
ib_device=self.kv_args.ib_device,
|
||||||
)
|
)
|
||||||
@@ -432,7 +428,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
|
|
||||||
def start_prefill_thread(self):
|
def start_prefill_thread(self):
|
||||||
self.rank_port = get_free_port()
|
self.rank_port = get_free_port()
|
||||||
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
|
self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
|
||||||
|
|
||||||
def bootstrap_thread():
|
def bootstrap_thread():
|
||||||
"""This thread recvs pre-alloc notification from the decode engine"""
|
"""This thread recvs pre-alloc notification from the decode engine"""
|
||||||
@@ -471,7 +467,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
|
|
||||||
def start_decode_thread(self):
|
def start_decode_thread(self):
|
||||||
self.rank_port = get_free_port()
|
self.rank_port = get_free_port()
|
||||||
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
|
self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
|
||||||
|
|
||||||
def decode_thread():
|
def decode_thread():
|
||||||
while True:
|
while True:
|
||||||
@@ -620,7 +616,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
"role": "Prefill",
|
"role": "Prefill",
|
||||||
"tp_size": self.tp_size,
|
"tp_size": self.tp_size,
|
||||||
"dp_size": self.dp_size,
|
"dp_size": self.dp_size,
|
||||||
"rank_ip": get_local_ip_by_remote(),
|
"rank_ip": self.local_ip,
|
||||||
"rank_port": self.rank_port,
|
"rank_port": self.rank_port,
|
||||||
"engine_rank": self.kv_args.engine_rank,
|
"engine_rank": self.kv_args.engine_rank,
|
||||||
}
|
}
|
||||||
@@ -953,7 +949,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
sock.send_multipart(
|
sock.send_multipart(
|
||||||
[
|
[
|
||||||
"None".encode("ascii"),
|
"None".encode("ascii"),
|
||||||
get_local_ip_by_remote().encode("ascii"),
|
self.kv_mgr.local_ip.encode("ascii"),
|
||||||
str(self.kv_mgr.rank_port).encode("ascii"),
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
||||||
self.session_id.encode("ascii"),
|
self.session_id.encode("ascii"),
|
||||||
packed_kv_data_ptrs,
|
packed_kv_data_ptrs,
|
||||||
@@ -983,7 +979,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
sock.send_multipart(
|
sock.send_multipart(
|
||||||
[
|
[
|
||||||
str(self.bootstrap_room).encode("ascii"),
|
str(self.bootstrap_room).encode("ascii"),
|
||||||
get_local_ip_by_remote().encode("ascii"),
|
self.kv_mgr.local_ip.encode("ascii"),
|
||||||
str(self.kv_mgr.rank_port).encode("ascii"),
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
||||||
self.session_id.encode("ascii"),
|
self.session_id.encode("ascii"),
|
||||||
kv_indices.tobytes() if not is_dummy else b"",
|
kv_indices.tobytes() if not is_dummy else b"",
|
||||||
|
|||||||
@@ -2141,6 +2141,44 @@ def get_free_port():
|
|||||||
return s.getsockname()[1]
|
return s.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_ip_auto() -> str:
|
||||||
|
interface = os.environ.get("SGLANG_LOCAL_IP_NIC", None)
|
||||||
|
return (
|
||||||
|
get_local_ip_by_nic(interface)
|
||||||
|
if interface is not None
|
||||||
|
else get_local_ip_by_remote()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_ip_by_nic(interface: str) -> str:
|
||||||
|
try:
|
||||||
|
import netifaces
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Environment variable SGLANG_LOCAL_IP_NIC requires package netifaces, please install it through 'pip install netifaces'"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
try:
|
||||||
|
addresses = netifaces.ifaddresses(interface)
|
||||||
|
if netifaces.AF_INET in addresses:
|
||||||
|
for addr_info in addresses[netifaces.AF_INET]:
|
||||||
|
ip = addr_info.get("addr")
|
||||||
|
if ip and ip != "127.0.0.1" and ip != "0.0.0.0":
|
||||||
|
return ip
|
||||||
|
if netifaces.AF_INET6 in addresses:
|
||||||
|
for addr_info in addresses[netifaces.AF_INET6]:
|
||||||
|
ip = addr_info.get("addr")
|
||||||
|
if ip and not ip.startswith("fe80::") and ip != "::1":
|
||||||
|
return ip.split("%")[0]
|
||||||
|
except (ValueError, OSError) as e:
|
||||||
|
raise ValueError(
|
||||||
|
"Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback
|
||||||
|
return get_local_ip_by_remote()
|
||||||
|
|
||||||
|
|
||||||
def get_local_ip_by_remote() -> str:
|
def get_local_ip_by_remote() -> str:
|
||||||
# try ipv4
|
# try ipv4
|
||||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
|||||||
Reference in New Issue
Block a user