feat(remote_model): support variable remote backend for model loader (#3964)
Signed-off-by: wangyu <wangyu.steph@bytedance.com>
This commit is contained in:
51
python/sglang/srt/connector/__init__.py
Normal file
51
python/sglang/srt/connector/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import enum
|
||||
import logging
|
||||
|
||||
from sglang.srt.connector.base_connector import (
|
||||
BaseConnector,
|
||||
BaseFileConnector,
|
||||
BaseKVConnector,
|
||||
)
|
||||
from sglang.srt.connector.redis import RedisConnector
|
||||
from sglang.srt.connector.s3 import S3Connector
|
||||
from sglang.srt.utils import parse_connector_type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectorType(str, enum.Enum):
|
||||
FS = "filesystem"
|
||||
KV = "KV"
|
||||
|
||||
|
||||
def create_remote_connector(url, device="cpu") -> BaseConnector:
|
||||
connector_type = parse_connector_type(url)
|
||||
if connector_type == "redis":
|
||||
return RedisConnector(url)
|
||||
elif connector_type == "s3":
|
||||
return S3Connector(url)
|
||||
else:
|
||||
raise ValueError(f"Invalid connector type: {url}")
|
||||
|
||||
|
||||
def get_connector_type(client: BaseConnector) -> ConnectorType:
|
||||
if isinstance(client, BaseKVConnector):
|
||||
return ConnectorType.KV
|
||||
if isinstance(client, BaseFileConnector):
|
||||
return ConnectorType.FS
|
||||
|
||||
raise ValueError(f"Invalid connector type: {client}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseConnector",
|
||||
"BaseFileConnector",
|
||||
"BaseKVConnector",
|
||||
"RedisConnector",
|
||||
"S3Connector",
|
||||
"ConnectorType",
|
||||
"create_remote_connector",
|
||||
"get_connector_type",
|
||||
]
|
||||
112
python/sglang/srt/connector/base_connector.py
Normal file
112
python/sglang/srt/connector/base_connector.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BaseConnector(ABC):
|
||||
"""
|
||||
For fs connector such as s3:
|
||||
<connector_type>://<path>/<filename>
|
||||
|
||||
For kv connector such as redis:
|
||||
<connector_type>://<host>:<port>/<model_name>/keys/<key>
|
||||
<connector_type://<host>:<port>/<model_name>/files/<filename>
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, device: torch.device = "cpu"):
|
||||
self.url = url
|
||||
self.device = device
|
||||
self.closed = False
|
||||
self.local_dir = tempfile.mkdtemp()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
existing_handler = signal.getsignal(sig)
|
||||
signal.signal(sig, self._close_by_signal(existing_handler))
|
||||
|
||||
def get_local_dir(self):
|
||||
return self.local_dir
|
||||
|
||||
@abstractmethod
|
||||
def weight_iterator(
|
||||
self, rank: int = 0
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def pull_files(
|
||||
self,
|
||||
allow_pattern: Optional[List[str]] = None,
|
||||
ignore_pattern: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def close(self):
|
||||
if self.closed:
|
||||
return
|
||||
|
||||
self.closed = True
|
||||
if os.path.exists(self.local_dir):
|
||||
shutil.rmtree(self.local_dir)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.close()
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def _close_by_signal(self, existing_handler=None):
|
||||
|
||||
def new_handler(signum, frame):
|
||||
self.close()
|
||||
if existing_handler:
|
||||
existing_handler(signum, frame)
|
||||
|
||||
return new_handler
|
||||
|
||||
|
||||
class BaseKVConnector(BaseConnector):
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str) -> Optional[torch.Tensor]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def getstr(self, key: str) -> Optional[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def set(self, key: str, obj: torch.Tensor) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def setstr(self, key: str, obj: str) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def list(self, prefix: str) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class BaseFileConnector(BaseConnector):
|
||||
"""
|
||||
List full file names from remote fs path and filter by allow pattern.
|
||||
|
||||
Args:
|
||||
allow_pattern: A list of patterns of which files to pull.
|
||||
|
||||
Returns:
|
||||
list[str]: List of full paths allowed by the pattern
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def glob(self, allow_pattern: str) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
85
python/sglang/srt/connector/redis.py
Normal file
85
python/sglang/srt/connector/redis.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.connector import BaseKVConnector
|
||||
from sglang.srt.connector.serde import create_serde
|
||||
from sglang.srt.connector.utils import pull_files_from_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisConnector(BaseKVConnector):
|
||||
|
||||
def __init__(self, url: str, device: torch.device = "cpu"):
|
||||
import redis
|
||||
|
||||
super().__init__(url, device)
|
||||
parsed_url = urlparse(url)
|
||||
self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)
|
||||
self.model_name = parsed_url.path.lstrip("/")
|
||||
# TODO: more serde options
|
||||
self.s, self.d = create_serde("safe")
|
||||
|
||||
def get(self, key: str) -> Optional[torch.Tensor]:
|
||||
val = self.connection.get(key)
|
||||
|
||||
if val is None:
|
||||
logger.error("Key %s not found", key)
|
||||
return None
|
||||
|
||||
return self.d.from_bytes(val)
|
||||
|
||||
def getstr(self, key: str) -> Optional[str]:
|
||||
val = self.connection.get(key)
|
||||
if val is None:
|
||||
logger.error("Key %s not found", key)
|
||||
return None
|
||||
|
||||
return val.decode("utf-8")
|
||||
|
||||
def set(self, key: str, tensor: torch.Tensor) -> None:
|
||||
assert tensor is not None
|
||||
self.connection.set(key, self.s.to_bytes(tensor))
|
||||
|
||||
def setstr(self, key: str, obj: str) -> None:
|
||||
self.connection.set(key, obj)
|
||||
|
||||
def list(self, prefix: str) -> List[str]:
|
||||
cursor = 0
|
||||
all_keys: List[bytes] = []
|
||||
|
||||
while True:
|
||||
ret: Tuple[int, List[bytes]] = self.connection.scan(
|
||||
cursor=cursor, match=f"{prefix}*"
|
||||
) # type: ignore
|
||||
cursor, keys = ret
|
||||
all_keys.extend(keys)
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
return [key.decode("utf-8") for key in all_keys]
|
||||
|
||||
def weight_iterator(
|
||||
self, rank: int = 0
|
||||
) -> Generator[Tuple[str, bytes], None, None]:
|
||||
keys = self.list(f"{self.model_name}/keys/rank_{rank}/")
|
||||
for key in keys:
|
||||
val = self.get(key)
|
||||
key = key.removeprefix(f"{self.model_name}/keys/rank_{rank}/")
|
||||
yield key, val
|
||||
|
||||
def pull_files(
|
||||
self,
|
||||
allow_pattern: Optional[List[str]] = None,
|
||||
ignore_pattern: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
pull_files_from_db(self, self.model_name, allow_pattern, ignore_pattern)
|
||||
|
||||
def close(self):
|
||||
self.connection.close()
|
||||
super().close()
|
||||
122
python/sglang/srt/connector/s3.py
Normal file
122
python/sglang/srt/connector/s3.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import fnmatch
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Generator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.connector import BaseFileConnector
|
||||
|
||||
|
||||
def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:
|
||||
return [
|
||||
path
|
||||
for path in paths
|
||||
if any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
|
||||
]
|
||||
|
||||
|
||||
def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
|
||||
return [
|
||||
path
|
||||
for path in paths
|
||||
if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
|
||||
]
|
||||
|
||||
|
||||
def list_files(
|
||||
s3,
|
||||
path: str,
|
||||
allow_pattern: Optional[list[str]] = None,
|
||||
ignore_pattern: Optional[list[str]] = None,
|
||||
) -> tuple[str, str, list[str]]:
|
||||
"""
|
||||
List files from S3 path and filter by pattern.
|
||||
|
||||
Args:
|
||||
s3: S3 client to use.
|
||||
path: The S3 path to list from.
|
||||
allow_pattern: A list of patterns of which files to pull.
|
||||
ignore_pattern: A list of patterns of which files not to pull.
|
||||
|
||||
Returns:
|
||||
tuple[str, str, list[str]]: A tuple where:
|
||||
- The first element is the bucket name
|
||||
- The second element is string represent the bucket
|
||||
and the prefix as a dir like string
|
||||
- The third element is a list of files allowed or
|
||||
disallowed by pattern
|
||||
"""
|
||||
parts = path.removeprefix("s3://").split("/")
|
||||
prefix = "/".join(parts[1:])
|
||||
bucket_name = parts[0]
|
||||
|
||||
objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
|
||||
paths = [obj["Key"] for obj in objects.get("Contents", [])]
|
||||
|
||||
paths = _filter_ignore(paths, ["*/"])
|
||||
if allow_pattern is not None:
|
||||
paths = _filter_allow(paths, allow_pattern)
|
||||
|
||||
if ignore_pattern is not None:
|
||||
paths = _filter_ignore(paths, ignore_pattern)
|
||||
|
||||
return bucket_name, prefix, paths
|
||||
|
||||
|
||||
class S3Connector(BaseFileConnector):
|
||||
|
||||
def __init__(self, url: str) -> None:
|
||||
import boto3
|
||||
|
||||
super().__init__(url)
|
||||
self.client = boto3.client("s3")
|
||||
|
||||
def glob(self, allow_pattern: Optional[list[str]] = None) -> list[str]:
|
||||
bucket_name, _, paths = list_files(
|
||||
self.client, path=self.url, allow_pattern=allow_pattern
|
||||
)
|
||||
return [f"s3://{bucket_name}/{path}" for path in paths]
|
||||
|
||||
def pull_files(
|
||||
self,
|
||||
allow_pattern: Optional[list[str]] = None,
|
||||
ignore_pattern: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Pull files from S3 storage into the temporary directory.
|
||||
|
||||
Args:
|
||||
s3_model_path: The S3 path of the model.
|
||||
allow_pattern: A list of patterns of which files to pull.
|
||||
ignore_pattern: A list of patterns of which files not to pull.
|
||||
|
||||
"""
|
||||
bucket_name, base_dir, files = list_files(
|
||||
self.client, self.url, allow_pattern, ignore_pattern
|
||||
)
|
||||
if len(files) == 0:
|
||||
return
|
||||
|
||||
for file in files:
|
||||
destination_file = os.path.join(self.local_dir, file.removeprefix(base_dir))
|
||||
local_dir = Path(destination_file).parent
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
self.client.download_file(bucket_name, file, destination_file)
|
||||
|
||||
def weight_iterator(
|
||||
self, rank: int = 0
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
runai_safetensors_weights_iterator,
|
||||
)
|
||||
|
||||
# only support safetensor files now
|
||||
hf_weights_files = self.glob(allow_pattern=["*.safetensors"])
|
||||
return runai_safetensors_weights_iterator(hf_weights_files)
|
||||
|
||||
def close(self):
|
||||
self.client.close()
|
||||
super().close()
|
||||
31
python/sglang/srt/connector/serde/__init__.py
Normal file
31
python/sglang/srt/connector/serde/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# inspired by LMCache
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.connector.serde.safe_serde import SafeDeserializer, SafeSerializer
|
||||
from sglang.srt.connector.serde.serde import Deserializer, Serializer
|
||||
|
||||
|
||||
def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]:
|
||||
s: Optional[Serializer] = None
|
||||
d: Optional[Deserializer] = None
|
||||
|
||||
if serde_type == "safe":
|
||||
s = SafeSerializer()
|
||||
d = SafeDeserializer(torch.uint8)
|
||||
else:
|
||||
raise ValueError(f"Unknown serde type: {serde_type}")
|
||||
|
||||
return s, d
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Serializer",
|
||||
"Deserializer",
|
||||
"SafeSerializer",
|
||||
"SafeDeserializer",
|
||||
"create_serde",
|
||||
]
|
||||
29
python/sglang/srt/connector/serde/safe_serde.py
Normal file
29
python/sglang/srt/connector/serde/safe_serde.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load, save
|
||||
|
||||
from sglang.srt.connector.serde.serde import Deserializer, Serializer
|
||||
|
||||
|
||||
class SafeSerializer(Serializer):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def to_bytes(self, t: torch.Tensor) -> bytes:
|
||||
return save({"tensor_bytes": t.cpu().contiguous()})
|
||||
|
||||
|
||||
class SafeDeserializer(Deserializer):
|
||||
|
||||
def __init__(self, dtype):
|
||||
super().__init__(dtype)
|
||||
|
||||
def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:
|
||||
return load(bytes(b))["tensor_bytes"].to(dtype=self.dtype)
|
||||
|
||||
def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:
|
||||
return self.from_bytes_normal(b)
|
||||
43
python/sglang/srt/connector/serde/serde.py
Normal file
43
python/sglang/srt/connector/serde/serde.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import abc
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Serializer(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def to_bytes(self, t: torch.Tensor) -> bytes:
|
||||
"""
|
||||
Serialize a pytorch tensor to bytes. The serialized bytes should contain
|
||||
both the data and the metadata (shape, dtype, etc.) of the tensor.
|
||||
|
||||
Input:
|
||||
t: the input pytorch tensor, can be on any device, in any shape,
|
||||
with any dtype
|
||||
|
||||
Returns:
|
||||
bytes: the serialized bytes
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Deserializer(metaclass=abc.ABCMeta):
|
||||
|
||||
def __init__(self, dtype):
|
||||
self.dtype = dtype
|
||||
|
||||
@abstractmethod
|
||||
def from_bytes(self, bs: bytes) -> torch.Tensor:
|
||||
"""
|
||||
Deserialize a pytorch tensor from bytes.
|
||||
|
||||
Input:
|
||||
bytes: a stream of bytes
|
||||
|
||||
Output:
|
||||
torch.Tensor: the deserialized pytorch tensor
|
||||
"""
|
||||
raise NotImplementedError
|
||||
35
python/sglang/srt/connector/utils.py
Normal file
35
python/sglang/srt/connector/utils.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sglang.srt.connector import BaseConnector
|
||||
|
||||
|
||||
def parse_model_name(url: str) -> str:
|
||||
"""
|
||||
Parse the model name from the url.
|
||||
Only used for db connector
|
||||
"""
|
||||
parsed_url = urlparse(url)
|
||||
return parsed_url.path.lstrip("/")
|
||||
|
||||
|
||||
def pull_files_from_db(
|
||||
connector: BaseConnector,
|
||||
model_name: str,
|
||||
allow_pattern: Optional[list[str]] = None,
|
||||
ignore_pattern: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
prefix = f"{model_name}/files/"
|
||||
local_dir = connector.get_local_dir()
|
||||
files = connector.list(prefix)
|
||||
|
||||
for file in files:
|
||||
destination_file = os.path.join(local_dir, file.removeprefix(prefix))
|
||||
local_dir = Path(destination_file).parent
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
with open(destination_file, "wb") as f:
|
||||
f.write(connector.getstr(file).encode("utf-8"))
|
||||
Reference in New Issue
Block a user