feat(remote_model): support variable remote backend for model loader (#3964)

Signed-off-by: wangyu <wangyu.steph@bytedance.com>
This commit is contained in:
wangyu
2025-03-14 15:40:44 +08:00
committed by GitHub
parent 977d7cd26a
commit 1ce4878d31
22 changed files with 1055 additions and 9 deletions

View File

@@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
import enum
import logging
from sglang.srt.connector.base_connector import (
BaseConnector,
BaseFileConnector,
BaseKVConnector,
)
from sglang.srt.connector.redis import RedisConnector
from sglang.srt.connector.s3 import S3Connector
from sglang.srt.utils import parse_connector_type
logger = logging.getLogger(__name__)
class ConnectorType(str, enum.Enum):
FS = "filesystem"
KV = "KV"
def create_remote_connector(url, device="cpu") -> BaseConnector:
connector_type = parse_connector_type(url)
if connector_type == "redis":
return RedisConnector(url)
elif connector_type == "s3":
return S3Connector(url)
else:
raise ValueError(f"Invalid connector type: {url}")
def get_connector_type(client: BaseConnector) -> ConnectorType:
if isinstance(client, BaseKVConnector):
return ConnectorType.KV
if isinstance(client, BaseFileConnector):
return ConnectorType.FS
raise ValueError(f"Invalid connector type: {client}")
__all__ = [
"BaseConnector",
"BaseFileConnector",
"BaseKVConnector",
"RedisConnector",
"S3Connector",
"ConnectorType",
"create_remote_connector",
"get_connector_type",
]

View File

@@ -0,0 +1,112 @@
# SPDX-License-Identifier: Apache-2.0
import os
import shutil
import signal
import tempfile
from abc import ABC, abstractmethod
from typing import Generator, List, Optional, Tuple
import torch
class BaseConnector(ABC):
"""
For fs connector such as s3:
<connector_type>://<path>/<filename>
For kv connector such as redis:
<connector_type>://<host>:<port>/<model_name>/keys/<key>
<connector_type://<host>:<port>/<model_name>/files/<filename>
"""
def __init__(self, url: str, device: torch.device = "cpu"):
self.url = url
self.device = device
self.closed = False
self.local_dir = tempfile.mkdtemp()
for sig in (signal.SIGINT, signal.SIGTERM):
existing_handler = signal.getsignal(sig)
signal.signal(sig, self._close_by_signal(existing_handler))
def get_local_dir(self):
return self.local_dir
@abstractmethod
def weight_iterator(
self, rank: int = 0
) -> Generator[Tuple[str, torch.Tensor], None, None]:
raise NotImplementedError()
@abstractmethod
def pull_files(
self,
allow_pattern: Optional[List[str]] = None,
ignore_pattern: Optional[List[str]] = None,
) -> None:
raise NotImplementedError()
def close(self):
if self.closed:
return
self.closed = True
if os.path.exists(self.local_dir):
shutil.rmtree(self.local_dir)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def __del__(self):
self.close()
def _close_by_signal(self, existing_handler=None):
def new_handler(signum, frame):
self.close()
if existing_handler:
existing_handler(signum, frame)
return new_handler
class BaseKVConnector(BaseConnector):
@abstractmethod
def get(self, key: str) -> Optional[torch.Tensor]:
raise NotImplementedError()
@abstractmethod
def getstr(self, key: str) -> Optional[str]:
raise NotImplementedError()
@abstractmethod
def set(self, key: str, obj: torch.Tensor) -> None:
raise NotImplementedError()
@abstractmethod
def setstr(self, key: str, obj: str) -> None:
raise NotImplementedError()
@abstractmethod
def list(self, prefix: str) -> List[str]:
raise NotImplementedError()
class BaseFileConnector(BaseConnector):
"""
List full file names from remote fs path and filter by allow pattern.
Args:
allow_pattern: A list of patterns of which files to pull.
Returns:
list[str]: List of full paths allowed by the pattern
"""
@abstractmethod
def glob(self, allow_pattern: str) -> List[str]:
raise NotImplementedError()

View File

@@ -0,0 +1,85 @@
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Generator, List, Optional, Tuple
from urllib.parse import urlparse
import torch
from sglang.srt.connector import BaseKVConnector
from sglang.srt.connector.serde import create_serde
from sglang.srt.connector.utils import pull_files_from_db
logger = logging.getLogger(__name__)
class RedisConnector(BaseKVConnector):
def __init__(self, url: str, device: torch.device = "cpu"):
import redis
super().__init__(url, device)
parsed_url = urlparse(url)
self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)
self.model_name = parsed_url.path.lstrip("/")
# TODO: more serde options
self.s, self.d = create_serde("safe")
def get(self, key: str) -> Optional[torch.Tensor]:
val = self.connection.get(key)
if val is None:
logger.error("Key %s not found", key)
return None
return self.d.from_bytes(val)
def getstr(self, key: str) -> Optional[str]:
val = self.connection.get(key)
if val is None:
logger.error("Key %s not found", key)
return None
return val.decode("utf-8")
def set(self, key: str, tensor: torch.Tensor) -> None:
assert tensor is not None
self.connection.set(key, self.s.to_bytes(tensor))
def setstr(self, key: str, obj: str) -> None:
self.connection.set(key, obj)
def list(self, prefix: str) -> List[str]:
cursor = 0
all_keys: List[bytes] = []
while True:
ret: Tuple[int, List[bytes]] = self.connection.scan(
cursor=cursor, match=f"{prefix}*"
) # type: ignore
cursor, keys = ret
all_keys.extend(keys)
if cursor == 0:
break
return [key.decode("utf-8") for key in all_keys]
def weight_iterator(
self, rank: int = 0
) -> Generator[Tuple[str, bytes], None, None]:
keys = self.list(f"{self.model_name}/keys/rank_{rank}/")
for key in keys:
val = self.get(key)
key = key.removeprefix(f"{self.model_name}/keys/rank_{rank}/")
yield key, val
def pull_files(
self,
allow_pattern: Optional[List[str]] = None,
ignore_pattern: Optional[List[str]] = None,
) -> None:
pull_files_from_db(self, self.model_name, allow_pattern, ignore_pattern)
def close(self):
self.connection.close()
super().close()

