From 56b991b12dcb320b465c594b9ed7f2290ccdbaa5 Mon Sep 17 00:00:00 2001 From: Jimmy <29097382+jinmingyi1998@users.noreply.github.com> Date: Fri, 19 Sep 2025 13:35:26 +0800 Subject: [PATCH] [Feature]feat(get_ip): unify get_ip_xxx (#10081) --- .../sglang/srt/disaggregation/ascend/conn.py | 4 +- .../device_communicators/shm_broadcast.py | 6 +- python/sglang/srt/utils.py | 108 +++++++++--------- 3 files changed, 58 insertions(+), 60 deletions(-) diff --git a/python/sglang/srt/disaggregation/ascend/conn.py b/python/sglang/srt/disaggregation/ascend/conn.py index b0009fc7c..661a0cc4e 100644 --- a/python/sglang/srt/disaggregation/ascend/conn.py +++ b/python/sglang/srt/disaggregation/ascend/conn.py @@ -13,7 +13,7 @@ from sglang.srt.disaggregation.mooncake.conn import ( MooncakeKVReceiver, MooncakeKVSender, ) -from sglang.srt.utils import get_local_ip_by_remote +from sglang.srt.utils import get_local_ip_auto logger = logging.getLogger(__name__) @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) class AscendKVManager(MooncakeKVManager): def init_engine(self): # TransferEngine initialized on ascend. - local_ip = get_local_ip_by_remote() + local_ip = get_local_ip_auto() self.engine = AscendTransferEngine( hostname=local_ip, npu_id=self.kv_args.gpu_id, diff --git a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py index e5b59e7cc..e956a2592 100644 --- a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py +++ b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py @@ -18,7 +18,7 @@ from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore from sglang.srt.utils import ( format_tcp_address, - get_ip, + get_local_ip_auto, get_open_port, is_valid_ipv6_address, ) @@ -191,7 +191,9 @@ class MessageQueue: self.n_remote_reader = n_remote_reader if connect_ip is None: - connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1" + connect_ip = ( + get_local_ip_auto("0.0.0.0") if n_remote_reader > 0 else "127.0.0.1" + ) context = Context() diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 1c9de7b7b..d6c939227 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2005,48 +2005,11 @@ def set_uvicorn_logging_configs(): LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S" -def get_ip() -> str: - # SGLANG_HOST_IP env can be ignore +def get_ip() -> Optional[str]: host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "") if host_ip: return host_ip - - # IP is not set, try to get it from the network interface - - # try ipv4 - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - try: - s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable - return s.getsockname()[0] - except Exception: - pass - - # try ipv6 - try: - s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) - # Google's public DNS server, see - # https://developers.google.com/speed/public-dns/docs/using#addresses - s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable - return s.getsockname()[0] - except Exception: - pass - - # try using hostname - hostname = socket.gethostname() - try: - ip_addr = socket.gethostbyname(hostname) - warnings.warn("using local ip address: {}".format(ip_addr)) - return ip_addr - except Exception: - pass - - warnings.warn( - "Failed to get the IP address, using 0.0.0.0 by default." - "The value can be set by the environment variable" - " SGLANG_HOST_IP or HOST_IP.", - stacklevel=2, - ) - return "0.0.0.0" + return None def get_open_port() -> int: @@ -2305,16 +2268,9 @@ def bind_or_assign(target, source): return source -def get_local_ip_auto() -> str: - interface = os.environ.get("SGLANG_LOCAL_IP_NIC", None) - return ( - get_local_ip_by_nic(interface) - if interface is not None - else get_local_ip_by_remote() - ) - - -def get_local_ip_by_nic(interface: str) -> str: +def get_local_ip_by_nic(interface: str = None) -> Optional[str]: + if not (interface := interface or os.environ.get("SGLANG_LOCAL_IP_NIC", None)): + return None try: import netifaces except ImportError as e: @@ -2335,15 +2291,13 @@ def get_local_ip_by_nic(interface: str) -> str: if ip and not ip.startswith("fe80::") and ip != "::1": return ip.split("%")[0] except (ValueError, OSError) as e: - raise ValueError( - "Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly." + logger.warning( + f"{e} Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly." ) - - # Fallback - return get_local_ip_by_remote() + return None -def get_local_ip_by_remote() -> str: +def get_local_ip_by_remote() -> Optional[str]: # try ipv4 s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: @@ -2368,7 +2322,49 @@ def get_local_ip_by_remote() -> str: s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable return s.getsockname()[0] except Exception: - raise ValueError("Can not get local ip") + logger.warning("Can not get local ip by remote") + return None + + +def get_local_ip_auto(fallback: str = None) -> str: + """ + Automatically detect the local IP address using multiple fallback strategies. + + This function attempts to obtain the local IP address through several methods. + If all methods fail, it returns the specified fallback value or raises an exception. + + Args: + fallback (str, optional): Fallback IP address to return if all detection + methods fail. For server applications, explicitly set this to + "0.0.0.0" (IPv4) or "::" (IPv6) to bind to all available interfaces. + Defaults to None. + + Returns: + str: The detected local IP address, or the fallback value if detection fails. + + Raises: + ValueError: If IP detection fails and no fallback value is provided. + + Note: + The function tries detection methods in the following order: + 1. Direct IP detection via get_ip() + 2. Network interface enumeration via get_local_ip_by_nic() + 3. Remote connection method via get_local_ip_by_remote() + """ + if ip := get_ip(): + return ip + logger.debug("get_ip failed") + # Fallback + if ip := get_local_ip_by_nic(): + return ip + logger.debug("get_local_ip_by_nic failed") + # Fallback + if ip := get_local_ip_by_remote(): + return ip + logger.debug("get_local_ip_by_remote failed") + if fallback: + return fallback + raise ValueError("Can not get local ip") def is_page_size_one(server_args):