IPv6 support (#3949)
Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>
This commit is contained in:
@@ -24,6 +24,7 @@ from typing import List, Optional
|
|||||||
from sglang.srt.hf_transformers_utils import check_gguf_file
|
from sglang.srt.hf_transformers_utils import check_gguf_file
|
||||||
from sglang.srt.reasoning_parser import ReasoningParser
|
from sglang.srt.reasoning_parser import ReasoningParser
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
|
configure_ipv6,
|
||||||
get_amdgpu_memory_capacity,
|
get_amdgpu_memory_capacity,
|
||||||
get_device,
|
get_device,
|
||||||
get_hpu_memory_capacity,
|
get_hpu_memory_capacity,
|
||||||
@@ -52,7 +53,7 @@ class ServerArgs:
|
|||||||
dtype: str = "auto"
|
dtype: str = "auto"
|
||||||
kv_cache_dtype: str = "auto"
|
kv_cache_dtype: str = "auto"
|
||||||
quantization: Optional[str] = None
|
quantization: Optional[str] = None
|
||||||
quantization_param_path: nullable_str = None
|
quantization_param_path: Optional[str] = None
|
||||||
context_length: Optional[int] = None
|
context_length: Optional[int] = None
|
||||||
device: Optional[str] = None
|
device: Optional[str] = None
|
||||||
served_model_name: Optional[str] = None
|
served_model_name: Optional[str] = None
|
||||||
@@ -140,7 +141,7 @@ class ServerArgs:
|
|||||||
|
|
||||||
# Double Sparsity
|
# Double Sparsity
|
||||||
enable_double_sparsity: bool = False
|
enable_double_sparsity: bool = False
|
||||||
ds_channel_config_path: str = None
|
ds_channel_config_path: Optional[str] = None
|
||||||
ds_heavy_channel_num: int = 32
|
ds_heavy_channel_num: int = 32
|
||||||
ds_heavy_token_num: int = 256
|
ds_heavy_token_num: int = 256
|
||||||
ds_heavy_channel_type: str = "qk"
|
ds_heavy_channel_type: str = "qk"
|
||||||
@@ -173,7 +174,7 @@ class ServerArgs:
|
|||||||
enable_memory_saver: bool = False
|
enable_memory_saver: bool = False
|
||||||
allow_auto_truncate: bool = False
|
allow_auto_truncate: bool = False
|
||||||
enable_custom_logit_processor: bool = False
|
enable_custom_logit_processor: bool = False
|
||||||
tool_call_parser: str = None
|
tool_call_parser: Optional[str] = None
|
||||||
enable_hierarchical_cache: bool = False
|
enable_hierarchical_cache: bool = False
|
||||||
hicache_ratio: float = 2.0
|
hicache_ratio: float = 2.0
|
||||||
enable_flashinfer_mla: bool = False
|
enable_flashinfer_mla: bool = False
|
||||||
@@ -1205,8 +1206,12 @@ class PortArgs:
|
|||||||
# DP attention. Use TCP + port to handle both single-node and multi-node.
|
# DP attention. Use TCP + port to handle both single-node and multi-node.
|
||||||
if server_args.nnodes == 1 and server_args.dist_init_addr is None:
|
if server_args.nnodes == 1 and server_args.dist_init_addr is None:
|
||||||
dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
|
dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
|
||||||
|
elif server_args.dist_init_addr.startswith("["): # ipv6 address
|
||||||
|
port_num, host = configure_ipv6(server_args.dist_init_addr)
|
||||||
|
dist_init_addr = (host, str(port_num))
|
||||||
else:
|
else:
|
||||||
dist_init_addr = server_args.dist_init_addr.split(":")
|
dist_init_addr = server_args.dist_init_addr.split(":")
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
len(dist_init_addr) == 2
|
len(dist_init_addr) == 2
|
||||||
), "please provide --dist-init-addr as host:port of head node"
|
), "please provide --dist-init-addr as host:port of head node"
|
||||||
|
|||||||
@@ -1630,6 +1630,38 @@ def is_valid_ipv6_address(address: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def configure_ipv6(dist_init_addr):
|
||||||
|
addr = dist_init_addr
|
||||||
|
end = addr.find("]")
|
||||||
|
if end == -1:
|
||||||
|
raise ValueError("invalid IPv6 address format: missing ']'")
|
||||||
|
|
||||||
|
host = addr[: end + 1]
|
||||||
|
|
||||||
|
# this only validates the address without brackets: we still need the below checks.
|
||||||
|
# if it's invalid, immediately raise an error so we know it's not formatting issues.
|
||||||
|
if not is_valid_ipv6_address(host[1:end]):
|
||||||
|
raise ValueError(f"invalid IPv6 address: {host}")
|
||||||
|
|
||||||
|
port_str = None
|
||||||
|
if len(addr) > end + 1:
|
||||||
|
if addr[end + 1] == ":":
|
||||||
|
port_str = addr[end + 2 :]
|
||||||
|
else:
|
||||||
|
raise ValueError("received IPv6 address format: expected ':' after ']'")
|
||||||
|
|
||||||
|
if not port_str:
|
||||||
|
raise ValueError(
|
||||||
|
"a port must be specified in IPv6 address (format: [ipv6]:port)"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
port = int(port_str)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"invalid port in IPv6 address: '{port_str}'")
|
||||||
|
return port, host
|
||||||
|
|
||||||
|
|
||||||
def rank0_print(msg: str):
|
def rank0_print(msg: str):
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
import unittest
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from sglang.srt.server_args import prepare_server_args
|
from sglang.srt.server_args import PortArgs, ServerArgs, prepare_server_args
|
||||||
from sglang.test.test_utils import CustomTestCase
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
|
|
||||||
@@ -22,5 +23,239 @@ class TestPrepareServerArgs(CustomTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPortArgs(unittest.TestCase):
|
||||||
|
@patch("sglang.srt.server_args.is_port_available")
|
||||||
|
@patch("sglang.srt.server_args.tempfile.NamedTemporaryFile")
|
||||||
|
def test_init_new_standard_case(self, mock_temp_file, mock_is_port_available):
|
||||||
|
|
||||||
|
mock_is_port_available.return_value = True
|
||||||
|
mock_temp_file.return_value.name = "temp_file"
|
||||||
|
|
||||||
|
server_args = MagicMock()
|
||||||
|
server_args.port = 30000
|
||||||
|
server_args.enable_dp_attention = False
|
||||||
|
|
||||||
|
port_args = PortArgs.init_new(server_args)
|
||||||
|
|
||||||
|
self.assertTrue(port_args.tokenizer_ipc_name.startswith("ipc://"))
|
||||||
|
self.assertTrue(port_args.scheduler_input_ipc_name.startswith("ipc://"))
|
||||||
|
self.assertTrue(port_args.detokenizer_ipc_name.startswith("ipc://"))
|
||||||
|
self.assertIsInstance(port_args.nccl_port, int)
|
||||||
|
|
||||||
|
@patch("sglang.srt.server_args.is_port_available")
|
||||||
|
def test_init_new_with_single_node_dp_attention(self, mock_is_port_available):
|
||||||
|
|
||||||
|
mock_is_port_available.return_value = True
|
||||||
|
|
||||||
|
server_args = MagicMock()
|
||||||
|
server_args.port = 30000
|
||||||
|
server_args.enable_dp_attention = True
|
||||||
|
server_args.nnodes = 1
|
||||||
|
server_args.dist_init_addr = None
|
||||||
|
|
||||||
|
port_args = PortArgs.init_new(server_args)
|
||||||
|
|
||||||
|
self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://127.0.0.1:"))
|
||||||
|
self.assertTrue(
|
||||||
|
port_args.scheduler_input_ipc_name.startswith("tcp://127.0.0.1:")
|
||||||
|
)
|
||||||
|
self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://127.0.0.1:"))
|
||||||
|
self.assertIsInstance(port_args.nccl_port, int)
|
||||||
|
|
||||||
|
@patch("sglang.srt.server_args.is_port_available")
|
||||||
|
def test_init_new_with_dp_rank(self, mock_is_port_available):
|
||||||
|
|
||||||
|
mock_is_port_available.return_value = True
|
||||||
|
|
||||||
|
server_args = MagicMock()
|
||||||
|
server_args.port = 30000
|
||||||
|
server_args.enable_dp_attention = True
|
||||||
|
server_args.nnodes = 1
|
||||||
|
server_args.dist_init_addr = "192.168.1.1:25000"
|
||||||
|
|
||||||
|
port_args = PortArgs.init_new(server_args, dp_rank=2)
|
||||||
|
|
||||||
|
self.assertTrue(port_args.scheduler_input_ipc_name.endswith(":25006"))
|
||||||
|
|
||||||
|
self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://192.168.1.1:"))
|
||||||
|
self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://192.168.1.1:"))
|
||||||
|
self.assertIsInstance(port_args.nccl_port, int)
|
||||||
|
|
||||||
|
@patch("sglang.srt.server_args.is_port_available")
|
||||||
|
def test_init_new_with_ipv4_address(self, mock_is_port_available):
|
||||||
|
|
||||||
|
mock_is_port_available.return_value = True
|
||||||
|
|
||||||
|
server_args = MagicMock()
|
||||||
|
server_args.port = 30000
|
||||||
|
server_args.enable_dp_attention = True
|
||||||
|
server_args.nnodes = 2
|
||||||
|
server_args.dist_init_addr = "192.168.1.1:25000"
|
||||||
|
|
||||||
|
port_args = PortArgs.init_new(server_args)
|
||||||
|
|
||||||
|
self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://192.168.1.1:"))
|
||||||
|
self.assertTrue(
|
||||||
|
port_args.scheduler_input_ipc_name.startswith("tcp://192.168.1.1:")
|
||||||
|
)
|
||||||
|
self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://192.168.1.1:"))
|
||||||
|
self.assertIsInstance(port_args.nccl_port, int)
|
||||||
|
|
||||||
|
@patch("sglang.srt.server_args.is_port_available")
|
||||||
|
def test_init_new_with_malformed_ipv4_address(self, mock_is_port_available):
|
||||||
|
|
||||||
|
mock_is_port_available.return_value = True
|
||||||
|
|
||||||
|
server_args = MagicMock()
|
||||||
|
server_args.port = 30000
|
||||||
|
server_args.enable_dp_attention = True
|
||||||
|
server_args.nnodes = 2
|
||||||
|
server_args.dist_init_addr = "192.168.1.1"
|
||||||
|
|
||||||
|
with self.assertRaises(AssertionError) as context:
|
||||||
|
PortArgs.init_new(server_args)
|
||||||
|
|
||||||
|
self.assertIn(
|
||||||
|
"please provide --dist-init-addr as host:port", str(context.exception)
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("sglang.srt.server_args.is_port_available")
|
||||||
|
def test_init_new_with_malformed_ipv4_address_invalid_port(
|
||||||
|
self, mock_is_port_available
|
||||||
|
):
|
||||||
|
|
||||||
|
mock_is_port_available.return_value = True
|
||||||
|
|
||||||
|
server_args = MagicMock()
|
||||||
|
server_args.port = 30000
|
||||||
|
server_args.enable_dp_attention = True
|
||||||
|
server_args.nnodes = 2
|
||||||
|
server_args.dist_init_addr = "192.168.1.1:abc"
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError) as context:
|
||||||
|
PortArgs.init_new(server_args)
|
||||||
|
|
||||||
|
@patch("sglang.srt.server_args.is_port_available")
|
||||||
|
@patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True)
|
||||||
|
def test_init_new_with_ipv6_address(
|
||||||
|
self, mock_is_valid_ipv6, mock_is_port_available
|
||||||
|
):
|
||||||
|
|
||||||
|
mock_is_port_available.return_value = True
|
||||||
|
|
||||||
|
server_args = MagicMock()
|
||||||
|
server_args.port = 30000
|
||||||
|
server_args.enable_dp_attention = True
|
||||||
|
server_args.nnodes = 2
|
||||||
|
server_args.dist_init_addr = "[2001:db8::1]:25000"
|
||||||
|
|
||||||
|
port_args = PortArgs.init_new(server_args)
|
||||||
|
|
||||||
|
self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://[2001:db8::1]:"))
|
||||||
|
self.assertTrue(
|
||||||
|
port_args.scheduler_input_ipc_name.startswith("tcp://[2001:db8::1]:")
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
port_args.detokenizer_ipc_name.startswith("tcp://[2001:db8::1]:")
|
||||||
|
)
|
||||||
|
self.assertIsInstance(port_args.nccl_port, int)
|
||||||
|
|
||||||
|
@patch("sglang.srt.server_args.is_port_available")
|
||||||
|
@patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=False)
|
||||||
|
def test_init_new_with_invalid_ipv6_address(
|
||||||
|
self, mock_is_valid_ipv6, mock_is_port_available
|
||||||
|
):
|
||||||
|
|
||||||
|
mock_is_port_available.return_value = True
|
||||||
|
|
||||||
|
server_args = MagicMock()
|
||||||
|
server_args.port = 30000
|
||||||
|
server_args.enable_dp_attention = True
|
||||||
|
server_args.nnodes = 2
|
||||||
|
server_args.dist_init_addr = "[invalid-ipv6]:25000"
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError) as context:
|
||||||
|
PortArgs.init_new(server_args)
|
||||||
|
|
||||||
|
self.assertIn("invalid IPv6 address", str(context.exception))
|
||||||
|
|
||||||
|
@patch("sglang.srt.server_args.is_port_available")
|
||||||
|
def test_init_new_with_malformed_ipv6_address_missing_bracket(
|
||||||
|
self, mock_is_port_available
|
||||||
|
):
|
||||||
|
|
||||||
|
mock_is_port_available.return_value = True
|
||||||
|
|
||||||
|
server_args = MagicMock()
|
||||||
|
server_args.port = 30000
|
||||||
|
server_args.enable_dp_attention = True
|
||||||
|
server_args.nnodes = 2
|
||||||
|
server_args.dist_init_addr = "[2001:db8::1:25000"
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError) as context:
|
||||||
|
PortArgs.init_new(server_args)
|
||||||
|
|
||||||
|
self.assertIn("invalid IPv6 address format", str(context.exception))
|
||||||
|
|
||||||
|
@patch("sglang.srt.server_args.is_port_available")
|
||||||
|
@patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True)
|
||||||
|
def test_init_new_with_malformed_ipv6_address_missing_port(
|
||||||
|
self, mock_is_valid_ipv6, mock_is_port_available
|
||||||
|
):
|
||||||
|
|
||||||
|
mock_is_port_available.return_value = True
|
||||||
|
|
||||||
|
server_args = MagicMock()
|
||||||
|
server_args.port = 30000
|
||||||
|
server_args.enable_dp_attention = True
|
||||||
|
server_args.nnodes = 2
|
||||||
|
server_args.dist_init_addr = "[2001:db8::1]"
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError) as context:
|
||||||
|
PortArgs.init_new(server_args)
|
||||||
|
|
||||||
|
self.assertIn(
|
||||||
|
"a port must be specified in IPv6 address", str(context.exception)
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("sglang.srt.server_args.is_port_available")
|
||||||
|
@patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True)
|
||||||
|
def test_init_new_with_malformed_ipv6_address_invalid_port(
|
||||||
|
self, mock_is_valid_ipv6, mock_is_port_available
|
||||||
|
):
|
||||||
|
|
||||||
|
mock_is_port_available.return_value = True
|
||||||
|
|
||||||
|
server_args = MagicMock()
|
||||||
|
server_args.port = 30000
|
||||||
|
server_args.enable_dp_attention = True
|
||||||
|
server_args.nnodes = 2
|
||||||
|
server_args.dist_init_addr = "[2001:db8::1]:abcde"
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError) as context:
|
||||||
|
PortArgs.init_new(server_args)
|
||||||
|
|
||||||
|
self.assertIn("invalid port in IPv6 address", str(context.exception))
|
||||||
|
|
||||||
|
@patch("sglang.srt.server_args.is_port_available")
|
||||||
|
@patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True)
|
||||||
|
def test_init_new_with_malformed_ipv6_address_wrong_separator(
|
||||||
|
self, mock_is_valid_ipv6, mock_is_port_available
|
||||||
|
):
|
||||||
|
|
||||||
|
mock_is_port_available.return_value = True
|
||||||
|
|
||||||
|
server_args = MagicMock()
|
||||||
|
server_args.port = 30000
|
||||||
|
server_args.enable_dp_attention = True
|
||||||
|
server_args.nnodes = 2
|
||||||
|
server_args.dist_init_addr = "[2001:db8::1]#25000"
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError) as context:
|
||||||
|
PortArgs.init_new(server_args)
|
||||||
|
|
||||||
|
self.assertIn("expected ':' after ']'", str(context.exception))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user