[Vulnerability]feat(conn): set bootstrap server host (#9931)
This commit is contained in:
@@ -131,4 +131,4 @@ class BaseKVReceiver(ABC):
|
|||||||
|
|
||||||
class BaseKVBootstrapServer(ABC):
|
class BaseKVBootstrapServer(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self, port: int): ...
|
def __init__(self, host: str, port: int): ...
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ class CommonKVManager(BaseKVManager):
|
|||||||
self.is_mla_backend = is_mla_backend
|
self.is_mla_backend = is_mla_backend
|
||||||
self.disaggregation_mode = disaggregation_mode
|
self.disaggregation_mode = disaggregation_mode
|
||||||
# for p/d multi node infer
|
# for p/d multi node infer
|
||||||
|
self.bootstrap_host = server_args.host
|
||||||
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
||||||
self.dist_init_addr = server_args.dist_init_addr
|
self.dist_init_addr = server_args.dist_init_addr
|
||||||
self.tp_size = server_args.tp_size
|
self.tp_size = server_args.tp_size
|
||||||
@@ -72,6 +73,7 @@ class CommonKVManager(BaseKVManager):
|
|||||||
def _register_to_bootstrap(self):
|
def _register_to_bootstrap(self):
|
||||||
"""Register KVSender to bootstrap server via HTTP POST."""
|
"""Register KVSender to bootstrap server via HTTP POST."""
|
||||||
if self.dist_init_addr:
|
if self.dist_init_addr:
|
||||||
|
# multi node: bootstrap server's host is dist_init_addr
|
||||||
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
||||||
if self.dist_init_addr.endswith("]"):
|
if self.dist_init_addr.endswith("]"):
|
||||||
host = self.dist_init_addr
|
host = self.dist_init_addr
|
||||||
@@ -80,7 +82,8 @@ class CommonKVManager(BaseKVManager):
|
|||||||
else:
|
else:
|
||||||
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
||||||
else:
|
else:
|
||||||
host = get_ip()
|
# single node: bootstrap server's host is same as http server's host
|
||||||
|
host = self.bootstrap_host
|
||||||
host = maybe_wrap_ipv6_address(host)
|
host = maybe_wrap_ipv6_address(host)
|
||||||
|
|
||||||
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
|
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
|
||||||
@@ -308,7 +311,8 @@ class CommonKVReceiver(BaseKVReceiver):
|
|||||||
|
|
||||||
|
|
||||||
class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
||||||
def __init__(self, port: int):
|
def __init__(self, host: str, port: int):
|
||||||
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
self.app = web.Application()
|
self.app = web.Application()
|
||||||
self.store = dict()
|
self.store = dict()
|
||||||
@@ -412,7 +416,7 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
|||||||
self._runner = web.AppRunner(self.app)
|
self._runner = web.AppRunner(self.app)
|
||||||
self._loop.run_until_complete(self._runner.setup())
|
self._loop.run_until_complete(self._runner.setup())
|
||||||
|
|
||||||
site = web.TCPSite(self._runner, port=self.port)
|
site = web.TCPSite(self._runner, host=self.host, port=self.port)
|
||||||
self._loop.run_until_complete(site.start())
|
self._loop.run_until_complete(site.start())
|
||||||
self._loop.run_forever()
|
self._loop.run_forever()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ import logging
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
@@ -218,8 +218,10 @@ class DecodePreallocQueue:
|
|||||||
|
|
||||||
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
||||||
kv_args.gpu_id = self.scheduler.gpu_id
|
kv_args.gpu_id = self.scheduler.gpu_id
|
||||||
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
kv_manager_class: Type[BaseKVManager] = get_kv_class(
|
||||||
kv_manager = kv_manager_class(
|
self.transfer_backend, KVClassType.MANAGER
|
||||||
|
)
|
||||||
|
kv_manager: BaseKVManager = kv_manager_class(
|
||||||
kv_args,
|
kv_args,
|
||||||
DisaggregationMode.DECODE,
|
DisaggregationMode.DECODE,
|
||||||
self.scheduler.server_args,
|
self.scheduler.server_args,
|
||||||
|
|||||||
@@ -175,6 +175,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
self.disaggregation_mode = disaggregation_mode
|
self.disaggregation_mode = disaggregation_mode
|
||||||
self.init_engine()
|
self.init_engine()
|
||||||
# for p/d multi node infer
|
# for p/d multi node infer
|
||||||
|
self.bootstrap_host = server_args.host
|
||||||
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
||||||
self.dist_init_addr = server_args.dist_init_addr
|
self.dist_init_addr = server_args.dist_init_addr
|
||||||
self.attn_tp_size = get_attention_tp_size()
|
self.attn_tp_size = get_attention_tp_size()
|
||||||
@@ -1020,6 +1021,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
def _register_to_bootstrap(self):
|
def _register_to_bootstrap(self):
|
||||||
"""Register KVSender to bootstrap server via HTTP POST."""
|
"""Register KVSender to bootstrap server via HTTP POST."""
|
||||||
if self.dist_init_addr:
|
if self.dist_init_addr:
|
||||||
|
# multi node case: bootstrap server's host is dist_init_addr
|
||||||
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
||||||
if self.dist_init_addr.endswith("]"):
|
if self.dist_init_addr.endswith("]"):
|
||||||
host = self.dist_init_addr
|
host = self.dist_init_addr
|
||||||
@@ -1028,7 +1030,8 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
else:
|
else:
|
||||||
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
||||||
else:
|
else:
|
||||||
host = get_ip()
|
# single node case: bootstrap server's host is same as http server's host
|
||||||
|
host = self.bootstrap_host
|
||||||
host = maybe_wrap_ipv6_address(host)
|
host = maybe_wrap_ipv6_address(host)
|
||||||
|
|
||||||
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
|
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
|
||||||
@@ -1545,7 +1548,8 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
|
|
||||||
|
|
||||||
class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
||||||
def __init__(self, port: int):
|
def __init__(self, host: str, port: int):
|
||||||
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
self.app = web.Application()
|
self.app = web.Application()
|
||||||
self.store = dict()
|
self.store = dict()
|
||||||
@@ -1673,7 +1677,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|||||||
self._runner = web.AppRunner(self.app, access_log=access_log)
|
self._runner = web.AppRunner(self.app, access_log=access_log)
|
||||||
self._loop.run_until_complete(self._runner.setup())
|
self._loop.run_until_complete(self._runner.setup())
|
||||||
|
|
||||||
site = web.TCPSite(self._runner, port=self.port)
|
site = web.TCPSite(self._runner, host=self.host, port=self.port)
|
||||||
self._loop.run_until_complete(site.start())
|
self._loop.run_until_complete(site.start())
|
||||||
self._loop.run_forever()
|
self._loop.run_forever()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import logging
|
|||||||
import threading
|
import threading
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -140,8 +140,10 @@ class PrefillBootstrapQueue:
|
|||||||
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
||||||
kv_args.gpu_id = self.scheduler.gpu_id
|
kv_args.gpu_id = self.scheduler.gpu_id
|
||||||
|
|
||||||
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
kv_manager_class: Type[BaseKVManager] = get_kv_class(
|
||||||
kv_manager = kv_manager_class(
|
self.transfer_backend, KVClassType.MANAGER
|
||||||
|
)
|
||||||
|
kv_manager: BaseKVManager = kv_manager_class(
|
||||||
kv_args,
|
kv_args,
|
||||||
DisaggregationMode.PREFILL,
|
DisaggregationMode.PREFILL,
|
||||||
self.scheduler.server_args,
|
self.scheduler.server_args,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import random
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, List, Optional, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -213,7 +213,9 @@ class KVClassType(Enum):
|
|||||||
BOOTSTRAP_SERVER = "bootstrap_server"
|
BOOTSTRAP_SERVER = "bootstrap_server"
|
||||||
|
|
||||||
|
|
||||||
def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
def get_kv_class(
|
||||||
|
transfer_backend: TransferBackend, class_type: KVClassType
|
||||||
|
) -> Optional[Type]:
|
||||||
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
||||||
|
|
||||||
if transfer_backend == TransferBackend.MOONCAKE:
|
if transfer_backend == TransferBackend.MOONCAKE:
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
@@ -53,6 +54,7 @@ from fastapi import BackgroundTasks
|
|||||||
|
|
||||||
from sglang.srt.aio_rwlock import RWLock
|
from sglang.srt.aio_rwlock import RWLock
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
|
from sglang.srt.disaggregation.base import BaseKVBootstrapServer
|
||||||
from sglang.srt.disaggregation.utils import (
|
from sglang.srt.disaggregation.utils import (
|
||||||
DisaggregationMode,
|
DisaggregationMode,
|
||||||
KVClassType,
|
KVClassType,
|
||||||
@@ -479,11 +481,12 @@ class TokenizerManager:
|
|||||||
# Start kv boostrap server on prefill
|
# Start kv boostrap server on prefill
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
# only start bootstrap server on prefill tm
|
# only start bootstrap server on prefill tm
|
||||||
kv_bootstrap_server_class = get_kv_class(
|
kv_bootstrap_server_class: Type[BaseKVBootstrapServer] = get_kv_class(
|
||||||
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
||||||
)
|
)
|
||||||
self.bootstrap_server = kv_bootstrap_server_class(
|
self.bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class(
|
||||||
self.server_args.disaggregation_bootstrap_port
|
host=self.server_args.host,
|
||||||
|
port=self.server_args.disaggregation_bootstrap_port,
|
||||||
)
|
)
|
||||||
is_create_store = (
|
is_create_store = (
|
||||||
self.server_args.node_rank == 0
|
self.server_args.node_rank == 0
|
||||||
|
|||||||
Reference in New Issue
Block a user