View File

@@ -0,0 +1,122 @@
# SPDX-License-Identifier: Apache-2.0
import fnmatch
import os
from pathlib import Path
from typing import Generator, Optional, Tuple
import torch
from sglang.srt.connector import BaseFileConnector
def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:
return [
path
for path in paths
if any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
]
def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
return [
path
for path in paths
if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
]
def list_files(
s3,
path: str,
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None,
) -> tuple[str, str, list[str]]:
"""
List files from S3 path and filter by pattern.
Args:
s3: S3 client to use.
path: The S3 path to list from.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
Returns:
tuple[str, str, list[str]]: A tuple where:
- The first element is the bucket name
- The second element is string represent the bucket
and the prefix as a dir like string
- The third element is a list of files allowed or
disallowed by pattern
"""
parts = path.removeprefix("s3://").split("/")
prefix = "/".join(parts[1:])
bucket_name = parts[0]
objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
paths = [obj["Key"] for obj in objects.get("Contents", [])]
paths = _filter_ignore(paths, ["*/"])
if allow_pattern is not None:
paths = _filter_allow(paths, allow_pattern)
if ignore_pattern is not None:
paths = _filter_ignore(paths, ignore_pattern)
return bucket_name, prefix, paths
class S3Connector(BaseFileConnector):
def __init__(self, url: str) -> None:
import boto3
super().__init__(url)
self.client = boto3.client("s3")
def glob(self, allow_pattern: Optional[list[str]] = None) -> list[str]:
bucket_name, _, paths = list_files(
self.client, path=self.url, allow_pattern=allow_pattern
)
return [f"s3://{bucket_name}/{path}" for path in paths]
def pull_files(
self,
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None,
) -> None:
"""
Pull files from S3 storage into the temporary directory.
Args:
s3_model_path: The S3 path of the model.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
"""
bucket_name, base_dir, files = list_files(
self.client, self.url, allow_pattern, ignore_pattern
)
if len(files) == 0:
return
for file in files:
destination_file = os.path.join(self.local_dir, file.removeprefix(base_dir))
local_dir = Path(destination_file).parent
os.makedirs(local_dir, exist_ok=True)
self.client.download_file(bucket_name, file, destination_file)
def weight_iterator(
self, rank: int = 0
) -> Generator[Tuple[str, torch.Tensor], None, None]:
from sglang.srt.model_loader.weight_utils import (
runai_safetensors_weights_iterator,
)
# only support safetensor files now
hf_weights_files = self.glob(allow_pattern=["*.safetensors"])
return runai_safetensors_weights_iterator(hf_weights_files)
def close(self):
self.client.close()
super().close()

