diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py index 584530e69..3f5877ea3 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -131,4 +131,4 @@ class BaseKVReceiver(ABC): class BaseKVBootstrapServer(ABC): @abstractmethod - def __init__(self, port: int): ... + def __init__(self, host: str, port: int): ... diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index da6cc7217..b23cb2d68 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -47,6 +47,7 @@ class CommonKVManager(BaseKVManager): self.is_mla_backend = is_mla_backend self.disaggregation_mode = disaggregation_mode # for p/d multi node infer + self.bootstrap_host = server_args.host self.bootstrap_port = server_args.disaggregation_bootstrap_port self.dist_init_addr = server_args.dist_init_addr self.tp_size = server_args.tp_size @@ -72,6 +73,7 @@ class CommonKVManager(BaseKVManager): def _register_to_bootstrap(self): """Register KVSender to bootstrap server via HTTP POST.""" if self.dist_init_addr: + # multi node: bootstrap server's host is dist_init_addr if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6] if self.dist_init_addr.endswith("]"): host = self.dist_init_addr @@ -80,7 +82,8 @@ class CommonKVManager(BaseKVManager): else: host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0]) else: - host = get_ip() + # single node: bootstrap server's host is same as http server's host + host = self.bootstrap_host host = maybe_wrap_ipv6_address(host) bootstrap_server_url = f"{host}:{self.bootstrap_port}" @@ -308,7 +311,8 @@ class CommonKVReceiver(BaseKVReceiver): class CommonKVBootstrapServer(BaseKVBootstrapServer): - def __init__(self, port: int): + def __init__(self, host: str, port: int): + self.host = host self.port = port self.app = web.Application() self.store = dict() @@ -412,7 +416,7 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer): self._runner = web.AppRunner(self.app) self._loop.run_until_complete(self._runner.setup()) - site = web.TCPSite(self._runner, port=self.port) + site = web.TCPSite(self._runner, host=self.host, port=self.port) self._loop.run_until_complete(site.start()) self._loop.run_forever() except Exception as e: diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index b9ce9bbff..528719f28 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -24,7 +24,7 @@ import logging from collections import deque from dataclasses import dataclass from http import HTTPStatus -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union import torch from torch.distributed import ProcessGroup @@ -218,8 +218,10 @@ class DecodePreallocQueue: kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.gpu_id = self.scheduler.gpu_id - kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) - kv_manager = kv_manager_class( + kv_manager_class: Type[BaseKVManager] = get_kv_class( + self.transfer_backend, KVClassType.MANAGER + ) + kv_manager: BaseKVManager = kv_manager_class( kv_args, DisaggregationMode.DECODE, self.scheduler.server_args, diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index e59497dc9..c744e110d 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -175,6 +175,7 @@ class MooncakeKVManager(BaseKVManager): self.disaggregation_mode = disaggregation_mode self.init_engine() # for p/d multi node infer + self.bootstrap_host = server_args.host self.bootstrap_port = server_args.disaggregation_bootstrap_port self.dist_init_addr = server_args.dist_init_addr self.attn_tp_size = get_attention_tp_size() @@ -1020,6 +1021,7 @@ class MooncakeKVManager(BaseKVManager): def _register_to_bootstrap(self): """Register KVSender to bootstrap server via HTTP POST.""" if self.dist_init_addr: + # multi node case: bootstrap server's host is dist_init_addr if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6] if self.dist_init_addr.endswith("]"): host = self.dist_init_addr @@ -1028,7 +1030,8 @@ class MooncakeKVManager(BaseKVManager): else: host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0]) else: - host = get_ip() + # single node case: bootstrap server's host is same as http server's host + host = self.bootstrap_host host = maybe_wrap_ipv6_address(host) bootstrap_server_url = f"{host}:{self.bootstrap_port}" @@ -1545,7 +1548,8 @@ class MooncakeKVReceiver(BaseKVReceiver): class MooncakeKVBootstrapServer(BaseKVBootstrapServer): - def __init__(self, port: int): + def __init__(self, host: str, port: int): + self.host = host self.port = port self.app = web.Application() self.store = dict() @@ -1673,7 +1677,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): self._runner = web.AppRunner(self.app, access_log=access_log) self._loop.run_until_complete(self._runner.setup()) - site = web.TCPSite(self._runner, port=self.port) + site = web.TCPSite(self._runner, host=self.host, port=self.port) self._loop.run_until_complete(site.start()) self._loop.run_forever() except Exception as e: diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 9b80bd4ff..b70748250 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -23,7 +23,7 @@ import logging import threading from collections import deque from http import HTTPStatus -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Type import torch @@ -140,8 +140,10 @@ class PrefillBootstrapQueue: kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.gpu_id = self.scheduler.gpu_id - kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) - kv_manager = kv_manager_class( + kv_manager_class: Type[BaseKVManager] = get_kv_class( + self.transfer_backend, KVClassType.MANAGER + ) + kv_manager: BaseKVManager = kv_manager_class( kv_args, DisaggregationMode.PREFILL, self.scheduler.server_args, diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index efe867e5a..43770e3e2 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -5,7 +5,7 @@ import random from collections import deque from contextlib import nullcontext from enum import Enum -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, List, Optional, Type, Union import numpy as np import torch @@ -213,7 +213,9 @@ class KVClassType(Enum): BOOTSTRAP_SERVER = "bootstrap_server" -def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): +def get_kv_class( + transfer_backend: TransferBackend, class_type: KVClassType +) -> Optional[Type]: from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender if transfer_backend == TransferBackend.MOONCAKE: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d38534e60..d23d1a628 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -40,6 +40,7 @@ from typing import ( List, Optional, Tuple, + Type, TypeVar, Union, ) @@ -53,6 +54,7 @@ from fastapi import BackgroundTasks from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.disaggregation.base import BaseKVBootstrapServer from sglang.srt.disaggregation.utils import ( DisaggregationMode, KVClassType, @@ -479,11 +481,12 @@ class TokenizerManager: # Start kv boostrap server on prefill if self.disaggregation_mode == DisaggregationMode.PREFILL: # only start bootstrap server on prefill tm - kv_bootstrap_server_class = get_kv_class( + kv_bootstrap_server_class: Type[BaseKVBootstrapServer] = get_kv_class( self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER ) - self.bootstrap_server = kv_bootstrap_server_class( - self.server_args.disaggregation_bootstrap_port + self.bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class( + host=self.server_args.host, + port=self.server_args.disaggregation_bootstrap_port, ) is_create_store = ( self.server_args.node_rank == 0