From e2e2ab70e062e0da78a4ae339542cb38080e8d28 Mon Sep 17 00:00:00 2001 From: Vincent Date: Fri, 28 Mar 2025 00:42:13 -0400 Subject: [PATCH] IPv6 support (#3949) Signed-off-by: Brayden Zhong --- python/sglang/srt/server_args.py | 11 +- python/sglang/srt/utils.py | 32 +++++ test/srt/test_server_args.py | 237 ++++++++++++++++++++++++++++++- 3 files changed, 276 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6a0166b41..a04882736 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -24,6 +24,7 @@ from typing import List, Optional from sglang.srt.hf_transformers_utils import check_gguf_file from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.utils import ( + configure_ipv6, get_amdgpu_memory_capacity, get_device, get_hpu_memory_capacity, @@ -52,7 +53,7 @@ class ServerArgs: dtype: str = "auto" kv_cache_dtype: str = "auto" quantization: Optional[str] = None - quantization_param_path: nullable_str = None + quantization_param_path: Optional[str] = None context_length: Optional[int] = None device: Optional[str] = None served_model_name: Optional[str] = None @@ -140,7 +141,7 @@ class ServerArgs: # Double Sparsity enable_double_sparsity: bool = False - ds_channel_config_path: str = None + ds_channel_config_path: Optional[str] = None ds_heavy_channel_num: int = 32 ds_heavy_token_num: int = 256 ds_heavy_channel_type: str = "qk" @@ -173,7 +174,7 @@ class ServerArgs: enable_memory_saver: bool = False allow_auto_truncate: bool = False enable_custom_logit_processor: bool = False - tool_call_parser: str = None + tool_call_parser: Optional[str] = None enable_hierarchical_cache: bool = False hicache_ratio: float = 2.0 enable_flashinfer_mla: bool = False @@ -1205,8 +1206,12 @@ class PortArgs: # DP attention. Use TCP + port to handle both single-node and multi-node. if server_args.nnodes == 1 and server_args.dist_init_addr is None: dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA) + elif server_args.dist_init_addr.startswith("["): # ipv6 address + port_num, host = configure_ipv6(server_args.dist_init_addr) + dist_init_addr = (host, str(port_num)) else: dist_init_addr = server_args.dist_init_addr.split(":") + assert ( len(dist_init_addr) == 2 ), "please provide --dist-init-addr as host:port of head node" diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index f7983be51..f095f71b7 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1630,6 +1630,38 @@ def is_valid_ipv6_address(address: str) -> bool: return False +def configure_ipv6(dist_init_addr): + addr = dist_init_addr + end = addr.find("]") + if end == -1: + raise ValueError("invalid IPv6 address format: missing ']'") + + host = addr[: end + 1] + + # this only validates the address without brackets: we still need the below checks. + # if it's invalid, immediately raise an error so we know it's not formatting issues. + if not is_valid_ipv6_address(host[1:end]): + raise ValueError(f"invalid IPv6 address: {host}") + + port_str = None + if len(addr) > end + 1: + if addr[end + 1] == ":": + port_str = addr[end + 2 :] + else: + raise ValueError("received IPv6 address format: expected ':' after ']'") + + if not port_str: + raise ValueError( + "a port must be specified in IPv6 address (format: [ipv6]:port)" + ) + + try: + port = int(port_str) + except ValueError: + raise ValueError(f"invalid port in IPv6 address: '{port_str}'") + return port, host + + def rank0_print(msg: str): from sglang.srt.distributed import get_tensor_model_parallel_rank diff --git a/test/srt/test_server_args.py b/test/srt/test_server_args.py index bb5618a15..64d1442c7 100644 --- a/test/srt/test_server_args.py +++ b/test/srt/test_server_args.py @@ -1,7 +1,8 @@ import json import unittest +from unittest.mock import MagicMock, patch -from sglang.srt.server_args import prepare_server_args +from sglang.srt.server_args import PortArgs, ServerArgs, prepare_server_args from sglang.test.test_utils import CustomTestCase @@ -22,5 +23,239 @@ class TestPrepareServerArgs(CustomTestCase): ) +class TestPortArgs(unittest.TestCase): + @patch("sglang.srt.server_args.is_port_available") + @patch("sglang.srt.server_args.tempfile.NamedTemporaryFile") + def test_init_new_standard_case(self, mock_temp_file, mock_is_port_available): + + mock_is_port_available.return_value = True + mock_temp_file.return_value.name = "temp_file" + + server_args = MagicMock() + server_args.port = 30000 + server_args.enable_dp_attention = False + + port_args = PortArgs.init_new(server_args) + + self.assertTrue(port_args.tokenizer_ipc_name.startswith("ipc://")) + self.assertTrue(port_args.scheduler_input_ipc_name.startswith("ipc://")) + self.assertTrue(port_args.detokenizer_ipc_name.startswith("ipc://")) + self.assertIsInstance(port_args.nccl_port, int) + + @patch("sglang.srt.server_args.is_port_available") + def test_init_new_with_single_node_dp_attention(self, mock_is_port_available): + + mock_is_port_available.return_value = True + + server_args = MagicMock() + server_args.port = 30000 + server_args.enable_dp_attention = True + server_args.nnodes = 1 + server_args.dist_init_addr = None + + port_args = PortArgs.init_new(server_args) + + self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://127.0.0.1:")) + self.assertTrue( + port_args.scheduler_input_ipc_name.startswith("tcp://127.0.0.1:") + ) + self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://127.0.0.1:")) + self.assertIsInstance(port_args.nccl_port, int) + + @patch("sglang.srt.server_args.is_port_available") + def test_init_new_with_dp_rank(self, mock_is_port_available): + + mock_is_port_available.return_value = True + + server_args = MagicMock() + server_args.port = 30000 + server_args.enable_dp_attention = True + server_args.nnodes = 1 + server_args.dist_init_addr = "192.168.1.1:25000" + + port_args = PortArgs.init_new(server_args, dp_rank=2) + + self.assertTrue(port_args.scheduler_input_ipc_name.endswith(":25006")) + + self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://192.168.1.1:")) + self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://192.168.1.1:")) + self.assertIsInstance(port_args.nccl_port, int) + + @patch("sglang.srt.server_args.is_port_available") + def test_init_new_with_ipv4_address(self, mock_is_port_available): + + mock_is_port_available.return_value = True + + server_args = MagicMock() + server_args.port = 30000 + server_args.enable_dp_attention = True + server_args.nnodes = 2 + server_args.dist_init_addr = "192.168.1.1:25000" + + port_args = PortArgs.init_new(server_args) + + self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://192.168.1.1:")) + self.assertTrue( + port_args.scheduler_input_ipc_name.startswith("tcp://192.168.1.1:") + ) + self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://192.168.1.1:")) + self.assertIsInstance(port_args.nccl_port, int) + + @patch("sglang.srt.server_args.is_port_available") + def test_init_new_with_malformed_ipv4_address(self, mock_is_port_available): + + mock_is_port_available.return_value = True + + server_args = MagicMock() + server_args.port = 30000 + server_args.enable_dp_attention = True + server_args.nnodes = 2 + server_args.dist_init_addr = "192.168.1.1" + + with self.assertRaises(AssertionError) as context: + PortArgs.init_new(server_args) + + self.assertIn( + "please provide --dist-init-addr as host:port", str(context.exception) + ) + + @patch("sglang.srt.server_args.is_port_available") + def test_init_new_with_malformed_ipv4_address_invalid_port( + self, mock_is_port_available + ): + + mock_is_port_available.return_value = True + + server_args = MagicMock() + server_args.port = 30000 + server_args.enable_dp_attention = True + server_args.nnodes = 2 + server_args.dist_init_addr = "192.168.1.1:abc" + + with self.assertRaises(ValueError) as context: + PortArgs.init_new(server_args) + + @patch("sglang.srt.server_args.is_port_available") + @patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True) + def test_init_new_with_ipv6_address( + self, mock_is_valid_ipv6, mock_is_port_available + ): + + mock_is_port_available.return_value = True + + server_args = MagicMock() + server_args.port = 30000 + server_args.enable_dp_attention = True + server_args.nnodes = 2 + server_args.dist_init_addr = "[2001:db8::1]:25000" + + port_args = PortArgs.init_new(server_args) + + self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://[2001:db8::1]:")) + self.assertTrue( + port_args.scheduler_input_ipc_name.startswith("tcp://[2001:db8::1]:") + ) + self.assertTrue( + port_args.detokenizer_ipc_name.startswith("tcp://[2001:db8::1]:") + ) + self.assertIsInstance(port_args.nccl_port, int) + + @patch("sglang.srt.server_args.is_port_available") + @patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=False) + def test_init_new_with_invalid_ipv6_address( + self, mock_is_valid_ipv6, mock_is_port_available + ): + + mock_is_port_available.return_value = True + + server_args = MagicMock() + server_args.port = 30000 + server_args.enable_dp_attention = True + server_args.nnodes = 2 + server_args.dist_init_addr = "[invalid-ipv6]:25000" + + with self.assertRaises(ValueError) as context: + PortArgs.init_new(server_args) + + self.assertIn("invalid IPv6 address", str(context.exception)) + + @patch("sglang.srt.server_args.is_port_available") + def test_init_new_with_malformed_ipv6_address_missing_bracket( + self, mock_is_port_available + ): + + mock_is_port_available.return_value = True + + server_args = MagicMock() + server_args.port = 30000 + server_args.enable_dp_attention = True + server_args.nnodes = 2 + server_args.dist_init_addr = "[2001:db8::1:25000" + + with self.assertRaises(ValueError) as context: + PortArgs.init_new(server_args) + + self.assertIn("invalid IPv6 address format", str(context.exception)) + + @patch("sglang.srt.server_args.is_port_available") + @patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True) + def test_init_new_with_malformed_ipv6_address_missing_port( + self, mock_is_valid_ipv6, mock_is_port_available + ): + + mock_is_port_available.return_value = True + + server_args = MagicMock() + server_args.port = 30000 + server_args.enable_dp_attention = True + server_args.nnodes = 2 + server_args.dist_init_addr = "[2001:db8::1]" + + with self.assertRaises(ValueError) as context: + PortArgs.init_new(server_args) + + self.assertIn( + "a port must be specified in IPv6 address", str(context.exception) + ) + + @patch("sglang.srt.server_args.is_port_available") + @patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True) + def test_init_new_with_malformed_ipv6_address_invalid_port( + self, mock_is_valid_ipv6, mock_is_port_available + ): + + mock_is_port_available.return_value = True + + server_args = MagicMock() + server_args.port = 30000 + server_args.enable_dp_attention = True + server_args.nnodes = 2 + server_args.dist_init_addr = "[2001:db8::1]:abcde" + + with self.assertRaises(ValueError) as context: + PortArgs.init_new(server_args) + + self.assertIn("invalid port in IPv6 address", str(context.exception)) + + @patch("sglang.srt.server_args.is_port_available") + @patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True) + def test_init_new_with_malformed_ipv6_address_wrong_separator( + self, mock_is_valid_ipv6, mock_is_port_available + ): + + mock_is_port_available.return_value = True + + server_args = MagicMock() + server_args.port = 30000 + server_args.enable_dp_attention = True + server_args.nnodes = 2 + server_args.dist_init_addr = "[2001:db8::1]#25000" + + with self.assertRaises(ValueError) as context: + PortArgs.init_new(server_args) + + self.assertIn("expected ':' after ']'", str(context.exception)) + + if __name__ == "__main__": unittest.main()