View File

@@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# inspired by LMCache
from typing import Optional, Tuple
import torch
from sglang.srt.connector.serde.safe_serde import SafeDeserializer, SafeSerializer
from sglang.srt.connector.serde.serde import Deserializer, Serializer
def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]:
s: Optional[Serializer] = None
d: Optional[Deserializer] = None
if serde_type == "safe":
s = SafeSerializer()
d = SafeDeserializer(torch.uint8)
else:
raise ValueError(f"Unknown serde type: {serde_type}")
return s, d
__all__ = [
"Serializer",
"Deserializer",
"SafeSerializer",
"SafeDeserializer",
"create_serde",
]

View File

@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Union
import torch
from safetensors.torch import load, save
from sglang.srt.connector.serde.serde import Deserializer, Serializer
class SafeSerializer(Serializer):
def __init__(self):
super().__init__()
def to_bytes(self, t: torch.Tensor) -> bytes:
return save({"tensor_bytes": t.cpu().contiguous()})
class SafeDeserializer(Deserializer):
def __init__(self, dtype):
super().__init__(dtype)
def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:
return load(bytes(b))["tensor_bytes"].to(dtype=self.dtype)
def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:
return self.from_bytes_normal(b)

View File

@@ -0,0 +1,43 @@
# SPDX-License-Identifier: Apache-2.0
import abc
from abc import ABC, abstractmethod
import torch
class Serializer(ABC):
@abstractmethod
def to_bytes(self, t: torch.Tensor) -> bytes:
"""
Serialize a pytorch tensor to bytes. The serialized bytes should contain
both the data and the metadata (shape, dtype, etc.) of the tensor.
Input:
t: the input pytorch tensor, can be on any device, in any shape,
with any dtype
Returns:
bytes: the serialized bytes
"""
raise NotImplementedError
class Deserializer(metaclass=abc.ABCMeta):
def __init__(self, dtype):
self.dtype = dtype
@abstractmethod
def from_bytes(self, bs: bytes) -> torch.Tensor:
"""
Deserialize a pytorch tensor from bytes.
Input:
bytes: a stream of bytes
Output:
torch.Tensor: the deserialized pytorch tensor
"""
raise NotImplementedError

View File

@@ -0,0 +1,35 @@
# SPDX-License-Identifier: Apache-2.0
import os
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
from sglang.srt.connector import BaseConnector
def parse_model_name(url: str) -> str:
"""
Parse the model name from the url.
Only used for db connector
"""
parsed_url = urlparse(url)
return parsed_url.path.lstrip("/")
def pull_files_from_db(
connector: BaseConnector,
model_name: str,
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None,
) -> None:
prefix = f"{model_name}/files/"
local_dir = connector.get_local_dir()
files = connector.list(prefix)
for file in files:
destination_file = os.path.join(local_dir, file.removeprefix(prefix))
local_dir = Path(destination_file).parent
os.makedirs(local_dir, exist_ok=True)
with open(destination_file, "wb") as f:
f.write(connector.getstr(file).encode("utf-8"))