feat(remote_model): support variable remote backend for model loader (#3964)
Signed-off-by: wangyu <wangyu.steph@bytedance.com>
This commit is contained in:
@@ -32,6 +32,7 @@ from sglang.lang.choices import (
|
||||
)
|
||||
from sglang.utils import LazyImport
|
||||
|
||||
ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
|
||||
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
|
||||
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
|
||||
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
|
||||
@@ -67,6 +68,7 @@ __all__ = [
|
||||
"greedy_token_selection",
|
||||
"token_length_normalized",
|
||||
"unconditional_likelihood_normalized",
|
||||
"ServerArgs",
|
||||
"Anthropic",
|
||||
"LiteLLM",
|
||||
"OpenAI",
|
||||
|
||||
@@ -22,6 +22,7 @@ class LoadFormat(str, enum.Enum):
|
||||
MISTRAL = "mistral"
|
||||
LAYERED = "layered"
|
||||
JAX = "jax"
|
||||
REMOTE = "remote"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -51,13 +51,14 @@ class ModelConfig:
|
||||
self.quantization = quantization
|
||||
|
||||
# Parse args
|
||||
self.maybe_pull_model_tokenizer_from_remote()
|
||||
self.model_override_args = json.loads(model_override_args)
|
||||
kwargs = {}
|
||||
if override_config_file and override_config_file.strip():
|
||||
kwargs["_configuration_file"] = override_config_file.strip()
|
||||
|
||||
self.hf_config = get_config(
|
||||
model_path,
|
||||
self.model_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
model_override_args=self.model_override_args,
|
||||
@@ -318,6 +319,29 @@ class ModelConfig:
|
||||
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
|
||||
return eos_ids
|
||||
|
||||
def maybe_pull_model_tokenizer_from_remote(self) -> None:
|
||||
"""
|
||||
Pull the model config files to a temporary
|
||||
directory in case of remote.
|
||||
|
||||
Args:
|
||||
model: The model name or path.
|
||||
|
||||
"""
|
||||
from sglang.srt.connector import create_remote_connector
|
||||
from sglang.srt.utils import is_remote_url
|
||||
|
||||
if is_remote_url(self.model_path):
|
||||
logger.info("Pulling model configs from remote...")
|
||||
# BaseConnector implements __del__() to clean up the local dir.
|
||||
# Since config files need to exist all the time, so we DO NOT use
|
||||
# with statement to avoid closing the client.
|
||||
client = create_remote_connector(self.model_path)
|
||||
if is_remote_url(self.model_path):
|
||||
client.pull_files(allow_pattern=["*config.json"])
|
||||
self.model_weights = self.model_path
|
||||
self.model_path = client.get_local_dir()
|
||||
|
||||
|
||||
def get_hf_text_config(config: PretrainedConfig):
|
||||
"""Get the "sub" config relevant to llm for multi modal models.
|
||||
|
||||
51
python/sglang/srt/connector/__init__.py
Normal file
51
python/sglang/srt/connector/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import enum
|
||||
import logging
|
||||
|
||||
from sglang.srt.connector.base_connector import (
|
||||
BaseConnector,
|
||||
BaseFileConnector,
|
||||
BaseKVConnector,
|
||||
)
|
||||
from sglang.srt.connector.redis import RedisConnector
|
||||
from sglang.srt.connector.s3 import S3Connector
|
||||
from sglang.srt.utils import parse_connector_type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectorType(str, enum.Enum):
|
||||
FS = "filesystem"
|
||||
KV = "KV"
|
||||
|
||||
|
||||
def create_remote_connector(url, device="cpu") -> BaseConnector:
|
||||
connector_type = parse_connector_type(url)
|
||||
if connector_type == "redis":
|
||||
return RedisConnector(url)
|
||||
elif connector_type == "s3":
|
||||
return S3Connector(url)
|
||||
else:
|
||||
raise ValueError(f"Invalid connector type: {url}")
|
||||
|
||||
|
||||
def get_connector_type(client: BaseConnector) -> ConnectorType:
|
||||
if isinstance(client, BaseKVConnector):
|
||||
return ConnectorType.KV
|
||||
if isinstance(client, BaseFileConnector):
|
||||
return ConnectorType.FS
|
||||
|
||||
raise ValueError(f"Invalid connector type: {client}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseConnector",
|
||||
"BaseFileConnector",
|
||||
"BaseKVConnector",
|
||||
"RedisConnector",
|
||||
"S3Connector",
|
||||
"ConnectorType",
|
||||
"create_remote_connector",
|
||||
"get_connector_type",
|
||||
]
|
||||
112
python/sglang/srt/connector/base_connector.py
Normal file
112
python/sglang/srt/connector/base_connector.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BaseConnector(ABC):
|
||||
"""
|
||||
For fs connector such as s3:
|
||||
<connector_type>://<path>/<filename>
|
||||
|
||||
For kv connector such as redis:
|
||||
<connector_type>://<host>:<port>/<model_name>/keys/<key>
|
||||
<connector_type://<host>:<port>/<model_name>/files/<filename>
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, device: torch.device = "cpu"):
|
||||
self.url = url
|
||||
self.device = device
|
||||
self.closed = False
|
||||
self.local_dir = tempfile.mkdtemp()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
existing_handler = signal.getsignal(sig)
|
||||
signal.signal(sig, self._close_by_signal(existing_handler))
|
||||
|
||||
def get_local_dir(self):
|
||||
return self.local_dir
|
||||
|
||||
@abstractmethod
|
||||
def weight_iterator(
|
||||
self, rank: int = 0
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def pull_files(
|
||||
self,
|
||||
allow_pattern: Optional[List[str]] = None,
|
||||
ignore_pattern: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def close(self):
|
||||
if self.closed:
|
||||
return
|
||||
|
||||
self.closed = True
|
||||
if os.path.exists(self.local_dir):
|
||||
shutil.rmtree(self.local_dir)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.close()
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def _close_by_signal(self, existing_handler=None):
|
||||
|
||||
def new_handler(signum, frame):
|
||||
self.close()
|
||||
if existing_handler:
|
||||
existing_handler(signum, frame)
|
||||
|
||||
return new_handler
|
||||
|
||||
|
||||
class BaseKVConnector(BaseConnector):
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str) -> Optional[torch.Tensor]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def getstr(self, key: str) -> Optional[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def set(self, key: str, obj: torch.Tensor) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def setstr(self, key: str, obj: str) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def list(self, prefix: str) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class BaseFileConnector(BaseConnector):
|
||||
"""
|
||||
List full file names from remote fs path and filter by allow pattern.
|
||||
|
||||
Args:
|
||||
allow_pattern: A list of patterns of which files to pull.
|
||||
|
||||
Returns:
|
||||
list[str]: List of full paths allowed by the pattern
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def glob(self, allow_pattern: str) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
85
python/sglang/srt/connector/redis.py
Normal file
85
python/sglang/srt/connector/redis.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.connector import BaseKVConnector
|
||||
from sglang.srt.connector.serde import create_serde
|
||||
from sglang.srt.connector.utils import pull_files_from_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisConnector(BaseKVConnector):
|
||||
|
||||
def __init__(self, url: str, device: torch.device = "cpu"):
|
||||
import redis
|
||||
|
||||
super().__init__(url, device)
|
||||
parsed_url = urlparse(url)
|
||||
self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)
|
||||
self.model_name = parsed_url.path.lstrip("/")
|
||||
# TODO: more serde options
|
||||
self.s, self.d = create_serde("safe")
|
||||
|
||||
def get(self, key: str) -> Optional[torch.Tensor]:
|
||||
val = self.connection.get(key)
|
||||
|
||||
if val is None:
|
||||
logger.error("Key %s not found", key)
|
||||
return None
|
||||
|
||||
return self.d.from_bytes(val)
|
||||
|
||||
def getstr(self, key: str) -> Optional[str]:
|
||||
val = self.connection.get(key)
|
||||
if val is None:
|
||||
logger.error("Key %s not found", key)
|
||||
return None
|
||||
|
||||
return val.decode("utf-8")
|
||||
|
||||
def set(self, key: str, tensor: torch.Tensor) -> None:
|
||||
assert tensor is not None
|
||||
self.connection.set(key, self.s.to_bytes(tensor))
|
||||
|
||||
def setstr(self, key: str, obj: str) -> None:
|
||||
self.connection.set(key, obj)
|
||||
|
||||
def list(self, prefix: str) -> List[str]:
|
||||
cursor = 0
|
||||
all_keys: List[bytes] = []
|
||||
|
||||
while True:
|
||||
ret: Tuple[int, List[bytes]] = self.connection.scan(
|
||||
cursor=cursor, match=f"{prefix}*"
|
||||
) # type: ignore
|
||||
cursor, keys = ret
|
||||
all_keys.extend(keys)
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
return [key.decode("utf-8") for key in all_keys]
|
||||
|
||||
def weight_iterator(
|
||||
self, rank: int = 0
|
||||
) -> Generator[Tuple[str, bytes], None, None]:
|
||||
keys = self.list(f"{self.model_name}/keys/rank_{rank}/")
|
||||
for key in keys:
|
||||
val = self.get(key)
|
||||
key = key.removeprefix(f"{self.model_name}/keys/rank_{rank}/")
|
||||
yield key, val
|
||||
|
||||
def pull_files(
|
||||
self,
|
||||
allow_pattern: Optional[List[str]] = None,
|
||||
ignore_pattern: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
pull_files_from_db(self, self.model_name, allow_pattern, ignore_pattern)
|
||||
|
||||
def close(self):
|
||||
self.connection.close()
|
||||
super().close()
|
||||
122
python/sglang/srt/connector/s3.py
Normal file
122
python/sglang/srt/connector/s3.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import fnmatch
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Generator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.connector import BaseFileConnector
|
||||
|
||||
|
||||
def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:
|
||||
return [
|
||||
path
|
||||
for path in paths
|
||||
if any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
|
||||
]
|
||||
|
||||
|
||||
def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
|
||||
return [
|
||||
path
|
||||
for path in paths
|
||||
if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
|
||||
]
|
||||
|
||||
|
||||
def list_files(
|
||||
s3,
|
||||
path: str,
|
||||
allow_pattern: Optional[list[str]] = None,
|
||||
ignore_pattern: Optional[list[str]] = None,
|
||||
) -> tuple[str, str, list[str]]:
|
||||
"""
|
||||
List files from S3 path and filter by pattern.
|
||||
|
||||
Args:
|
||||
s3: S3 client to use.
|
||||
path: The S3 path to list from.
|
||||
allow_pattern: A list of patterns of which files to pull.
|
||||
ignore_pattern: A list of patterns of which files not to pull.
|
||||
|
||||
Returns:
|
||||
tuple[str, str, list[str]]: A tuple where:
|
||||
- The first element is the bucket name
|
||||
- The second element is string represent the bucket
|
||||
and the prefix as a dir like string
|
||||
- The third element is a list of files allowed or
|
||||
disallowed by pattern
|
||||
"""
|
||||
parts = path.removeprefix("s3://").split("/")
|
||||
prefix = "/".join(parts[1:])
|
||||
bucket_name = parts[0]
|
||||
|
||||
objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
|
||||
paths = [obj["Key"] for obj in objects.get("Contents", [])]
|
||||
|
||||
paths = _filter_ignore(paths, ["*/"])
|
||||
if allow_pattern is not None:
|
||||
paths = _filter_allow(paths, allow_pattern)
|
||||
|
||||
if ignore_pattern is not None:
|
||||
paths = _filter_ignore(paths, ignore_pattern)
|
||||
|
||||
return bucket_name, prefix, paths
|
||||
|
||||
|
||||
class S3Connector(BaseFileConnector):
|
||||
|
||||
def __init__(self, url: str) -> None:
|
||||
import boto3
|
||||
|
||||
super().__init__(url)
|
||||
self.client = boto3.client("s3")
|
||||
|
||||
def glob(self, allow_pattern: Optional[list[str]] = None) -> list[str]:
|
||||
bucket_name, _, paths = list_files(
|
||||
self.client, path=self.url, allow_pattern=allow_pattern
|
||||
)
|
||||
return [f"s3://{bucket_name}/{path}" for path in paths]
|
||||
|
||||
def pull_files(
|
||||
self,
|
||||
allow_pattern: Optional[list[str]] = None,
|
||||
ignore_pattern: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Pull files from S3 storage into the temporary directory.
|
||||
|
||||
Args:
|
||||
s3_model_path: The S3 path of the model.
|
||||
allow_pattern: A list of patterns of which files to pull.
|
||||
ignore_pattern: A list of patterns of which files not to pull.
|
||||
|
||||
"""
|
||||
bucket_name, base_dir, files = list_files(
|
||||
self.client, self.url, allow_pattern, ignore_pattern
|
||||
)
|
||||
if len(files) == 0:
|
||||
return
|
||||
|
||||
for file in files:
|
||||
destination_file = os.path.join(self.local_dir, file.removeprefix(base_dir))
|
||||
local_dir = Path(destination_file).parent
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
self.client.download_file(bucket_name, file, destination_file)
|
||||
|
||||
def weight_iterator(
|
||||
self, rank: int = 0
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
runai_safetensors_weights_iterator,
|
||||
)
|
||||
|
||||
# only support safetensor files now
|
||||
hf_weights_files = self.glob(allow_pattern=["*.safetensors"])
|
||||
return runai_safetensors_weights_iterator(hf_weights_files)
|
||||
|
||||
def close(self):
|
||||
self.client.close()
|
||||
super().close()
|
||||
31
python/sglang/srt/connector/serde/__init__.py
Normal file
31
python/sglang/srt/connector/serde/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# inspired by LMCache
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.connector.serde.safe_serde import SafeDeserializer, SafeSerializer
|
||||
from sglang.srt.connector.serde.serde import Deserializer, Serializer
|
||||
|
||||
|
||||
def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]:
|
||||
s: Optional[Serializer] = None
|
||||
d: Optional[Deserializer] = None
|
||||
|
||||
if serde_type == "safe":
|
||||
s = SafeSerializer()
|
||||
d = SafeDeserializer(torch.uint8)
|
||||
else:
|
||||
raise ValueError(f"Unknown serde type: {serde_type}")
|
||||
|
||||
return s, d
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Serializer",
|
||||
"Deserializer",
|
||||
"SafeSerializer",
|
||||
"SafeDeserializer",
|
||||
"create_serde",
|
||||
]
|
||||
29
python/sglang/srt/connector/serde/safe_serde.py
Normal file
29
python/sglang/srt/connector/serde/safe_serde.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load, save
|
||||
|
||||
from sglang.srt.connector.serde.serde import Deserializer, Serializer
|
||||
|
||||
|
||||
class SafeSerializer(Serializer):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def to_bytes(self, t: torch.Tensor) -> bytes:
|
||||
return save({"tensor_bytes": t.cpu().contiguous()})
|
||||
|
||||
|
||||
class SafeDeserializer(Deserializer):
|
||||
|
||||
def __init__(self, dtype):
|
||||
super().__init__(dtype)
|
||||
|
||||
def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:
|
||||
return load(bytes(b))["tensor_bytes"].to(dtype=self.dtype)
|
||||
|
||||
def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:
|
||||
return self.from_bytes_normal(b)
|
||||
43
python/sglang/srt/connector/serde/serde.py
Normal file
43
python/sglang/srt/connector/serde/serde.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import abc
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Serializer(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def to_bytes(self, t: torch.Tensor) -> bytes:
|
||||
"""
|
||||
Serialize a pytorch tensor to bytes. The serialized bytes should contain
|
||||
both the data and the metadata (shape, dtype, etc.) of the tensor.
|
||||
|
||||
Input:
|
||||
t: the input pytorch tensor, can be on any device, in any shape,
|
||||
with any dtype
|
||||
|
||||
Returns:
|
||||
bytes: the serialized bytes
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Deserializer(metaclass=abc.ABCMeta):
|
||||
|
||||
def __init__(self, dtype):
|
||||
self.dtype = dtype
|
||||
|
||||
@abstractmethod
|
||||
def from_bytes(self, bs: bytes) -> torch.Tensor:
|
||||
"""
|
||||
Deserialize a pytorch tensor from bytes.
|
||||
|
||||
Input:
|
||||
bytes: a stream of bytes
|
||||
|
||||
Output:
|
||||
torch.Tensor: the deserialized pytorch tensor
|
||||
"""
|
||||
raise NotImplementedError
|
||||
35
python/sglang/srt/connector/utils.py
Normal file
35
python/sglang/srt/connector/utils.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sglang.srt.connector import BaseConnector
|
||||
|
||||
|
||||
def parse_model_name(url: str) -> str:
|
||||
"""
|
||||
Parse the model name from the url.
|
||||
Only used for db connector
|
||||
"""
|
||||
parsed_url = urlparse(url)
|
||||
return parsed_url.path.lstrip("/")
|
||||
|
||||
|
||||
def pull_files_from_db(
|
||||
connector: BaseConnector,
|
||||
model_name: str,
|
||||
allow_pattern: Optional[list[str]] = None,
|
||||
ignore_pattern: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
prefix = f"{model_name}/files/"
|
||||
local_dir = connector.get_local_dir()
|
||||
files = connector.list(prefix)
|
||||
|
||||
for file in files:
|
||||
destination_file = os.path.join(local_dir, file.removeprefix(prefix))
|
||||
local_dir = Path(destination_file).parent
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
with open(destination_file, "wb") as f:
|
||||
f.write(connector.getstr(file).encode("utf-8"))
|
||||
@@ -27,6 +27,9 @@ import signal
|
||||
import threading
|
||||
from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
|
||||
@@ -44,6 +47,8 @@ from sglang.srt.managers.io_struct import (
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
RpcReqInput,
|
||||
RpcReqOutput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
@@ -57,6 +62,7 @@ from sglang.srt.utils import (
|
||||
MultiprocessingSerializer,
|
||||
assert_pkg_version,
|
||||
configure_logger,
|
||||
get_zmq_socket,
|
||||
kill_process_tree,
|
||||
launch_dummy_health_check_server,
|
||||
maybe_set_triton_cache_manager,
|
||||
@@ -102,15 +108,25 @@ class Engine:
|
||||
# Shutdown the subprocesses automatically when the program exits
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
# Allocate ports for inter-process communications
|
||||
port_args = PortArgs.init_new(server_args)
|
||||
logger.info(f"{server_args=}")
|
||||
|
||||
# Launch subprocesses
|
||||
tokenizer_manager, scheduler_info = _launch_subprocesses(
|
||||
server_args=server_args
|
||||
server_args=server_args,
|
||||
port_args=port_args,
|
||||
)
|
||||
|
||||
self.server_args = server_args
|
||||
self.tokenizer_manager = tokenizer_manager
|
||||
self.scheduler_info = scheduler_info
|
||||
|
||||
context = zmq.Context(2)
|
||||
self.send_to_rpc = get_zmq_socket(
|
||||
context, zmq.DEALER, port_args.rpc_ipc_name, True
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
@@ -350,6 +366,23 @@ class Engine:
|
||||
self.tokenizer_manager.resume_memory_occupation(obj, None)
|
||||
)
|
||||
|
||||
"""
|
||||
Execute an RPC call on all scheduler processes.
|
||||
"""
|
||||
|
||||
def collective_rpc(self, method: str, **kwargs):
|
||||
obj = RpcReqInput(method=method, parameters=kwargs)
|
||||
self.send_to_rpc.send_pyobj(obj)
|
||||
recv_req = self.send_to_rpc.recv_pyobj(zmq.BLOCKY)
|
||||
assert isinstance(recv_req, RpcReqOutput)
|
||||
assert recv_req.success, recv_req.message
|
||||
|
||||
def save_remote_model(self, **kwargs):
|
||||
self.collective_rpc("save_remote_model", **kwargs)
|
||||
|
||||
def save_sharded_model(self, **kwargs):
|
||||
self.collective_rpc("save_sharded_model", **kwargs)
|
||||
|
||||
|
||||
def _set_envs_and_config(server_args: ServerArgs):
|
||||
# Set global environments
|
||||
@@ -408,7 +441,9 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
|
||||
def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dict]:
|
||||
def _launch_subprocesses(
|
||||
server_args: ServerArgs, port_args: Optional[PortArgs] = None
|
||||
) -> Tuple[TokenizerManager, Dict]:
|
||||
"""
|
||||
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
|
||||
"""
|
||||
@@ -418,8 +453,9 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
|
||||
_set_envs_and_config(server_args)
|
||||
|
||||
# Allocate ports for inter-process communications
|
||||
port_args = PortArgs.init_new(server_args)
|
||||
logger.info(f"{server_args=}")
|
||||
if port_args is None:
|
||||
port_args = PortArgs.init_new(server_args)
|
||||
logger.info(f"{server_args=}")
|
||||
|
||||
# If using model from www.modelscope.cn, first download the model.
|
||||
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
|
||||
|
||||
@@ -37,6 +37,8 @@ from sglang.srt.configs import (
|
||||
MultiModalityConfig,
|
||||
Qwen2_5_VLConfig,
|
||||
)
|
||||
from sglang.srt.connector import create_remote_connector
|
||||
from sglang.srt.utils import is_remote_url
|
||||
|
||||
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
ChatGLMConfig.model_type: ChatGLMConfig,
|
||||
@@ -155,6 +157,14 @@ def get_tokenizer(
|
||||
kwargs["gguf_file"] = tokenizer_name
|
||||
tokenizer_name = Path(tokenizer_name).parent
|
||||
|
||||
if is_remote_url(tokenizer_name):
|
||||
# BaseConnector implements __del__() to clean up the local dir.
|
||||
# Since config files need to exist all the time, so we DO NOT use
|
||||
# with statement to avoid closing the client.
|
||||
client = create_remote_connector(tokenizer_name)
|
||||
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
|
||||
tokenizer_name = client.get_local_dir()
|
||||
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name,
|
||||
|
||||
@@ -723,3 +723,15 @@ class SeparateReasoningReqInput:
|
||||
class VertexGenerateReqInput:
|
||||
instances: List[dict]
|
||||
parameters: Optional[dict] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RpcReqInput:
|
||||
method: str
|
||||
parameters: Optional[Dict] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RpcReqOutput:
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
@@ -32,6 +32,7 @@ import psutil
|
||||
import setproctitle
|
||||
import torch
|
||||
import zmq
|
||||
from torch.distributed import barrier
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
@@ -59,6 +60,8 @@ from sglang.srt.managers.io_struct import (
|
||||
ReleaseMemoryOccupationReqOutput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqOutput,
|
||||
RpcReqInput,
|
||||
RpcReqOutput,
|
||||
SetInternalStateReq,
|
||||
SetInternalStateReqOutput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
@@ -193,8 +196,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
self.send_to_detokenizer = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
|
||||
)
|
||||
|
||||
self.recv_from_rpc = get_zmq_socket(
|
||||
context, zmq.DEALER, port_args.rpc_ipc_name, False
|
||||
)
|
||||
else:
|
||||
self.recv_from_tokenizer = None
|
||||
self.recv_from_rpc = None
|
||||
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
||||
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
||||
|
||||
@@ -376,6 +384,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
(ProfileReq, self.profile),
|
||||
(GetInternalStateReq, self.get_internal_state),
|
||||
(SetInternalStateReq, self.set_internal_state),
|
||||
(RpcReqInput, self.handle_rpc_request),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -549,6 +558,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
recv_reqs.append(recv_req)
|
||||
|
||||
while True:
|
||||
try:
|
||||
recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
recv_reqs.append(recv_rpc)
|
||||
else:
|
||||
recv_reqs = None
|
||||
|
||||
@@ -600,7 +616,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
|
||||
output = self._request_dispatcher(recv_req)
|
||||
if output is not None:
|
||||
self.send_to_tokenizer.send_pyobj(output)
|
||||
if isinstance(output, RpcReqOutput):
|
||||
if self.recv_from_rpc is not None:
|
||||
self.recv_from_rpc.send_pyobj(output)
|
||||
else:
|
||||
self.send_to_tokenizer.send_pyobj(output)
|
||||
|
||||
def handle_generate_request(
|
||||
self,
|
||||
@@ -1492,6 +1512,47 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
server_args=global_server_args_dict,
|
||||
)
|
||||
|
||||
def handle_rpc_request(self, recv_req: RpcReqInput):
|
||||
# Handle RPC requests
|
||||
logger.info(
|
||||
f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
|
||||
)
|
||||
|
||||
success = True
|
||||
exec = None
|
||||
try:
|
||||
func = getattr(self, recv_req.method)
|
||||
func(recv_req.parameters)
|
||||
except Exception as e:
|
||||
success = False
|
||||
exec = e
|
||||
logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")
|
||||
|
||||
barrier()
|
||||
return RpcReqOutput(success, "" if not exec else str(exec))
|
||||
|
||||
def save_remote_model(self, params):
|
||||
url = params["url"]
|
||||
|
||||
if isinstance(self.tp_worker, TpModelWorkerClient):
|
||||
worker = self.tp_worker.worker
|
||||
else:
|
||||
worker = self.tp_worker
|
||||
|
||||
worker.model_runner.save_remote_model(url)
|
||||
|
||||
def save_sharded_model(self, params):
|
||||
if isinstance(self.tp_worker, TpModelWorkerClient):
|
||||
worker = self.tp_worker.worker
|
||||
else:
|
||||
worker = self.tp_worker
|
||||
|
||||
worker.model_runner.save_sharded_model(
|
||||
path=params["path"],
|
||||
pattern=params["pattern"],
|
||||
max_size=params["max_size"],
|
||||
)
|
||||
|
||||
def abort_request(self, recv_req: AbortReq):
|
||||
# Delete requests in the waiting queue
|
||||
to_del = []
|
||||
|
||||
@@ -1009,6 +1009,22 @@ class ModelRunner:
|
||||
return False
|
||||
return rope_scaling.get("type", None) == "mrope"
|
||||
|
||||
def save_remote_model(self, url: str):
|
||||
from sglang.srt.model_loader.loader import RemoteModelLoader
|
||||
|
||||
logger.info(f"Saving model to {url}")
|
||||
RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)
|
||||
|
||||
def save_sharded_model(
|
||||
self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
|
||||
):
|
||||
from sglang.srt.model_loader.loader import ShardedStateLoader
|
||||
|
||||
logger.info(
|
||||
f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
|
||||
)
|
||||
ShardedStateLoader.save_model(self.model, path, pattern, max_size)
|
||||
|
||||
|
||||
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(model.named_parameters())
|
||||
|
||||
@@ -9,6 +9,7 @@ import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
|
||||
@@ -25,6 +26,12 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.connector import (
|
||||
ConnectorType,
|
||||
create_remote_connector,
|
||||
get_connector_type,
|
||||
)
|
||||
from sglang.srt.connector.utils import parse_model_name
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@@ -46,6 +53,7 @@ from sglang.srt.model_loader.weight_utils import (
|
||||
np_cache_weights_iterator,
|
||||
pt_weights_iterator,
|
||||
safetensors_weights_iterator,
|
||||
set_runai_streamer_env,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
@@ -490,7 +498,7 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
Model loader that directly loads each worker's model state dict, which
|
||||
enables a fast load path for large tensor-parallel models where each worker
|
||||
only needs to read its own shard rather than the entire checkpoint. See
|
||||
`examples/save_sharded_state.py` for creating a sharded checkpoint.
|
||||
`examples/runtime/engine/save_sharded_state.py` for creating a sharded checkpoint.
|
||||
"""
|
||||
|
||||
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
||||
@@ -1204,6 +1212,153 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
return model
|
||||
|
||||
|
||||
class RemoteModelLoader(BaseModelLoader):
|
||||
"""Model loader that can load Tensors from remote database."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
# TODO @DellCurry: move to s3 connector only
|
||||
set_runai_streamer_env(load_config)
|
||||
|
||||
def _get_weights_iterator_kv(
|
||||
self,
|
||||
client,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights from remote storage."""
|
||||
assert get_connector_type(client) == ConnectorType.KV
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
return client.weight_iterator(rank)
|
||||
|
||||
def _get_weights_iterator_fs(
|
||||
self,
|
||||
client,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights from remote storage."""
|
||||
assert get_connector_type(client) == ConnectorType.FS
|
||||
return client.weight_iterator()
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def save_model(
|
||||
model: torch.nn.Module,
|
||||
model_path: str,
|
||||
url: str,
|
||||
) -> None:
|
||||
with create_remote_connector(url) as client:
|
||||
assert get_connector_type(client) == ConnectorType.KV
|
||||
model_name = parse_model_name(url)
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
|
||||
for key, tensor in state_dict.items():
|
||||
r_key = f"{model_name}/keys/rank_{rank}/{key}"
|
||||
client.set(r_key, tensor)
|
||||
|
||||
for root, _, files in os.walk(model_path):
|
||||
for file_name in files:
|
||||
# ignore hidden files
|
||||
if file_name.startswith("."):
|
||||
continue
|
||||
if os.path.splitext(file_name)[1] not in (
|
||||
".bin",
|
||||
".pt",
|
||||
".safetensors",
|
||||
):
|
||||
file_path = os.path.join(root, file_name)
|
||||
with open(file_path, encoding="utf-8") as file:
|
||||
file_content = file.read()
|
||||
f_key = f"{model_name}/files/{file_name}"
|
||||
client.setstr(f_key, file_content)
|
||||
|
||||
def _load_model_from_remote_kv(self, model: nn.Module, client):
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
quant_method.process_weights_after_loading(module)
|
||||
weights_iterator = self._get_weights_iterator_kv(client)
|
||||
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
|
||||
for key, tensor in weights_iterator:
|
||||
# If loading with LoRA enabled, additional padding may
|
||||
# be added to certain parameters. We only load into a
|
||||
# narrowed view of the parameter data.
|
||||
param_data = state_dict[key].data
|
||||
param_shape = state_dict[key].shape
|
||||
for dim, size in enumerate(tensor.shape):
|
||||
if size < param_shape[dim]:
|
||||
param_data = param_data.narrow(dim, 0, size)
|
||||
if tensor.shape != param_shape:
|
||||
logger.warning(
|
||||
"loading tensor of shape %s into " "parameter '%s' of shape %s",
|
||||
tensor.shape,
|
||||
key,
|
||||
param_shape,
|
||||
)
|
||||
param_data.copy_(tensor)
|
||||
state_dict.pop(key)
|
||||
if state_dict:
|
||||
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
|
||||
|
||||
def _load_model_from_remote_fs(
|
||||
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
|
||||
) -> nn.Module:
|
||||
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
model.load_weights(self._get_weights_iterator_fs(client))
|
||||
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
# When quant methods need to process weights after loading
|
||||
# (for repacking, quantizing, etc), they expect parameters
|
||||
# to be on the global target device. This scope is for the
|
||||
# case where cpu offloading is used, where we will move the
|
||||
# parameters onto device for processing and back off after.
|
||||
with device_loading_context(module, target_device):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
) -> nn.Module:
|
||||
logger.info("Loading weights from remote storage ...")
|
||||
start = time.perf_counter()
|
||||
load_config = self.load_config
|
||||
|
||||
assert load_config.load_format == LoadFormat.REMOTE, (
|
||||
f"Model loader {self.load_config.load_format} is not supported for "
|
||||
f"load format {load_config.load_format}"
|
||||
)
|
||||
|
||||
model_weights = model_config.model_path
|
||||
if hasattr(model_config, "model_weights"):
|
||||
model_weights = model_config.model_weights
|
||||
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config)
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
quant_method.process_weights_after_loading(module)
|
||||
|
||||
with create_remote_connector(model_weights, device_config.device) as client:
|
||||
connector_type = get_connector_type(client)
|
||||
if connector_type == ConnectorType.KV:
|
||||
self._load_model_from_remote_kv(model, client)
|
||||
elif connector_type == ConnectorType.FS:
|
||||
self._load_model_from_remote_fs(
|
||||
model, client, model_config, device_config
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
logger.info("Loaded weights from remote storage in %.2f seconds.", end - start)
|
||||
return model.eval()
|
||||
|
||||
|
||||
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
"""Get a model loader based on the load format."""
|
||||
|
||||
@@ -1225,4 +1380,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
if load_config.load_format == LoadFormat.LAYERED:
|
||||
return LayeredModelLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.REMOTE:
|
||||
return RemoteModelLoader(load_config)
|
||||
|
||||
return DefaultModelLoader(load_config)
|
||||
|
||||
@@ -585,6 +585,51 @@ def composed_weight_loader(
|
||||
return composed_loader
|
||||
|
||||
|
||||
def runai_safetensors_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
from runai_model_streamer import SafetensorsStreamer
|
||||
|
||||
enable_tqdm = (
|
||||
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
||||
)
|
||||
|
||||
with SafetensorsStreamer() as streamer:
|
||||
for st_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading safetensors using Runai Model Streamer",
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
streamer.stream_file(st_file)
|
||||
yield from streamer.get_tensors()
|
||||
|
||||
|
||||
def set_runai_streamer_env(load_config: LoadConfig):
|
||||
if load_config.model_loader_extra_config:
|
||||
extra_config = load_config.model_loader_extra_config
|
||||
|
||||
if "concurrency" in extra_config and isinstance(
|
||||
extra_config.get("concurrency"), int
|
||||
):
|
||||
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
|
||||
extra_config.get("concurrency")
|
||||
)
|
||||
|
||||
if "memory_limit" in extra_config and isinstance(
|
||||
extra_config.get("memory_limit"), int
|
||||
):
|
||||
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
|
||||
extra_config.get("memory_limit")
|
||||
)
|
||||
|
||||
runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT")
|
||||
aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL")
|
||||
if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None:
|
||||
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
|
||||
|
||||
|
||||
def initialize_dummy_weights(
|
||||
model: torch.nn.Module,
|
||||
low: float = -1e-3,
|
||||
|
||||
@@ -30,6 +30,7 @@ from sglang.srt.utils import (
|
||||
is_flashinfer_available,
|
||||
is_hip,
|
||||
is_port_available,
|
||||
is_remote_url,
|
||||
is_valid_ipv6_address,
|
||||
nullable_str,
|
||||
)
|
||||
@@ -296,6 +297,9 @@ class ServerArgs:
|
||||
) and check_gguf_file(self.model_path):
|
||||
self.quantization = self.load_format = "gguf"
|
||||
|
||||
if is_remote_url(self.model_path):
|
||||
self.load_format = "remote"
|
||||
|
||||
# AMD-specific Triton attention KV splits default number
|
||||
if is_hip():
|
||||
self.triton_attention_num_kv_splits = 16
|
||||
@@ -345,9 +349,11 @@ class ServerArgs:
|
||||
"safetensors",
|
||||
"npcache",
|
||||
"dummy",
|
||||
"sharded_state",
|
||||
"gguf",
|
||||
"bitsandbytes",
|
||||
"layered",
|
||||
"remote",
|
||||
],
|
||||
help="The format of the model weights to load. "
|
||||
'"auto" will try to load the weights in the safetensors format '
|
||||
@@ -1088,6 +1094,9 @@ class PortArgs:
|
||||
# The port for nccl initialization (torch.dist)
|
||||
nccl_port: int
|
||||
|
||||
# The ipc filename for rpc call between Engine and Scheduler
|
||||
rpc_ipc_name: str
|
||||
|
||||
@staticmethod
|
||||
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
|
||||
port = server_args.port + random.randint(100, 1000)
|
||||
@@ -1106,6 +1115,7 @@ class PortArgs:
|
||||
scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
||||
detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
||||
nccl_port=port,
|
||||
rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
||||
)
|
||||
else:
|
||||
# DP attention. Use TCP + port to handle both single-node and multi-node.
|
||||
@@ -1131,6 +1141,7 @@ class PortArgs:
|
||||
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
|
||||
detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
|
||||
nccl_port=port,
|
||||
rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ from importlib.util import find_spec
|
||||
from io import BytesIO
|
||||
from multiprocessing import Pool
|
||||
from multiprocessing.reduction import ForkingPickler
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -774,12 +775,22 @@ def get_zmq_socket(
|
||||
buf_size = -1
|
||||
|
||||
socket = context.socket(socket_type)
|
||||
if socket_type == zmq.PUSH:
|
||||
|
||||
def set_send_opt():
|
||||
socket.setsockopt(zmq.SNDHWM, 0)
|
||||
socket.setsockopt(zmq.SNDBUF, buf_size)
|
||||
elif socket_type == zmq.PULL:
|
||||
|
||||
def set_recv_opt():
|
||||
socket.setsockopt(zmq.RCVHWM, 0)
|
||||
socket.setsockopt(zmq.RCVBUF, buf_size)
|
||||
|
||||
if socket_type == zmq.PUSH:
|
||||
set_send_opt()
|
||||
elif socket_type == zmq.PULL:
|
||||
set_recv_opt()
|
||||
elif socket_type == zmq.DEALER:
|
||||
set_send_opt()
|
||||
set_recv_opt()
|
||||
else:
|
||||
raise ValueError(f"Unsupported socket type: {socket_type}")
|
||||
|
||||
@@ -1572,3 +1583,29 @@ def add_prefix(name: str, prefix: str) -> str:
|
||||
The string `prefix.name` if prefix is non-empty, otherwise just `name`.
|
||||
"""
|
||||
return name if not prefix else f"{prefix}.{name}"
|
||||
|
||||
|
||||
def is_remote_url(url: Union[str, Path]) -> bool:
|
||||
"""
|
||||
Check if the URL is a remote URL of the format:
|
||||
<connector_type>://<host>:<port>/<model_name>
|
||||
"""
|
||||
if isinstance(url, Path):
|
||||
return False
|
||||
|
||||
pattern = r"(.+)://(.*)"
|
||||
m = re.match(pattern, url)
|
||||
return m is not None
|
||||
|
||||
|
||||
def parse_connector_type(url: str) -> str:
|
||||
"""
|
||||
Parse the connector type from the URL of the format:
|
||||
<connector_type>://<path>
|
||||
"""
|
||||
pattern = r"(.+)://(.*)"
|
||||
m = re.match(pattern, url)
|
||||
if m is None:
|
||||
return ""
|
||||
|
||||
return m.group(1)
|
||||
|
||||
Reference in New Issue
Block a user