feat(remote_model): support variable remote backend for model loader (#3964)
Signed-off-by: wangyu <wangyu.steph@bytedance.com>
This commit is contained in:
51
examples/runtime/engine/save_remote_state.py
Normal file
51
examples/runtime/engine/save_remote_state.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""
|
||||||
|
Saves each worker's model state dict directly to a checkpoint, which enables a
|
||||||
|
fast load path for large tensor-parallel models where each worker only needs to
|
||||||
|
read its own shard rather than the entire checkpoint.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
python save_remote_state.py \
|
||||||
|
--model-path /path/to/load \
|
||||||
|
--tensor-parallel-size 8 \
|
||||||
|
--remote-model-save-url [protocol]://[host]:[port]/[model_name] \
|
||||||
|
|
||||||
|
Then, the model can be loaded with
|
||||||
|
|
||||||
|
llm = Engine(
|
||||||
|
model_path="/path/to/save",
|
||||||
|
--remote-model-url [protocol]://[host]:[port]/[model_name],
|
||||||
|
tensor_parallel_size=8,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
import dataclasses
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from sglang import Engine, ServerArgs
|
||||||
|
|
||||||
|
parser = ArgumentParser()
|
||||||
|
ServerArgs.add_cli_args(parser)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--remote-model-save-url",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="remote address to store model weights",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
engine_args = ServerArgs.from_cli_args(args)
|
||||||
|
model_path = engine_args.model_path
|
||||||
|
if not Path(model_path).is_dir():
|
||||||
|
raise ValueError("model path must be a local directory")
|
||||||
|
# Create LLM instance from arguments
|
||||||
|
llm = Engine(**dataclasses.asdict(engine_args))
|
||||||
|
llm.save_remote_model(url=args.remote_model_save_url)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
74
examples/runtime/engine/save_sharded_state.py
Normal file
74
examples/runtime/engine/save_sharded_state.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""
|
||||||
|
Saves each worker's model state dict directly to a checkpoint, which enables a
|
||||||
|
fast load path for large tensor-parallel models where each worker only needs to
|
||||||
|
read its own shard rather than the entire checkpoint.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
python save_sharded_state.py \
|
||||||
|
--model-path /path/to/load \
|
||||||
|
--quantization deepspeedfp \
|
||||||
|
--tensor-parallel-size 8 \
|
||||||
|
--output /path/to/save
|
||||||
|
|
||||||
|
Then, the model can be loaded with
|
||||||
|
|
||||||
|
llm = Engine(
|
||||||
|
model_path="/path/to/save",
|
||||||
|
load_format="sharded_state",
|
||||||
|
quantization="deepspeedfp",
|
||||||
|
tensor_parallel_size=8,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
import dataclasses
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from sglang import Engine, ServerArgs
|
||||||
|
|
||||||
|
parser = ArgumentParser()
|
||||||
|
ServerArgs.add_cli_args(parser)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output", "-o", required=True, type=str, help="path to output checkpoint"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--file-pattern", type=str, help="string pattern of saved filenames"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-file-size",
|
||||||
|
type=str,
|
||||||
|
default=5 * 1024**3,
|
||||||
|
help="max size (in bytes) of each safetensors file",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
engine_args = ServerArgs.from_cli_args(args)
|
||||||
|
model_path = engine_args.model_path
|
||||||
|
if not Path(model_path).is_dir():
|
||||||
|
raise ValueError("model path must be a local directory")
|
||||||
|
# Create LLM instance from arguments
|
||||||
|
llm = Engine(**dataclasses.asdict(engine_args))
|
||||||
|
Path(args.output).mkdir(exist_ok=True)
|
||||||
|
llm.save_sharded_model(
|
||||||
|
path=args.output, pattern=args.file_pattern, max_size=args.max_file_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy metadata files to output directory
|
||||||
|
for file in os.listdir(model_path):
|
||||||
|
if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
|
||||||
|
if os.path.isdir(os.path.join(model_path, file)):
|
||||||
|
shutil.copytree(
|
||||||
|
os.path.join(model_path, file), os.path.join(args.output, file)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
shutil.copy(os.path.join(model_path, file), args.output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
@@ -32,6 +32,7 @@ from sglang.lang.choices import (
|
|||||||
)
|
)
|
||||||
from sglang.utils import LazyImport
|
from sglang.utils import LazyImport
|
||||||
|
|
||||||
|
ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
|
||||||
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
|
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
|
||||||
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
|
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
|
||||||
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
|
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
|
||||||
@@ -67,6 +68,7 @@ __all__ = [
|
|||||||
"greedy_token_selection",
|
"greedy_token_selection",
|
||||||
"token_length_normalized",
|
"token_length_normalized",
|
||||||
"unconditional_likelihood_normalized",
|
"unconditional_likelihood_normalized",
|
||||||
|
"ServerArgs",
|
||||||
"Anthropic",
|
"Anthropic",
|
||||||
"LiteLLM",
|
"LiteLLM",
|
||||||
"OpenAI",
|
"OpenAI",
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ class LoadFormat(str, enum.Enum):
|
|||||||
MISTRAL = "mistral"
|
MISTRAL = "mistral"
|
||||||
LAYERED = "layered"
|
LAYERED = "layered"
|
||||||
JAX = "jax"
|
JAX = "jax"
|
||||||
|
REMOTE = "remote"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -51,13 +51,14 @@ class ModelConfig:
|
|||||||
self.quantization = quantization
|
self.quantization = quantization
|
||||||
|
|
||||||
# Parse args
|
# Parse args
|
||||||
|
self.maybe_pull_model_tokenizer_from_remote()
|
||||||
self.model_override_args = json.loads(model_override_args)
|
self.model_override_args = json.loads(model_override_args)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if override_config_file and override_config_file.strip():
|
if override_config_file and override_config_file.strip():
|
||||||
kwargs["_configuration_file"] = override_config_file.strip()
|
kwargs["_configuration_file"] = override_config_file.strip()
|
||||||
|
|
||||||
self.hf_config = get_config(
|
self.hf_config = get_config(
|
||||||
model_path,
|
self.model_path,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
model_override_args=self.model_override_args,
|
model_override_args=self.model_override_args,
|
||||||
@@ -318,6 +319,29 @@ class ModelConfig:
|
|||||||
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
|
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
|
||||||
return eos_ids
|
return eos_ids
|
||||||
|
|
||||||
|
def maybe_pull_model_tokenizer_from_remote(self) -> None:
|
||||||
|
"""
|
||||||
|
Pull the model config files to a temporary
|
||||||
|
directory in case of remote.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model name or path.
|
||||||
|
|
||||||
|
"""
|
||||||
|
from sglang.srt.connector import create_remote_connector
|
||||||
|
from sglang.srt.utils import is_remote_url
|
||||||
|
|
||||||
|
if is_remote_url(self.model_path):
|
||||||
|
logger.info("Pulling model configs from remote...")
|
||||||
|
# BaseConnector implements __del__() to clean up the local dir.
|
||||||
|
# Since config files need to exist all the time, so we DO NOT use
|
||||||
|
# with statement to avoid closing the client.
|
||||||
|
client = create_remote_connector(self.model_path)
|
||||||
|
if is_remote_url(self.model_path):
|
||||||
|
client.pull_files(allow_pattern=["*config.json"])
|
||||||
|
self.model_weights = self.model_path
|
||||||
|
self.model_path = client.get_local_dir()
|
||||||
|
|
||||||
|
|
||||||
def get_hf_text_config(config: PretrainedConfig):
|
def get_hf_text_config(config: PretrainedConfig):
|
||||||
"""Get the "sub" config relevant to llm for multi modal models.
|
"""Get the "sub" config relevant to llm for multi modal models.
|
||||||
|
|||||||
51
python/sglang/srt/connector/__init__.py
Normal file
51
python/sglang/srt/connector/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from sglang.srt.connector.base_connector import (
|
||||||
|
BaseConnector,
|
||||||
|
BaseFileConnector,
|
||||||
|
BaseKVConnector,
|
||||||
|
)
|
||||||
|
from sglang.srt.connector.redis import RedisConnector
|
||||||
|
from sglang.srt.connector.s3 import S3Connector
|
||||||
|
from sglang.srt.utils import parse_connector_type
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorType(str, enum.Enum):
|
||||||
|
FS = "filesystem"
|
||||||
|
KV = "KV"
|
||||||
|
|
||||||
|
|
||||||
|
def create_remote_connector(url, device="cpu") -> BaseConnector:
|
||||||
|
connector_type = parse_connector_type(url)
|
||||||
|
if connector_type == "redis":
|
||||||
|
return RedisConnector(url)
|
||||||
|
elif connector_type == "s3":
|
||||||
|
return S3Connector(url)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid connector type: {url}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_connector_type(client: BaseConnector) -> ConnectorType:
|
||||||
|
if isinstance(client, BaseKVConnector):
|
||||||
|
return ConnectorType.KV
|
||||||
|
if isinstance(client, BaseFileConnector):
|
||||||
|
return ConnectorType.FS
|
||||||
|
|
||||||
|
raise ValueError(f"Invalid connector type: {client}")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseConnector",
|
||||||
|
"BaseFileConnector",
|
||||||
|
"BaseKVConnector",
|
||||||
|
"RedisConnector",
|
||||||
|
"S3Connector",
|
||||||
|
"ConnectorType",
|
||||||
|
"create_remote_connector",
|
||||||
|
"get_connector_type",
|
||||||
|
]
|
||||||
112
python/sglang/srt/connector/base_connector.py
Normal file
112
python/sglang/srt/connector/base_connector.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import signal
|
||||||
|
import tempfile
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Generator, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class BaseConnector(ABC):
|
||||||
|
"""
|
||||||
|
For fs connector such as s3:
|
||||||
|
<connector_type>://<path>/<filename>
|
||||||
|
|
||||||
|
For kv connector such as redis:
|
||||||
|
<connector_type>://<host>:<port>/<model_name>/keys/<key>
|
||||||
|
<connector_type://<host>:<port>/<model_name>/files/<filename>
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, url: str, device: torch.device = "cpu"):
|
||||||
|
self.url = url
|
||||||
|
self.device = device
|
||||||
|
self.closed = False
|
||||||
|
self.local_dir = tempfile.mkdtemp()
|
||||||
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||||
|
existing_handler = signal.getsignal(sig)
|
||||||
|
signal.signal(sig, self._close_by_signal(existing_handler))
|
||||||
|
|
||||||
|
def get_local_dir(self):
|
||||||
|
return self.local_dir
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def weight_iterator(
|
||||||
|
self, rank: int = 0
|
||||||
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def pull_files(
|
||||||
|
self,
|
||||||
|
allow_pattern: Optional[List[str]] = None,
|
||||||
|
ignore_pattern: Optional[List[str]] = None,
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self.closed:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.closed = True
|
||||||
|
if os.path.exists(self.local_dir):
|
||||||
|
shutil.rmtree(self.local_dir)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def _close_by_signal(self, existing_handler=None):
|
||||||
|
|
||||||
|
def new_handler(signum, frame):
|
||||||
|
self.close()
|
||||||
|
if existing_handler:
|
||||||
|
existing_handler(signum, frame)
|
||||||
|
|
||||||
|
return new_handler
|
||||||
|
|
||||||
|
|
||||||
|
class BaseKVConnector(BaseConnector):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, key: str) -> Optional[torch.Tensor]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def getstr(self, key: str) -> Optional[str]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set(self, key: str, obj: torch.Tensor) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def setstr(self, key: str, obj: str) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list(self, prefix: str) -> List[str]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFileConnector(BaseConnector):
|
||||||
|
"""
|
||||||
|
List full file names from remote fs path and filter by allow pattern.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allow_pattern: A list of patterns of which files to pull.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: List of full paths allowed by the pattern
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def glob(self, allow_pattern: str) -> List[str]:
|
||||||
|
raise NotImplementedError()
|
||||||
85
python/sglang/srt/connector/redis.py
Normal file
85
python/sglang/srt/connector/redis.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Generator, List, Optional, Tuple
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.connector import BaseKVConnector
|
||||||
|
from sglang.srt.connector.serde import create_serde
|
||||||
|
from sglang.srt.connector.utils import pull_files_from_db
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisConnector(BaseKVConnector):
|
||||||
|
|
||||||
|
def __init__(self, url: str, device: torch.device = "cpu"):
|
||||||
|
import redis
|
||||||
|
|
||||||
|
super().__init__(url, device)
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)
|
||||||
|
self.model_name = parsed_url.path.lstrip("/")
|
||||||
|
# TODO: more serde options
|
||||||
|
self.s, self.d = create_serde("safe")
|
||||||
|
|
||||||
|
def get(self, key: str) -> Optional[torch.Tensor]:
|
||||||
|
val = self.connection.get(key)
|
||||||
|
|
||||||
|
if val is None:
|
||||||
|
logger.error("Key %s not found", key)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self.d.from_bytes(val)
|
||||||
|
|
||||||
|
def getstr(self, key: str) -> Optional[str]:
|
||||||
|
val = self.connection.get(key)
|
||||||
|
if val is None:
|
||||||
|
logger.error("Key %s not found", key)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return val.decode("utf-8")
|
||||||
|
|
||||||
|
def set(self, key: str, tensor: torch.Tensor) -> None:
|
||||||
|
assert tensor is not None
|
||||||
|
self.connection.set(key, self.s.to_bytes(tensor))
|
||||||
|
|
||||||
|
def setstr(self, key: str, obj: str) -> None:
|
||||||
|
self.connection.set(key, obj)
|
||||||
|
|
||||||
|
def list(self, prefix: str) -> List[str]:
|
||||||
|
cursor = 0
|
||||||
|
all_keys: List[bytes] = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
ret: Tuple[int, List[bytes]] = self.connection.scan(
|
||||||
|
cursor=cursor, match=f"{prefix}*"
|
||||||
|
) # type: ignore
|
||||||
|
cursor, keys = ret
|
||||||
|
all_keys.extend(keys)
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
return [key.decode("utf-8") for key in all_keys]
|
||||||
|
|
||||||
|
def weight_iterator(
|
||||||
|
self, rank: int = 0
|
||||||
|
) -> Generator[Tuple[str, bytes], None, None]:
|
||||||
|
keys = self.list(f"{self.model_name}/keys/rank_{rank}/")
|
||||||
|
for key in keys:
|
||||||
|
val = self.get(key)
|
||||||
|
key = key.removeprefix(f"{self.model_name}/keys/rank_{rank}/")
|
||||||
|
yield key, val
|
||||||
|
|
||||||
|
def pull_files(
|
||||||
|
self,
|
||||||
|
allow_pattern: Optional[List[str]] = None,
|
||||||
|
ignore_pattern: Optional[List[str]] = None,
|
||||||
|
) -> None:
|
||||||
|
pull_files_from_db(self, self.model_name, allow_pattern, ignore_pattern)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.connection.close()
|
||||||
|
super().close()
|
||||||
122
python/sglang/srt/connector/s3.py
Normal file
122
python/sglang/srt/connector/s3.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import fnmatch
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Generator, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.connector import BaseFileConnector
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:
|
||||||
|
return [
|
||||||
|
path
|
||||||
|
for path in paths
|
||||||
|
if any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
|
||||||
|
return [
|
||||||
|
path
|
||||||
|
for path in paths
|
||||||
|
if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def list_files(
|
||||||
|
s3,
|
||||||
|
path: str,
|
||||||
|
allow_pattern: Optional[list[str]] = None,
|
||||||
|
ignore_pattern: Optional[list[str]] = None,
|
||||||
|
) -> tuple[str, str, list[str]]:
|
||||||
|
"""
|
||||||
|
List files from S3 path and filter by pattern.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s3: S3 client to use.
|
||||||
|
path: The S3 path to list from.
|
||||||
|
allow_pattern: A list of patterns of which files to pull.
|
||||||
|
ignore_pattern: A list of patterns of which files not to pull.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, str, list[str]]: A tuple where:
|
||||||
|
- The first element is the bucket name
|
||||||
|
- The second element is string represent the bucket
|
||||||
|
and the prefix as a dir like string
|
||||||
|
- The third element is a list of files allowed or
|
||||||
|
disallowed by pattern
|
||||||
|
"""
|
||||||
|
parts = path.removeprefix("s3://").split("/")
|
||||||
|
prefix = "/".join(parts[1:])
|
||||||
|
bucket_name = parts[0]
|
||||||
|
|
||||||
|
objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
|
||||||
|
paths = [obj["Key"] for obj in objects.get("Contents", [])]
|
||||||
|
|
||||||
|
paths = _filter_ignore(paths, ["*/"])
|
||||||
|
if allow_pattern is not None:
|
||||||
|
paths = _filter_allow(paths, allow_pattern)
|
||||||
|
|
||||||
|
if ignore_pattern is not None:
|
||||||
|
paths = _filter_ignore(paths, ignore_pattern)
|
||||||
|
|
||||||
|
return bucket_name, prefix, paths
|
||||||
|
|
||||||
|
|
||||||
|
class S3Connector(BaseFileConnector):
|
||||||
|
|
||||||
|
def __init__(self, url: str) -> None:
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
super().__init__(url)
|
||||||
|
self.client = boto3.client("s3")
|
||||||
|
|
||||||
|
def glob(self, allow_pattern: Optional[list[str]] = None) -> list[str]:
|
||||||
|
bucket_name, _, paths = list_files(
|
||||||
|
self.client, path=self.url, allow_pattern=allow_pattern
|
||||||
|
)
|
||||||
|
return [f"s3://{bucket_name}/{path}" for path in paths]
|
||||||
|
|
||||||
|
def pull_files(
|
||||||
|
self,
|
||||||
|
allow_pattern: Optional[list[str]] = None,
|
||||||
|
ignore_pattern: Optional[list[str]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Pull files from S3 storage into the temporary directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s3_model_path: The S3 path of the model.
|
||||||
|
allow_pattern: A list of patterns of which files to pull.
|
||||||
|
ignore_pattern: A list of patterns of which files not to pull.
|
||||||
|
|
||||||
|
"""
|
||||||
|
bucket_name, base_dir, files = list_files(
|
||||||
|
self.client, self.url, allow_pattern, ignore_pattern
|
||||||
|
)
|
||||||
|
if len(files) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
destination_file = os.path.join(self.local_dir, file.removeprefix(base_dir))
|
||||||
|
local_dir = Path(destination_file).parent
|
||||||
|
os.makedirs(local_dir, exist_ok=True)
|
||||||
|
self.client.download_file(bucket_name, file, destination_file)
|
||||||
|
|
||||||
|
def weight_iterator(
|
||||||
|
self, rank: int = 0
|
||||||
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||||
|
from sglang.srt.model_loader.weight_utils import (
|
||||||
|
runai_safetensors_weights_iterator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# only support safetensor files now
|
||||||
|
hf_weights_files = self.glob(allow_pattern=["*.safetensors"])
|
||||||
|
return runai_safetensors_weights_iterator(hf_weights_files)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.client.close()
|
||||||
|
super().close()
|
||||||
31
python/sglang/srt/connector/serde/__init__.py
Normal file
31
python/sglang/srt/connector/serde/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
# inspired by LMCache
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.connector.serde.safe_serde import SafeDeserializer, SafeSerializer
|
||||||
|
from sglang.srt.connector.serde.serde import Deserializer, Serializer
|
||||||
|
|
||||||
|
|
||||||
|
def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]:
|
||||||
|
s: Optional[Serializer] = None
|
||||||
|
d: Optional[Deserializer] = None
|
||||||
|
|
||||||
|
if serde_type == "safe":
|
||||||
|
s = SafeSerializer()
|
||||||
|
d = SafeDeserializer(torch.uint8)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown serde type: {serde_type}")
|
||||||
|
|
||||||
|
return s, d
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Serializer",
|
||||||
|
"Deserializer",
|
||||||
|
"SafeSerializer",
|
||||||
|
"SafeDeserializer",
|
||||||
|
"create_serde",
|
||||||
|
]
|
||||||
29
python/sglang/srt/connector/serde/safe_serde.py
Normal file
29
python/sglang/srt/connector/serde/safe_serde.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load, save
|
||||||
|
|
||||||
|
from sglang.srt.connector.serde.serde import Deserializer, Serializer
|
||||||
|
|
||||||
|
|
||||||
|
class SafeSerializer(Serializer):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def to_bytes(self, t: torch.Tensor) -> bytes:
|
||||||
|
return save({"tensor_bytes": t.cpu().contiguous()})
|
||||||
|
|
||||||
|
|
||||||
|
class SafeDeserializer(Deserializer):
|
||||||
|
|
||||||
|
def __init__(self, dtype):
|
||||||
|
super().__init__(dtype)
|
||||||
|
|
||||||
|
def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:
|
||||||
|
return load(bytes(b))["tensor_bytes"].to(dtype=self.dtype)
|
||||||
|
|
||||||
|
def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:
|
||||||
|
return self.from_bytes_normal(b)
|
||||||
43
python/sglang/srt/connector/serde/serde.py
Normal file
43
python/sglang/srt/connector/serde/serde.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import abc
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Serializer(ABC):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def to_bytes(self, t: torch.Tensor) -> bytes:
|
||||||
|
"""
|
||||||
|
Serialize a pytorch tensor to bytes. The serialized bytes should contain
|
||||||
|
both the data and the metadata (shape, dtype, etc.) of the tensor.
|
||||||
|
|
||||||
|
Input:
|
||||||
|
t: the input pytorch tensor, can be on any device, in any shape,
|
||||||
|
with any dtype
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: the serialized bytes
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class Deserializer(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
|
def __init__(self, dtype):
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def from_bytes(self, bs: bytes) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Deserialize a pytorch tensor from bytes.
|
||||||
|
|
||||||
|
Input:
|
||||||
|
bytes: a stream of bytes
|
||||||
|
|
||||||
|
Output:
|
||||||
|
torch.Tensor: the deserialized pytorch tensor
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
35
python/sglang/srt/connector/utils.py
Normal file
35
python/sglang/srt/connector/utils.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from sglang.srt.connector import BaseConnector
|
||||||
|
|
||||||
|
|
||||||
|
def parse_model_name(url: str) -> str:
|
||||||
|
"""
|
||||||
|
Parse the model name from the url.
|
||||||
|
Only used for db connector
|
||||||
|
"""
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
return parsed_url.path.lstrip("/")
|
||||||
|
|
||||||
|
|
||||||
|
def pull_files_from_db(
|
||||||
|
connector: BaseConnector,
|
||||||
|
model_name: str,
|
||||||
|
allow_pattern: Optional[list[str]] = None,
|
||||||
|
ignore_pattern: Optional[list[str]] = None,
|
||||||
|
) -> None:
|
||||||
|
prefix = f"{model_name}/files/"
|
||||||
|
local_dir = connector.get_local_dir()
|
||||||
|
files = connector.list(prefix)
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
destination_file = os.path.join(local_dir, file.removeprefix(prefix))
|
||||||
|
local_dir = Path(destination_file).parent
|
||||||
|
os.makedirs(local_dir, exist_ok=True)
|
||||||
|
with open(destination_file, "wb") as f:
|
||||||
|
f.write(connector.getstr(file).encode("utf-8"))
|
||||||
@@ -27,6 +27,9 @@ import signal
|
|||||||
import threading
|
import threading
|
||||||
from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
|
from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
import zmq.asyncio
|
||||||
|
|
||||||
# Fix a bug of Python threading
|
# Fix a bug of Python threading
|
||||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||||
|
|
||||||
@@ -44,6 +47,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
ReleaseMemoryOccupationReqInput,
|
ReleaseMemoryOccupationReqInput,
|
||||||
ResumeMemoryOccupationReqInput,
|
ResumeMemoryOccupationReqInput,
|
||||||
|
RpcReqInput,
|
||||||
|
RpcReqOutput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
@@ -57,6 +62,7 @@ from sglang.srt.utils import (
|
|||||||
MultiprocessingSerializer,
|
MultiprocessingSerializer,
|
||||||
assert_pkg_version,
|
assert_pkg_version,
|
||||||
configure_logger,
|
configure_logger,
|
||||||
|
get_zmq_socket,
|
||||||
kill_process_tree,
|
kill_process_tree,
|
||||||
launch_dummy_health_check_server,
|
launch_dummy_health_check_server,
|
||||||
maybe_set_triton_cache_manager,
|
maybe_set_triton_cache_manager,
|
||||||
@@ -102,15 +108,25 @@ class Engine:
|
|||||||
# Shutdown the subprocesses automatically when the program exits
|
# Shutdown the subprocesses automatically when the program exits
|
||||||
atexit.register(self.shutdown)
|
atexit.register(self.shutdown)
|
||||||
|
|
||||||
|
# Allocate ports for inter-process communications
|
||||||
|
port_args = PortArgs.init_new(server_args)
|
||||||
|
logger.info(f"{server_args=}")
|
||||||
|
|
||||||
# Launch subprocesses
|
# Launch subprocesses
|
||||||
tokenizer_manager, scheduler_info = _launch_subprocesses(
|
tokenizer_manager, scheduler_info = _launch_subprocesses(
|
||||||
server_args=server_args
|
server_args=server_args,
|
||||||
|
port_args=port_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
self.tokenizer_manager = tokenizer_manager
|
self.tokenizer_manager = tokenizer_manager
|
||||||
self.scheduler_info = scheduler_info
|
self.scheduler_info = scheduler_info
|
||||||
|
|
||||||
|
context = zmq.Context(2)
|
||||||
|
self.send_to_rpc = get_zmq_socket(
|
||||||
|
context, zmq.DEALER, port_args.rpc_ipc_name, True
|
||||||
|
)
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||||
@@ -350,6 +366,23 @@ class Engine:
|
|||||||
self.tokenizer_manager.resume_memory_occupation(obj, None)
|
self.tokenizer_manager.resume_memory_occupation(obj, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Execute an RPC call on all scheduler processes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def collective_rpc(self, method: str, **kwargs):
|
||||||
|
obj = RpcReqInput(method=method, parameters=kwargs)
|
||||||
|
self.send_to_rpc.send_pyobj(obj)
|
||||||
|
recv_req = self.send_to_rpc.recv_pyobj(zmq.BLOCKY)
|
||||||
|
assert isinstance(recv_req, RpcReqOutput)
|
||||||
|
assert recv_req.success, recv_req.message
|
||||||
|
|
||||||
|
def save_remote_model(self, **kwargs):
|
||||||
|
self.collective_rpc("save_remote_model", **kwargs)
|
||||||
|
|
||||||
|
def save_sharded_model(self, **kwargs):
|
||||||
|
self.collective_rpc("save_sharded_model", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _set_envs_and_config(server_args: ServerArgs):
|
def _set_envs_and_config(server_args: ServerArgs):
|
||||||
# Set global environments
|
# Set global environments
|
||||||
@@ -408,7 +441,9 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
mp.set_start_method("spawn", force=True)
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
|
||||||
def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dict]:
|
def _launch_subprocesses(
|
||||||
|
server_args: ServerArgs, port_args: Optional[PortArgs] = None
|
||||||
|
) -> Tuple[TokenizerManager, Dict]:
|
||||||
"""
|
"""
|
||||||
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
|
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
|
||||||
"""
|
"""
|
||||||
@@ -418,8 +453,9 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
|
|||||||
_set_envs_and_config(server_args)
|
_set_envs_and_config(server_args)
|
||||||
|
|
||||||
# Allocate ports for inter-process communications
|
# Allocate ports for inter-process communications
|
||||||
port_args = PortArgs.init_new(server_args)
|
if port_args is None:
|
||||||
logger.info(f"{server_args=}")
|
port_args = PortArgs.init_new(server_args)
|
||||||
|
logger.info(f"{server_args=}")
|
||||||
|
|
||||||
# If using model from www.modelscope.cn, first download the model.
|
# If using model from www.modelscope.cn, first download the model.
|
||||||
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
|
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
|
||||||
|
|||||||
@@ -37,6 +37,8 @@ from sglang.srt.configs import (
|
|||||||
MultiModalityConfig,
|
MultiModalityConfig,
|
||||||
Qwen2_5_VLConfig,
|
Qwen2_5_VLConfig,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.connector import create_remote_connector
|
||||||
|
from sglang.srt.utils import is_remote_url
|
||||||
|
|
||||||
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||||
ChatGLMConfig.model_type: ChatGLMConfig,
|
ChatGLMConfig.model_type: ChatGLMConfig,
|
||||||
@@ -155,6 +157,14 @@ def get_tokenizer(
|
|||||||
kwargs["gguf_file"] = tokenizer_name
|
kwargs["gguf_file"] = tokenizer_name
|
||||||
tokenizer_name = Path(tokenizer_name).parent
|
tokenizer_name = Path(tokenizer_name).parent
|
||||||
|
|
||||||
|
if is_remote_url(tokenizer_name):
|
||||||
|
# BaseConnector implements __del__() to clean up the local dir.
|
||||||
|
# Since config files need to exist all the time, so we DO NOT use
|
||||||
|
# with statement to avoid closing the client.
|
||||||
|
client = create_remote_connector(tokenizer_name)
|
||||||
|
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
|
||||||
|
tokenizer_name = client.get_local_dir()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
|
|||||||
@@ -723,3 +723,15 @@ class SeparateReasoningReqInput:
|
|||||||
class VertexGenerateReqInput:
|
class VertexGenerateReqInput:
|
||||||
instances: List[dict]
|
instances: List[dict]
|
||||||
parameters: Optional[dict] = None
|
parameters: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RpcReqInput:
|
||||||
|
method: str
|
||||||
|
parameters: Optional[Dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RpcReqOutput:
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ import psutil
|
|||||||
import setproctitle
|
import setproctitle
|
||||||
import torch
|
import torch
|
||||||
import zmq
|
import zmq
|
||||||
|
from torch.distributed import barrier
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
@@ -59,6 +60,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
ReleaseMemoryOccupationReqOutput,
|
ReleaseMemoryOccupationReqOutput,
|
||||||
ResumeMemoryOccupationReqInput,
|
ResumeMemoryOccupationReqInput,
|
||||||
ResumeMemoryOccupationReqOutput,
|
ResumeMemoryOccupationReqOutput,
|
||||||
|
RpcReqInput,
|
||||||
|
RpcReqOutput,
|
||||||
SetInternalStateReq,
|
SetInternalStateReq,
|
||||||
SetInternalStateReqOutput,
|
SetInternalStateReqOutput,
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
@@ -193,8 +196,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|||||||
self.send_to_detokenizer = get_zmq_socket(
|
self.send_to_detokenizer = get_zmq_socket(
|
||||||
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
|
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.recv_from_rpc = get_zmq_socket(
|
||||||
|
context, zmq.DEALER, port_args.rpc_ipc_name, False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.recv_from_tokenizer = None
|
self.recv_from_tokenizer = None
|
||||||
|
self.recv_from_rpc = None
|
||||||
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
||||||
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
||||||
|
|
||||||
@@ -376,6 +384,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|||||||
(ProfileReq, self.profile),
|
(ProfileReq, self.profile),
|
||||||
(GetInternalStateReq, self.get_internal_state),
|
(GetInternalStateReq, self.get_internal_state),
|
||||||
(SetInternalStateReq, self.set_internal_state),
|
(SetInternalStateReq, self.set_internal_state),
|
||||||
|
(RpcReqInput, self.handle_rpc_request),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -549,6 +558,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|||||||
except zmq.ZMQError:
|
except zmq.ZMQError:
|
||||||
break
|
break
|
||||||
recv_reqs.append(recv_req)
|
recv_reqs.append(recv_req)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
|
||||||
|
except zmq.ZMQError:
|
||||||
|
break
|
||||||
|
recv_reqs.append(recv_rpc)
|
||||||
else:
|
else:
|
||||||
recv_reqs = None
|
recv_reqs = None
|
||||||
|
|
||||||
@@ -600,7 +616,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|||||||
|
|
||||||
output = self._request_dispatcher(recv_req)
|
output = self._request_dispatcher(recv_req)
|
||||||
if output is not None:
|
if output is not None:
|
||||||
self.send_to_tokenizer.send_pyobj(output)
|
if isinstance(output, RpcReqOutput):
|
||||||
|
if self.recv_from_rpc is not None:
|
||||||
|
self.recv_from_rpc.send_pyobj(output)
|
||||||
|
else:
|
||||||
|
self.send_to_tokenizer.send_pyobj(output)
|
||||||
|
|
||||||
def handle_generate_request(
|
def handle_generate_request(
|
||||||
self,
|
self,
|
||||||
@@ -1492,6 +1512,47 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|||||||
server_args=global_server_args_dict,
|
server_args=global_server_args_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def handle_rpc_request(self, recv_req: RpcReqInput):
|
||||||
|
# Handle RPC requests
|
||||||
|
logger.info(
|
||||||
|
f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
|
||||||
|
)
|
||||||
|
|
||||||
|
success = True
|
||||||
|
exec = None
|
||||||
|
try:
|
||||||
|
func = getattr(self, recv_req.method)
|
||||||
|
func(recv_req.parameters)
|
||||||
|
except Exception as e:
|
||||||
|
success = False
|
||||||
|
exec = e
|
||||||
|
logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")
|
||||||
|
|
||||||
|
barrier()
|
||||||
|
return RpcReqOutput(success, "" if not exec else str(exec))
|
||||||
|
|
||||||
|
def save_remote_model(self, params):
|
||||||
|
url = params["url"]
|
||||||
|
|
||||||
|
if isinstance(self.tp_worker, TpModelWorkerClient):
|
||||||
|
worker = self.tp_worker.worker
|
||||||
|
else:
|
||||||
|
worker = self.tp_worker
|
||||||
|
|
||||||
|
worker.model_runner.save_remote_model(url)
|
||||||
|
|
||||||
|
def save_sharded_model(self, params):
|
||||||
|
if isinstance(self.tp_worker, TpModelWorkerClient):
|
||||||
|
worker = self.tp_worker.worker
|
||||||
|
else:
|
||||||
|
worker = self.tp_worker
|
||||||
|
|
||||||
|
worker.model_runner.save_sharded_model(
|
||||||
|
path=params["path"],
|
||||||
|
pattern=params["pattern"],
|
||||||
|
max_size=params["max_size"],
|
||||||
|
)
|
||||||
|
|
||||||
def abort_request(self, recv_req: AbortReq):
|
def abort_request(self, recv_req: AbortReq):
|
||||||
# Delete requests in the waiting queue
|
# Delete requests in the waiting queue
|
||||||
to_del = []
|
to_del = []
|
||||||
|
|||||||
@@ -1009,6 +1009,22 @@ class ModelRunner:
|
|||||||
return False
|
return False
|
||||||
return rope_scaling.get("type", None) == "mrope"
|
return rope_scaling.get("type", None) == "mrope"
|
||||||
|
|
||||||
|
def save_remote_model(self, url: str):
|
||||||
|
from sglang.srt.model_loader.loader import RemoteModelLoader
|
||||||
|
|
||||||
|
logger.info(f"Saving model to {url}")
|
||||||
|
RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)
|
||||||
|
|
||||||
|
def save_sharded_model(
|
||||||
|
self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
|
||||||
|
):
|
||||||
|
from sglang.srt.model_loader.loader import ShardedStateLoader
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
|
||||||
|
)
|
||||||
|
ShardedStateLoader.save_model(self.model, path, pattern, max_size)
|
||||||
|
|
||||||
|
|
||||||
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
|
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
|
||||||
params_dict = dict(model.named_parameters())
|
params_dict = dict(model.named_parameters())
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
|
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
|
||||||
@@ -25,6 +26,12 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
|||||||
from sglang.srt.configs.device_config import DeviceConfig
|
from sglang.srt.configs.device_config import DeviceConfig
|
||||||
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
|
from sglang.srt.connector import (
|
||||||
|
ConnectorType,
|
||||||
|
create_remote_connector,
|
||||||
|
get_connector_type,
|
||||||
|
)
|
||||||
|
from sglang.srt.connector.utils import parse_model_name
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
@@ -46,6 +53,7 @@ from sglang.srt.model_loader.weight_utils import (
|
|||||||
np_cache_weights_iterator,
|
np_cache_weights_iterator,
|
||||||
pt_weights_iterator,
|
pt_weights_iterator,
|
||||||
safetensors_weights_iterator,
|
safetensors_weights_iterator,
|
||||||
|
set_runai_streamer_env,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
@@ -490,7 +498,7 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
Model loader that directly loads each worker's model state dict, which
|
Model loader that directly loads each worker's model state dict, which
|
||||||
enables a fast load path for large tensor-parallel models where each worker
|
enables a fast load path for large tensor-parallel models where each worker
|
||||||
only needs to read its own shard rather than the entire checkpoint. See
|
only needs to read its own shard rather than the entire checkpoint. See
|
||||||
`examples/save_sharded_state.py` for creating a sharded checkpoint.
|
`examples/runtime/engine/save_sharded_state.py` for creating a sharded checkpoint.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
||||||
@@ -1204,6 +1212,153 @@ class GGUFModelLoader(BaseModelLoader):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteModelLoader(BaseModelLoader):
|
||||||
|
"""Model loader that can load Tensors from remote database."""
|
||||||
|
|
||||||
|
def __init__(self, load_config: LoadConfig):
|
||||||
|
super().__init__(load_config)
|
||||||
|
# TODO @DellCurry: move to s3 connector only
|
||||||
|
set_runai_streamer_env(load_config)
|
||||||
|
|
||||||
|
def _get_weights_iterator_kv(
|
||||||
|
self,
|
||||||
|
client,
|
||||||
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||||
|
"""Get an iterator for the model weights from remote storage."""
|
||||||
|
assert get_connector_type(client) == ConnectorType.KV
|
||||||
|
rank = get_tensor_model_parallel_rank()
|
||||||
|
return client.weight_iterator(rank)
|
||||||
|
|
||||||
|
def _get_weights_iterator_fs(
|
||||||
|
self,
|
||||||
|
client,
|
||||||
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||||
|
"""Get an iterator for the model weights from remote storage."""
|
||||||
|
assert get_connector_type(client) == ConnectorType.FS
|
||||||
|
return client.weight_iterator()
|
||||||
|
|
||||||
|
def download_model(self, model_config: ModelConfig) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_model(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
model_path: str,
|
||||||
|
url: str,
|
||||||
|
) -> None:
|
||||||
|
with create_remote_connector(url) as client:
|
||||||
|
assert get_connector_type(client) == ConnectorType.KV
|
||||||
|
model_name = parse_model_name(url)
|
||||||
|
rank = get_tensor_model_parallel_rank()
|
||||||
|
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
|
||||||
|
for key, tensor in state_dict.items():
|
||||||
|
r_key = f"{model_name}/keys/rank_{rank}/{key}"
|
||||||
|
client.set(r_key, tensor)
|
||||||
|
|
||||||
|
for root, _, files in os.walk(model_path):
|
||||||
|
for file_name in files:
|
||||||
|
# ignore hidden files
|
||||||
|
if file_name.startswith("."):
|
||||||
|
continue
|
||||||
|
if os.path.splitext(file_name)[1] not in (
|
||||||
|
".bin",
|
||||||
|
".pt",
|
||||||
|
".safetensors",
|
||||||
|
):
|
||||||
|
file_path = os.path.join(root, file_name)
|
||||||
|
with open(file_path, encoding="utf-8") as file:
|
||||||
|
file_content = file.read()
|
||||||
|
f_key = f"{model_name}/files/{file_name}"
|
||||||
|
client.setstr(f_key, file_content)
|
||||||
|
|
||||||
|
def _load_model_from_remote_kv(self, model: nn.Module, client):
|
||||||
|
for _, module in model.named_modules():
|
||||||
|
quant_method = getattr(module, "quant_method", None)
|
||||||
|
if quant_method is not None:
|
||||||
|
quant_method.process_weights_after_loading(module)
|
||||||
|
weights_iterator = self._get_weights_iterator_kv(client)
|
||||||
|
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
|
||||||
|
for key, tensor in weights_iterator:
|
||||||
|
# If loading with LoRA enabled, additional padding may
|
||||||
|
# be added to certain parameters. We only load into a
|
||||||
|
# narrowed view of the parameter data.
|
||||||
|
param_data = state_dict[key].data
|
||||||
|
param_shape = state_dict[key].shape
|
||||||
|
for dim, size in enumerate(tensor.shape):
|
||||||
|
if size < param_shape[dim]:
|
||||||
|
param_data = param_data.narrow(dim, 0, size)
|
||||||
|
if tensor.shape != param_shape:
|
||||||
|
logger.warning(
|
||||||
|
"loading tensor of shape %s into " "parameter '%s' of shape %s",
|
||||||
|
tensor.shape,
|
||||||
|
key,
|
||||||
|
param_shape,
|
||||||
|
)
|
||||||
|
param_data.copy_(tensor)
|
||||||
|
state_dict.pop(key)
|
||||||
|
if state_dict:
|
||||||
|
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
|
||||||
|
|
||||||
|
def _load_model_from_remote_fs(
|
||||||
|
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
|
||||||
|
) -> nn.Module:
|
||||||
|
|
||||||
|
target_device = torch.device(device_config.device)
|
||||||
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
|
model.load_weights(self._get_weights_iterator_fs(client))
|
||||||
|
|
||||||
|
for _, module in model.named_modules():
|
||||||
|
quant_method = getattr(module, "quant_method", None)
|
||||||
|
if quant_method is not None:
|
||||||
|
# When quant methods need to process weights after loading
|
||||||
|
# (for repacking, quantizing, etc), they expect parameters
|
||||||
|
# to be on the global target device. This scope is for the
|
||||||
|
# case where cpu offloading is used, where we will move the
|
||||||
|
# parameters onto device for processing and back off after.
|
||||||
|
with device_loading_context(module, target_device):
|
||||||
|
quant_method.process_weights_after_loading(module)
|
||||||
|
|
||||||
|
def load_model(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
) -> nn.Module:
|
||||||
|
logger.info("Loading weights from remote storage ...")
|
||||||
|
start = time.perf_counter()
|
||||||
|
load_config = self.load_config
|
||||||
|
|
||||||
|
assert load_config.load_format == LoadFormat.REMOTE, (
|
||||||
|
f"Model loader {self.load_config.load_format} is not supported for "
|
||||||
|
f"load format {load_config.load_format}"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_weights = model_config.model_path
|
||||||
|
if hasattr(model_config, "model_weights"):
|
||||||
|
model_weights = model_config.model_weights
|
||||||
|
|
||||||
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
|
with torch.device(device_config.device):
|
||||||
|
model = _initialize_model(model_config, self.load_config)
|
||||||
|
for _, module in model.named_modules():
|
||||||
|
quant_method = getattr(module, "quant_method", None)
|
||||||
|
if quant_method is not None:
|
||||||
|
quant_method.process_weights_after_loading(module)
|
||||||
|
|
||||||
|
with create_remote_connector(model_weights, device_config.device) as client:
|
||||||
|
connector_type = get_connector_type(client)
|
||||||
|
if connector_type == ConnectorType.KV:
|
||||||
|
self._load_model_from_remote_kv(model, client)
|
||||||
|
elif connector_type == ConnectorType.FS:
|
||||||
|
self._load_model_from_remote_fs(
|
||||||
|
model, client, model_config, device_config
|
||||||
|
)
|
||||||
|
|
||||||
|
end = time.perf_counter()
|
||||||
|
logger.info("Loaded weights from remote storage in %.2f seconds.", end - start)
|
||||||
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||||
"""Get a model loader based on the load format."""
|
"""Get a model loader based on the load format."""
|
||||||
|
|
||||||
@@ -1225,4 +1380,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
|||||||
if load_config.load_format == LoadFormat.LAYERED:
|
if load_config.load_format == LoadFormat.LAYERED:
|
||||||
return LayeredModelLoader(load_config)
|
return LayeredModelLoader(load_config)
|
||||||
|
|
||||||
|
if load_config.load_format == LoadFormat.REMOTE:
|
||||||
|
return RemoteModelLoader(load_config)
|
||||||
|
|
||||||
return DefaultModelLoader(load_config)
|
return DefaultModelLoader(load_config)
|
||||||
|
|||||||
@@ -585,6 +585,51 @@ def composed_weight_loader(
|
|||||||
return composed_loader
|
return composed_loader
|
||||||
|
|
||||||
|
|
||||||
|
def runai_safetensors_weights_iterator(
|
||||||
|
hf_weights_files: List[str],
|
||||||
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||||
|
"""Iterate over the weights in the model safetensor files."""
|
||||||
|
from runai_model_streamer import SafetensorsStreamer
|
||||||
|
|
||||||
|
enable_tqdm = (
|
||||||
|
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
with SafetensorsStreamer() as streamer:
|
||||||
|
for st_file in tqdm(
|
||||||
|
hf_weights_files,
|
||||||
|
desc="Loading safetensors using Runai Model Streamer",
|
||||||
|
disable=not enable_tqdm,
|
||||||
|
bar_format=_BAR_FORMAT,
|
||||||
|
):
|
||||||
|
streamer.stream_file(st_file)
|
||||||
|
yield from streamer.get_tensors()
|
||||||
|
|
||||||
|
|
||||||
|
def set_runai_streamer_env(load_config: LoadConfig):
|
||||||
|
if load_config.model_loader_extra_config:
|
||||||
|
extra_config = load_config.model_loader_extra_config
|
||||||
|
|
||||||
|
if "concurrency" in extra_config and isinstance(
|
||||||
|
extra_config.get("concurrency"), int
|
||||||
|
):
|
||||||
|
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
|
||||||
|
extra_config.get("concurrency")
|
||||||
|
)
|
||||||
|
|
||||||
|
if "memory_limit" in extra_config and isinstance(
|
||||||
|
extra_config.get("memory_limit"), int
|
||||||
|
):
|
||||||
|
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
|
||||||
|
extra_config.get("memory_limit")
|
||||||
|
)
|
||||||
|
|
||||||
|
runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT")
|
||||||
|
aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL")
|
||||||
|
if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None:
|
||||||
|
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
|
||||||
|
|
||||||
|
|
||||||
def initialize_dummy_weights(
|
def initialize_dummy_weights(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
low: float = -1e-3,
|
low: float = -1e-3,
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from sglang.srt.utils import (
|
|||||||
is_flashinfer_available,
|
is_flashinfer_available,
|
||||||
is_hip,
|
is_hip,
|
||||||
is_port_available,
|
is_port_available,
|
||||||
|
is_remote_url,
|
||||||
is_valid_ipv6_address,
|
is_valid_ipv6_address,
|
||||||
nullable_str,
|
nullable_str,
|
||||||
)
|
)
|
||||||
@@ -296,6 +297,9 @@ class ServerArgs:
|
|||||||
) and check_gguf_file(self.model_path):
|
) and check_gguf_file(self.model_path):
|
||||||
self.quantization = self.load_format = "gguf"
|
self.quantization = self.load_format = "gguf"
|
||||||
|
|
||||||
|
if is_remote_url(self.model_path):
|
||||||
|
self.load_format = "remote"
|
||||||
|
|
||||||
# AMD-specific Triton attention KV splits default number
|
# AMD-specific Triton attention KV splits default number
|
||||||
if is_hip():
|
if is_hip():
|
||||||
self.triton_attention_num_kv_splits = 16
|
self.triton_attention_num_kv_splits = 16
|
||||||
@@ -345,9 +349,11 @@ class ServerArgs:
|
|||||||
"safetensors",
|
"safetensors",
|
||||||
"npcache",
|
"npcache",
|
||||||
"dummy",
|
"dummy",
|
||||||
|
"sharded_state",
|
||||||
"gguf",
|
"gguf",
|
||||||
"bitsandbytes",
|
"bitsandbytes",
|
||||||
"layered",
|
"layered",
|
||||||
|
"remote",
|
||||||
],
|
],
|
||||||
help="The format of the model weights to load. "
|
help="The format of the model weights to load. "
|
||||||
'"auto" will try to load the weights in the safetensors format '
|
'"auto" will try to load the weights in the safetensors format '
|
||||||
@@ -1088,6 +1094,9 @@ class PortArgs:
|
|||||||
# The port for nccl initialization (torch.dist)
|
# The port for nccl initialization (torch.dist)
|
||||||
nccl_port: int
|
nccl_port: int
|
||||||
|
|
||||||
|
# The ipc filename for rpc call between Engine and Scheduler
|
||||||
|
rpc_ipc_name: str
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
|
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
|
||||||
port = server_args.port + random.randint(100, 1000)
|
port = server_args.port + random.randint(100, 1000)
|
||||||
@@ -1106,6 +1115,7 @@ class PortArgs:
|
|||||||
scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
||||||
detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
||||||
nccl_port=port,
|
nccl_port=port,
|
||||||
|
rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# DP attention. Use TCP + port to handle both single-node and multi-node.
|
# DP attention. Use TCP + port to handle both single-node and multi-node.
|
||||||
@@ -1131,6 +1141,7 @@ class PortArgs:
|
|||||||
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
|
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
|
||||||
detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
|
detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
|
||||||
nccl_port=port,
|
nccl_port=port,
|
||||||
|
rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ from importlib.util import find_spec
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from multiprocessing.reduction import ForkingPickler
|
from multiprocessing.reduction import ForkingPickler
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -774,12 +775,22 @@ def get_zmq_socket(
|
|||||||
buf_size = -1
|
buf_size = -1
|
||||||
|
|
||||||
socket = context.socket(socket_type)
|
socket = context.socket(socket_type)
|
||||||
if socket_type == zmq.PUSH:
|
|
||||||
|
def set_send_opt():
|
||||||
socket.setsockopt(zmq.SNDHWM, 0)
|
socket.setsockopt(zmq.SNDHWM, 0)
|
||||||
socket.setsockopt(zmq.SNDBUF, buf_size)
|
socket.setsockopt(zmq.SNDBUF, buf_size)
|
||||||
elif socket_type == zmq.PULL:
|
|
||||||
|
def set_recv_opt():
|
||||||
socket.setsockopt(zmq.RCVHWM, 0)
|
socket.setsockopt(zmq.RCVHWM, 0)
|
||||||
socket.setsockopt(zmq.RCVBUF, buf_size)
|
socket.setsockopt(zmq.RCVBUF, buf_size)
|
||||||
|
|
||||||
|
if socket_type == zmq.PUSH:
|
||||||
|
set_send_opt()
|
||||||
|
elif socket_type == zmq.PULL:
|
||||||
|
set_recv_opt()
|
||||||
|
elif socket_type == zmq.DEALER:
|
||||||
|
set_send_opt()
|
||||||
|
set_recv_opt()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported socket type: {socket_type}")
|
raise ValueError(f"Unsupported socket type: {socket_type}")
|
||||||
|
|
||||||
@@ -1572,3 +1583,29 @@ def add_prefix(name: str, prefix: str) -> str:
|
|||||||
The string `prefix.name` if prefix is non-empty, otherwise just `name`.
|
The string `prefix.name` if prefix is non-empty, otherwise just `name`.
|
||||||
"""
|
"""
|
||||||
return name if not prefix else f"{prefix}.{name}"
|
return name if not prefix else f"{prefix}.{name}"
|
||||||
|
|
||||||
|
|
||||||
|
def is_remote_url(url: Union[str, Path]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the URL is a remote URL of the format:
|
||||||
|
<connector_type>://<host>:<port>/<model_name>
|
||||||
|
"""
|
||||||
|
if isinstance(url, Path):
|
||||||
|
return False
|
||||||
|
|
||||||
|
pattern = r"(.+)://(.*)"
|
||||||
|
m = re.match(pattern, url)
|
||||||
|
return m is not None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_connector_type(url: str) -> str:
|
||||||
|
"""
|
||||||
|
Parse the connector type from the URL of the format:
|
||||||
|
<connector_type>://<path>
|
||||||
|
"""
|
||||||
|
pattern = r"(.+)://(.*)"
|
||||||
|
m = re.match(pattern, url)
|
||||||
|
if m is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return m.group(1)
|
||||||
|
|||||||
Reference in New Issue
Block a user