diff --git a/examples/runtime/engine/save_remote_state.py b/examples/runtime/engine/save_remote_state.py new file mode 100644 index 000000000..47812695f --- /dev/null +++ b/examples/runtime/engine/save_remote_state.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Saves each worker's model state dict directly to a checkpoint, which enables a +fast load path for large tensor-parallel models where each worker only needs to +read its own shard rather than the entire checkpoint. + +Example usage: + +python save_remote_state.py \ + --model-path /path/to/load \ + --tensor-parallel-size 8 \ + --remote-model-save-url [protocol]://[host]:[port]/[model_name] \ + +Then, the model can be loaded with + +llm = Engine( + model_path="/path/to/save", + --remote-model-url [protocol]://[host]:[port]/[model_name], + tensor_parallel_size=8, +) +""" +import dataclasses +from argparse import ArgumentParser +from pathlib import Path + +from sglang import Engine, ServerArgs + +parser = ArgumentParser() +ServerArgs.add_cli_args(parser) + +parser.add_argument( + "--remote-model-save-url", + required=True, + type=str, + help="remote address to store model weights", +) + + +def main(args): + engine_args = ServerArgs.from_cli_args(args) + model_path = engine_args.model_path + if not Path(model_path).is_dir(): + raise ValueError("model path must be a local directory") + # Create LLM instance from arguments + llm = Engine(**dataclasses.asdict(engine_args)) + llm.save_remote_model(url=args.remote_model_save_url) + + +if __name__ == "__main__": + args = parser.parse_args() + main(args) diff --git a/examples/runtime/engine/save_sharded_state.py b/examples/runtime/engine/save_sharded_state.py new file mode 100644 index 000000000..80ad5321f --- /dev/null +++ b/examples/runtime/engine/save_sharded_state.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Saves each worker's model state dict directly to a checkpoint, which enables a +fast load path for large tensor-parallel models where each worker only needs to +read its own shard rather than the entire checkpoint. + +Example usage: + +python save_sharded_state.py \ + --model-path /path/to/load \ + --quantization deepspeedfp \ + --tensor-parallel-size 8 \ + --output /path/to/save + +Then, the model can be loaded with + +llm = Engine( + model_path="/path/to/save", + load_format="sharded_state", + quantization="deepspeedfp", + tensor_parallel_size=8, +) +""" +import dataclasses +import os +import shutil +from argparse import ArgumentParser +from pathlib import Path + +from sglang import Engine, ServerArgs + +parser = ArgumentParser() +ServerArgs.add_cli_args(parser) + +parser.add_argument( + "--output", "-o", required=True, type=str, help="path to output checkpoint" +) +parser.add_argument( + "--file-pattern", type=str, help="string pattern of saved filenames" +) +parser.add_argument( + "--max-file-size", + type=str, + default=5 * 1024**3, + help="max size (in bytes) of each safetensors file", +) + + +def main(args): + engine_args = ServerArgs.from_cli_args(args) + model_path = engine_args.model_path + if not Path(model_path).is_dir(): + raise ValueError("model path must be a local directory") + # Create LLM instance from arguments + llm = Engine(**dataclasses.asdict(engine_args)) + Path(args.output).mkdir(exist_ok=True) + llm.save_sharded_model( + path=args.output, pattern=args.file_pattern, max_size=args.max_file_size + ) + + # Copy metadata files to output directory + for file in os.listdir(model_path): + if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"): + if os.path.isdir(os.path.join(model_path, file)): + shutil.copytree( + os.path.join(model_path, file), os.path.join(args.output, file) + ) + else: + shutil.copy(os.path.join(model_path, file), args.output) + + +if __name__ == "__main__": + args = parser.parse_args() + main(args) diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py index 70d58043d..db0cf2604 100644 --- a/python/sglang/__init__.py +++ b/python/sglang/__init__.py @@ -32,6 +32,7 @@ from sglang.lang.choices import ( ) from sglang.utils import LazyImport +ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs") Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic") LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM") OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI") @@ -67,6 +68,7 @@ __all__ = [ "greedy_token_selection", "token_length_normalized", "unconditional_likelihood_normalized", + "ServerArgs", "Anthropic", "LiteLLM", "OpenAI", diff --git a/python/sglang/srt/configs/load_config.py b/python/sglang/srt/configs/load_config.py index c8521910e..be9a40b4b 100644 --- a/python/sglang/srt/configs/load_config.py +++ b/python/sglang/srt/configs/load_config.py @@ -22,6 +22,7 @@ class LoadFormat(str, enum.Enum): MISTRAL = "mistral" LAYERED = "layered" JAX = "jax" + REMOTE = "remote" @dataclass diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 13516c5c6..cb31edd1b 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -51,13 +51,14 @@ class ModelConfig: self.quantization = quantization # Parse args + self.maybe_pull_model_tokenizer_from_remote() self.model_override_args = json.loads(model_override_args) kwargs = {} if override_config_file and override_config_file.strip(): kwargs["_configuration_file"] = override_config_file.strip() self.hf_config = get_config( - model_path, + self.model_path, trust_remote_code=trust_remote_code, revision=revision, model_override_args=self.model_override_args, @@ -318,6 +319,29 @@ class ModelConfig: eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids) return eos_ids + def maybe_pull_model_tokenizer_from_remote(self) -> None: + """ + Pull the model config files to a temporary + directory in case of remote. + + Args: + model: The model name or path. + + """ + from sglang.srt.connector import create_remote_connector + from sglang.srt.utils import is_remote_url + + if is_remote_url(self.model_path): + logger.info("Pulling model configs from remote...") + # BaseConnector implements __del__() to clean up the local dir. + # Since config files need to exist all the time, so we DO NOT use + # with statement to avoid closing the client. + client = create_remote_connector(self.model_path) + if is_remote_url(self.model_path): + client.pull_files(allow_pattern=["*config.json"]) + self.model_weights = self.model_path + self.model_path = client.get_local_dir() + def get_hf_text_config(config: PretrainedConfig): """Get the "sub" config relevant to llm for multi modal models. diff --git a/python/sglang/srt/connector/__init__.py b/python/sglang/srt/connector/__init__.py new file mode 100644 index 000000000..829644c91 --- /dev/null +++ b/python/sglang/srt/connector/__init__.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 + +import enum +import logging + +from sglang.srt.connector.base_connector import ( + BaseConnector, + BaseFileConnector, + BaseKVConnector, +) +from sglang.srt.connector.redis import RedisConnector +from sglang.srt.connector.s3 import S3Connector +from sglang.srt.utils import parse_connector_type + +logger = logging.getLogger(__name__) + + +class ConnectorType(str, enum.Enum): + FS = "filesystem" + KV = "KV" + + +def create_remote_connector(url, device="cpu") -> BaseConnector: + connector_type = parse_connector_type(url) + if connector_type == "redis": + return RedisConnector(url) + elif connector_type == "s3": + return S3Connector(url) + else: + raise ValueError(f"Invalid connector type: {url}") + + +def get_connector_type(client: BaseConnector) -> ConnectorType: + if isinstance(client, BaseKVConnector): + return ConnectorType.KV + if isinstance(client, BaseFileConnector): + return ConnectorType.FS + + raise ValueError(f"Invalid connector type: {client}") + + +__all__ = [ + "BaseConnector", + "BaseFileConnector", + "BaseKVConnector", + "RedisConnector", + "S3Connector", + "ConnectorType", + "create_remote_connector", + "get_connector_type", +] diff --git a/python/sglang/srt/connector/base_connector.py b/python/sglang/srt/connector/base_connector.py new file mode 100644 index 000000000..a9c00d0c9 --- /dev/null +++ b/python/sglang/srt/connector/base_connector.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import shutil +import signal +import tempfile +from abc import ABC, abstractmethod +from typing import Generator, List, Optional, Tuple + +import torch + + +class BaseConnector(ABC): + """ + For fs connector such as s3: + :/// + + For kv connector such as redis: + ://://keys/ + ://files/ + """ + + def __init__(self, url: str, device: torch.device = "cpu"): + self.url = url + self.device = device + self.closed = False + self.local_dir = tempfile.mkdtemp() + for sig in (signal.SIGINT, signal.SIGTERM): + existing_handler = signal.getsignal(sig) + signal.signal(sig, self._close_by_signal(existing_handler)) + + def get_local_dir(self): + return self.local_dir + + @abstractmethod + def weight_iterator( + self, rank: int = 0 + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + raise NotImplementedError() + + @abstractmethod + def pull_files( + self, + allow_pattern: Optional[List[str]] = None, + ignore_pattern: Optional[List[str]] = None, + ) -> None: + raise NotImplementedError() + + def close(self): + if self.closed: + return + + self.closed = True + if os.path.exists(self.local_dir): + shutil.rmtree(self.local_dir) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def __del__(self): + self.close() + + def _close_by_signal(self, existing_handler=None): + + def new_handler(signum, frame): + self.close() + if existing_handler: + existing_handler(signum, frame) + + return new_handler + + +class BaseKVConnector(BaseConnector): + + @abstractmethod + def get(self, key: str) -> Optional[torch.Tensor]: + raise NotImplementedError() + + @abstractmethod + def getstr(self, key: str) -> Optional[str]: + raise NotImplementedError() + + @abstractmethod + def set(self, key: str, obj: torch.Tensor) -> None: + raise NotImplementedError() + + @abstractmethod + def setstr(self, key: str, obj: str) -> None: + raise NotImplementedError() + + @abstractmethod + def list(self, prefix: str) -> List[str]: + raise NotImplementedError() + + +class BaseFileConnector(BaseConnector): + """ + List full file names from remote fs path and filter by allow pattern. + + Args: + allow_pattern: A list of patterns of which files to pull. + + Returns: + list[str]: List of full paths allowed by the pattern + """ + + @abstractmethod + def glob(self, allow_pattern: str) -> List[str]: + raise NotImplementedError() diff --git a/python/sglang/srt/connector/redis.py b/python/sglang/srt/connector/redis.py new file mode 100644 index 000000000..761594f78 --- /dev/null +++ b/python/sglang/srt/connector/redis.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Generator, List, Optional, Tuple +from urllib.parse import urlparse + +import torch + +from sglang.srt.connector import BaseKVConnector +from sglang.srt.connector.serde import create_serde +from sglang.srt.connector.utils import pull_files_from_db + +logger = logging.getLogger(__name__) + + +class RedisConnector(BaseKVConnector): + + def __init__(self, url: str, device: torch.device = "cpu"): + import redis + + super().__init__(url, device) + parsed_url = urlparse(url) + self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port) + self.model_name = parsed_url.path.lstrip("/") + # TODO: more serde options + self.s, self.d = create_serde("safe") + + def get(self, key: str) -> Optional[torch.Tensor]: + val = self.connection.get(key) + + if val is None: + logger.error("Key %s not found", key) + return None + + return self.d.from_bytes(val) + + def getstr(self, key: str) -> Optional[str]: + val = self.connection.get(key) + if val is None: + logger.error("Key %s not found", key) + return None + + return val.decode("utf-8") + + def set(self, key: str, tensor: torch.Tensor) -> None: + assert tensor is not None + self.connection.set(key, self.s.to_bytes(tensor)) + + def setstr(self, key: str, obj: str) -> None: + self.connection.set(key, obj) + + def list(self, prefix: str) -> List[str]: + cursor = 0 + all_keys: List[bytes] = [] + + while True: + ret: Tuple[int, List[bytes]] = self.connection.scan( + cursor=cursor, match=f"{prefix}*" + ) # type: ignore + cursor, keys = ret + all_keys.extend(keys) + if cursor == 0: + break + + return [key.decode("utf-8") for key in all_keys] + + def weight_iterator( + self, rank: int = 0 + ) -> Generator[Tuple[str, bytes], None, None]: + keys = self.list(f"{self.model_name}/keys/rank_{rank}/") + for key in keys: + val = self.get(key) + key = key.removeprefix(f"{self.model_name}/keys/rank_{rank}/") + yield key, val + + def pull_files( + self, + allow_pattern: Optional[List[str]] = None, + ignore_pattern: Optional[List[str]] = None, + ) -> None: + pull_files_from_db(self, self.model_name, allow_pattern, ignore_pattern) + + def close(self): + self.connection.close() + super().close() diff --git a/python/sglang/srt/connector/s3.py b/python/sglang/srt/connector/s3.py new file mode 100644 index 000000000..7bef8f5d5 --- /dev/null +++ b/python/sglang/srt/connector/s3.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 + +import fnmatch +import os +from pathlib import Path +from typing import Generator, Optional, Tuple + +import torch + +from sglang.srt.connector import BaseFileConnector + + +def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]: + return [ + path + for path in paths + if any(fnmatch.fnmatch(path, pattern) for pattern in patterns) + ] + + +def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]: + return [ + path + for path in paths + if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns) + ] + + +def list_files( + s3, + path: str, + allow_pattern: Optional[list[str]] = None, + ignore_pattern: Optional[list[str]] = None, +) -> tuple[str, str, list[str]]: + """ + List files from S3 path and filter by pattern. + + Args: + s3: S3 client to use. + path: The S3 path to list from. + allow_pattern: A list of patterns of which files to pull. + ignore_pattern: A list of patterns of which files not to pull. + + Returns: + tuple[str, str, list[str]]: A tuple where: + - The first element is the bucket name + - The second element is string represent the bucket + and the prefix as a dir like string + - The third element is a list of files allowed or + disallowed by pattern + """ + parts = path.removeprefix("s3://").split("/") + prefix = "/".join(parts[1:]) + bucket_name = parts[0] + + objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix) + paths = [obj["Key"] for obj in objects.get("Contents", [])] + + paths = _filter_ignore(paths, ["*/"]) + if allow_pattern is not None: + paths = _filter_allow(paths, allow_pattern) + + if ignore_pattern is not None: + paths = _filter_ignore(paths, ignore_pattern) + + return bucket_name, prefix, paths + + +class S3Connector(BaseFileConnector): + + def __init__(self, url: str) -> None: + import boto3 + + super().__init__(url) + self.client = boto3.client("s3") + + def glob(self, allow_pattern: Optional[list[str]] = None) -> list[str]: + bucket_name, _, paths = list_files( + self.client, path=self.url, allow_pattern=allow_pattern + ) + return [f"s3://{bucket_name}/{path}" for path in paths] + + def pull_files( + self, + allow_pattern: Optional[list[str]] = None, + ignore_pattern: Optional[list[str]] = None, + ) -> None: + """ + Pull files from S3 storage into the temporary directory. + + Args: + s3_model_path: The S3 path of the model. + allow_pattern: A list of patterns of which files to pull. + ignore_pattern: A list of patterns of which files not to pull. + + """ + bucket_name, base_dir, files = list_files( + self.client, self.url, allow_pattern, ignore_pattern + ) + if len(files) == 0: + return + + for file in files: + destination_file = os.path.join(self.local_dir, file.removeprefix(base_dir)) + local_dir = Path(destination_file).parent + os.makedirs(local_dir, exist_ok=True) + self.client.download_file(bucket_name, file, destination_file) + + def weight_iterator( + self, rank: int = 0 + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + from sglang.srt.model_loader.weight_utils import ( + runai_safetensors_weights_iterator, + ) + + # only support safetensor files now + hf_weights_files = self.glob(allow_pattern=["*.safetensors"]) + return runai_safetensors_weights_iterator(hf_weights_files) + + def close(self): + self.client.close() + super().close() diff --git a/python/sglang/srt/connector/serde/__init__.py b/python/sglang/srt/connector/serde/__init__.py new file mode 100644 index 000000000..394dba0a6 --- /dev/null +++ b/python/sglang/srt/connector/serde/__init__.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 + +# inspired by LMCache +from typing import Optional, Tuple + +import torch + +from sglang.srt.connector.serde.safe_serde import SafeDeserializer, SafeSerializer +from sglang.srt.connector.serde.serde import Deserializer, Serializer + + +def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]: + s: Optional[Serializer] = None + d: Optional[Deserializer] = None + + if serde_type == "safe": + s = SafeSerializer() + d = SafeDeserializer(torch.uint8) + else: + raise ValueError(f"Unknown serde type: {serde_type}") + + return s, d + + +__all__ = [ + "Serializer", + "Deserializer", + "SafeSerializer", + "SafeDeserializer", + "create_serde", +] diff --git a/python/sglang/srt/connector/serde/safe_serde.py b/python/sglang/srt/connector/serde/safe_serde.py new file mode 100644 index 000000000..0163af9f5 --- /dev/null +++ b/python/sglang/srt/connector/serde/safe_serde.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +import torch +from safetensors.torch import load, save + +from sglang.srt.connector.serde.serde import Deserializer, Serializer + + +class SafeSerializer(Serializer): + + def __init__(self): + super().__init__() + + def to_bytes(self, t: torch.Tensor) -> bytes: + return save({"tensor_bytes": t.cpu().contiguous()}) + + +class SafeDeserializer(Deserializer): + + def __init__(self, dtype): + super().__init__(dtype) + + def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor: + return load(bytes(b))["tensor_bytes"].to(dtype=self.dtype) + + def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor: + return self.from_bytes_normal(b) diff --git a/python/sglang/srt/connector/serde/serde.py b/python/sglang/srt/connector/serde/serde.py new file mode 100644 index 000000000..3d6f804d7 --- /dev/null +++ b/python/sglang/srt/connector/serde/serde.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 + +import abc +from abc import ABC, abstractmethod + +import torch + + +class Serializer(ABC): + + @abstractmethod + def to_bytes(self, t: torch.Tensor) -> bytes: + """ + Serialize a pytorch tensor to bytes. The serialized bytes should contain + both the data and the metadata (shape, dtype, etc.) of the tensor. + + Input: + t: the input pytorch tensor, can be on any device, in any shape, + with any dtype + + Returns: + bytes: the serialized bytes + """ + raise NotImplementedError + + +class Deserializer(metaclass=abc.ABCMeta): + + def __init__(self, dtype): + self.dtype = dtype + + @abstractmethod + def from_bytes(self, bs: bytes) -> torch.Tensor: + """ + Deserialize a pytorch tensor from bytes. + + Input: + bytes: a stream of bytes + + Output: + torch.Tensor: the deserialized pytorch tensor + """ + raise NotImplementedError diff --git a/python/sglang/srt/connector/utils.py b/python/sglang/srt/connector/utils.py new file mode 100644 index 000000000..6cd12da33 --- /dev/null +++ b/python/sglang/srt/connector/utils.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from pathlib import Path +from typing import Optional +from urllib.parse import urlparse + +from sglang.srt.connector import BaseConnector + + +def parse_model_name(url: str) -> str: + """ + Parse the model name from the url. + Only used for db connector + """ + parsed_url = urlparse(url) + return parsed_url.path.lstrip("/") + + +def pull_files_from_db( + connector: BaseConnector, + model_name: str, + allow_pattern: Optional[list[str]] = None, + ignore_pattern: Optional[list[str]] = None, +) -> None: + prefix = f"{model_name}/files/" + local_dir = connector.get_local_dir() + files = connector.list(prefix) + + for file in files: + destination_file = os.path.join(local_dir, file.removeprefix(prefix)) + local_dir = Path(destination_file).parent + os.makedirs(local_dir, exist_ok=True) + with open(destination_file, "wb") as f: + f.write(connector.getstr(file).encode("utf-8")) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 0fa841993..ec4ea515b 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -27,6 +27,9 @@ import signal import threading from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union +import zmq +import zmq.asyncio + # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -44,6 +47,8 @@ from sglang.srt.managers.io_struct import ( InitWeightsUpdateGroupReqInput, ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, + RpcReqInput, + RpcReqOutput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromTensorReqInput, @@ -57,6 +62,7 @@ from sglang.srt.utils import ( MultiprocessingSerializer, assert_pkg_version, configure_logger, + get_zmq_socket, kill_process_tree, launch_dummy_health_check_server, maybe_set_triton_cache_manager, @@ -102,15 +108,25 @@ class Engine: # Shutdown the subprocesses automatically when the program exits atexit.register(self.shutdown) + # Allocate ports for inter-process communications + port_args = PortArgs.init_new(server_args) + logger.info(f"{server_args=}") + # Launch subprocesses tokenizer_manager, scheduler_info = _launch_subprocesses( - server_args=server_args + server_args=server_args, + port_args=port_args, ) self.server_args = server_args self.tokenizer_manager = tokenizer_manager self.scheduler_info = scheduler_info + context = zmq.Context(2) + self.send_to_rpc = get_zmq_socket( + context, zmq.DEALER, port_args.rpc_ipc_name, True + ) + def generate( self, # The input prompt. It can be a single prompt or a batch of prompts. @@ -350,6 +366,23 @@ class Engine: self.tokenizer_manager.resume_memory_occupation(obj, None) ) + """ + Execute an RPC call on all scheduler processes. + """ + + def collective_rpc(self, method: str, **kwargs): + obj = RpcReqInput(method=method, parameters=kwargs) + self.send_to_rpc.send_pyobj(obj) + recv_req = self.send_to_rpc.recv_pyobj(zmq.BLOCKY) + assert isinstance(recv_req, RpcReqOutput) + assert recv_req.success, recv_req.message + + def save_remote_model(self, **kwargs): + self.collective_rpc("save_remote_model", **kwargs) + + def save_sharded_model(self, **kwargs): + self.collective_rpc("save_sharded_model", **kwargs) + def _set_envs_and_config(server_args: ServerArgs): # Set global environments @@ -408,7 +441,9 @@ def _set_envs_and_config(server_args: ServerArgs): mp.set_start_method("spawn", force=True) -def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dict]: +def _launch_subprocesses( + server_args: ServerArgs, port_args: Optional[PortArgs] = None +) -> Tuple[TokenizerManager, Dict]: """ Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. """ @@ -418,8 +453,9 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic _set_envs_and_config(server_args) # Allocate ports for inter-process communications - port_args = PortArgs.init_new(server_args) - logger.info(f"{server_args=}") + if port_args is None: + port_args = PortArgs.init_new(server_args) + logger.info(f"{server_args=}") # If using model from www.modelscope.cn, first download the model. server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer( diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index e5aa5a62e..5cd52bf4a 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -37,6 +37,8 @@ from sglang.srt.configs import ( MultiModalityConfig, Qwen2_5_VLConfig, ) +from sglang.srt.connector import create_remote_connector +from sglang.srt.utils import is_remote_url _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { ChatGLMConfig.model_type: ChatGLMConfig, @@ -155,6 +157,14 @@ def get_tokenizer( kwargs["gguf_file"] = tokenizer_name tokenizer_name = Path(tokenizer_name).parent + if is_remote_url(tokenizer_name): + # BaseConnector implements __del__() to clean up the local dir. + # Since config files need to exist all the time, so we DO NOT use + # with statement to avoid closing the client. + client = create_remote_connector(tokenizer_name) + client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) + tokenizer_name = client.get_local_dir() + try: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 232fb3859..e1cb4f73e 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -723,3 +723,15 @@ class SeparateReasoningReqInput: class VertexGenerateReqInput: instances: List[dict] parameters: Optional[dict] = None + + +@dataclass +class RpcReqInput: + method: str + parameters: Optional[Dict] = None + + +@dataclass +class RpcReqOutput: + success: bool + message: str diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index d5ce3bc71..558141d74 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -32,6 +32,7 @@ import psutil import setproctitle import torch import zmq +from torch.distributed import barrier from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig @@ -59,6 +60,8 @@ from sglang.srt.managers.io_struct import ( ReleaseMemoryOccupationReqOutput, ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqOutput, + RpcReqInput, + RpcReqOutput, SetInternalStateReq, SetInternalStateReqOutput, TokenizedEmbeddingReqInput, @@ -193,8 +196,13 @@ class Scheduler(SchedulerOutputProcessorMixin): self.send_to_detokenizer = get_zmq_socket( context, zmq.PUSH, port_args.detokenizer_ipc_name, False ) + + self.recv_from_rpc = get_zmq_socket( + context, zmq.DEALER, port_args.rpc_ipc_name, False + ) else: self.recv_from_tokenizer = None + self.recv_from_rpc = None self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None) self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None) @@ -376,6 +384,7 @@ class Scheduler(SchedulerOutputProcessorMixin): (ProfileReq, self.profile), (GetInternalStateReq, self.get_internal_state), (SetInternalStateReq, self.set_internal_state), + (RpcReqInput, self.handle_rpc_request), ] ) @@ -549,6 +558,13 @@ class Scheduler(SchedulerOutputProcessorMixin): except zmq.ZMQError: break recv_reqs.append(recv_req) + + while True: + try: + recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK) + except zmq.ZMQError: + break + recv_reqs.append(recv_rpc) else: recv_reqs = None @@ -600,7 +616,11 @@ class Scheduler(SchedulerOutputProcessorMixin): output = self._request_dispatcher(recv_req) if output is not None: - self.send_to_tokenizer.send_pyobj(output) + if isinstance(output, RpcReqOutput): + if self.recv_from_rpc is not None: + self.recv_from_rpc.send_pyobj(output) + else: + self.send_to_tokenizer.send_pyobj(output) def handle_generate_request( self, @@ -1492,6 +1512,47 @@ class Scheduler(SchedulerOutputProcessorMixin): server_args=global_server_args_dict, ) + def handle_rpc_request(self, recv_req: RpcReqInput): + # Handle RPC requests + logger.info( + f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}" + ) + + success = True + exec = None + try: + func = getattr(self, recv_req.method) + func(recv_req.parameters) + except Exception as e: + success = False + exec = e + logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}") + + barrier() + return RpcReqOutput(success, "" if not exec else str(exec)) + + def save_remote_model(self, params): + url = params["url"] + + if isinstance(self.tp_worker, TpModelWorkerClient): + worker = self.tp_worker.worker + else: + worker = self.tp_worker + + worker.model_runner.save_remote_model(url) + + def save_sharded_model(self, params): + if isinstance(self.tp_worker, TpModelWorkerClient): + worker = self.tp_worker.worker + else: + worker = self.tp_worker + + worker.model_runner.save_sharded_model( + path=params["path"], + pattern=params["pattern"], + max_size=params["max_size"], + ) + def abort_request(self, recv_req: AbortReq): # Delete requests in the waiting queue to_del = [] diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 1b447b2b8..c4150bda1 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1009,6 +1009,22 @@ class ModelRunner: return False return rope_scaling.get("type", None) == "mrope" + def save_remote_model(self, url: str): + from sglang.srt.model_loader.loader import RemoteModelLoader + + logger.info(f"Saving model to {url}") + RemoteModelLoader.save_model(self.model, self.model_config.model_path, url) + + def save_sharded_model( + self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None + ): + from sglang.srt.model_loader.loader import ShardedStateLoader + + logger.info( + f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}" + ) + ShardedStateLoader.save_model(self.model, path, pattern, max_size) + def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]): params_dict = dict(model.named_parameters()) diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index c241fd9d6..656a9718d 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -9,6 +9,7 @@ import json import logging import math import os +import time from abc import ABC, abstractmethod from contextlib import contextmanager from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast @@ -25,6 +26,12 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig, LoadFormat from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.connector import ( + ConnectorType, + create_remote_connector, + get_connector_type, +) +from sglang.srt.connector.utils import parse_model_name from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -46,6 +53,7 @@ from sglang.srt.model_loader.weight_utils import ( np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator, + set_runai_streamer_env, ) from sglang.srt.utils import ( get_bool_env_var, @@ -490,7 +498,7 @@ class ShardedStateLoader(BaseModelLoader): Model loader that directly loads each worker's model state dict, which enables a fast load path for large tensor-parallel models where each worker only needs to read its own shard rather than the entire checkpoint. See - `examples/save_sharded_state.py` for creating a sharded checkpoint. + `examples/runtime/engine/save_sharded_state.py` for creating a sharded checkpoint. """ DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" @@ -1204,6 +1212,153 @@ class GGUFModelLoader(BaseModelLoader): return model +class RemoteModelLoader(BaseModelLoader): + """Model loader that can load Tensors from remote database.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + # TODO @DellCurry: move to s3 connector only + set_runai_streamer_env(load_config) + + def _get_weights_iterator_kv( + self, + client, + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights from remote storage.""" + assert get_connector_type(client) == ConnectorType.KV + rank = get_tensor_model_parallel_rank() + return client.weight_iterator(rank) + + def _get_weights_iterator_fs( + self, + client, + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights from remote storage.""" + assert get_connector_type(client) == ConnectorType.FS + return client.weight_iterator() + + def download_model(self, model_config: ModelConfig) -> None: + pass + + @staticmethod + def save_model( + model: torch.nn.Module, + model_path: str, + url: str, + ) -> None: + with create_remote_connector(url) as client: + assert get_connector_type(client) == ConnectorType.KV + model_name = parse_model_name(url) + rank = get_tensor_model_parallel_rank() + state_dict = ShardedStateLoader._filter_subtensors(model.state_dict()) + for key, tensor in state_dict.items(): + r_key = f"{model_name}/keys/rank_{rank}/{key}" + client.set(r_key, tensor) + + for root, _, files in os.walk(model_path): + for file_name in files: + # ignore hidden files + if file_name.startswith("."): + continue + if os.path.splitext(file_name)[1] not in ( + ".bin", + ".pt", + ".safetensors", + ): + file_path = os.path.join(root, file_name) + with open(file_path, encoding="utf-8") as file: + file_content = file.read() + f_key = f"{model_name}/files/{file_name}" + client.setstr(f_key, file_content) + + def _load_model_from_remote_kv(self, model: nn.Module, client): + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + weights_iterator = self._get_weights_iterator_kv(client) + state_dict = ShardedStateLoader._filter_subtensors(model.state_dict()) + for key, tensor in weights_iterator: + # If loading with LoRA enabled, additional padding may + # be added to certain parameters. We only load into a + # narrowed view of the parameter data. + param_data = state_dict[key].data + param_shape = state_dict[key].shape + for dim, size in enumerate(tensor.shape): + if size < param_shape[dim]: + param_data = param_data.narrow(dim, 0, size) + if tensor.shape != param_shape: + logger.warning( + "loading tensor of shape %s into " "parameter '%s' of shape %s", + tensor.shape, + key, + param_shape, + ) + param_data.copy_(tensor) + state_dict.pop(key) + if state_dict: + raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!") + + def _load_model_from_remote_fs( + self, model, client, model_config: ModelConfig, device_config: DeviceConfig + ) -> nn.Module: + + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + model.load_weights(self._get_weights_iterator_fs(client)) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + ) -> nn.Module: + logger.info("Loading weights from remote storage ...") + start = time.perf_counter() + load_config = self.load_config + + assert load_config.load_format == LoadFormat.REMOTE, ( + f"Model loader {self.load_config.load_format} is not supported for " + f"load format {load_config.load_format}" + ) + + model_weights = model_config.model_path + if hasattr(model_config, "model_weights"): + model_weights = model_config.model_weights + + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config) + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + + with create_remote_connector(model_weights, device_config.device) as client: + connector_type = get_connector_type(client) + if connector_type == ConnectorType.KV: + self._load_model_from_remote_kv(model, client) + elif connector_type == ConnectorType.FS: + self._load_model_from_remote_fs( + model, client, model_config, device_config + ) + + end = time.perf_counter() + logger.info("Loaded weights from remote storage in %.2f seconds.", end - start) + return model.eval() + + def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: """Get a model loader based on the load format.""" @@ -1225,4 +1380,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: if load_config.load_format == LoadFormat.LAYERED: return LayeredModelLoader(load_config) + if load_config.load_format == LoadFormat.REMOTE: + return RemoteModelLoader(load_config) + return DefaultModelLoader(load_config) diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index be54f8a5d..92751b3c1 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -585,6 +585,51 @@ def composed_weight_loader( return composed_loader +def runai_safetensors_weights_iterator( + hf_weights_files: List[str], +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model safetensor files.""" + from runai_model_streamer import SafetensorsStreamer + + enable_tqdm = ( + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + ) + + with SafetensorsStreamer() as streamer: + for st_file in tqdm( + hf_weights_files, + desc="Loading safetensors using Runai Model Streamer", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ): + streamer.stream_file(st_file) + yield from streamer.get_tensors() + + +def set_runai_streamer_env(load_config: LoadConfig): + if load_config.model_loader_extra_config: + extra_config = load_config.model_loader_extra_config + + if "concurrency" in extra_config and isinstance( + extra_config.get("concurrency"), int + ): + os.environ["RUNAI_STREAMER_CONCURRENCY"] = str( + extra_config.get("concurrency") + ) + + if "memory_limit" in extra_config and isinstance( + extra_config.get("memory_limit"), int + ): + os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str( + extra_config.get("memory_limit") + ) + + runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT") + aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL") + if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None: + os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url + + def initialize_dummy_weights( model: torch.nn.Module, low: float = -1e-3, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index fc3cb5cb2..d6cb878b9 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -30,6 +30,7 @@ from sglang.srt.utils import ( is_flashinfer_available, is_hip, is_port_available, + is_remote_url, is_valid_ipv6_address, nullable_str, ) @@ -296,6 +297,9 @@ class ServerArgs: ) and check_gguf_file(self.model_path): self.quantization = self.load_format = "gguf" + if is_remote_url(self.model_path): + self.load_format = "remote" + # AMD-specific Triton attention KV splits default number if is_hip(): self.triton_attention_num_kv_splits = 16 @@ -345,9 +349,11 @@ class ServerArgs: "safetensors", "npcache", "dummy", + "sharded_state", "gguf", "bitsandbytes", "layered", + "remote", ], help="The format of the model weights to load. " '"auto" will try to load the weights in the safetensors format ' @@ -1088,6 +1094,9 @@ class PortArgs: # The port for nccl initialization (torch.dist) nccl_port: int + # The ipc filename for rpc call between Engine and Scheduler + rpc_ipc_name: str + @staticmethod def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs": port = server_args.port + random.randint(100, 1000) @@ -1106,6 +1115,7 @@ class PortArgs: scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", nccl_port=port, + rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", ) else: # DP attention. Use TCP + port to handle both single-node and multi-node. @@ -1131,6 +1141,7 @@ class PortArgs: scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}", detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}", nccl_port=port, + rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}", ) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 19c06b607..b661d648f 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -42,6 +42,7 @@ from importlib.util import find_spec from io import BytesIO from multiprocessing import Pool from multiprocessing.reduction import ForkingPickler +from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union import numpy as np @@ -774,12 +775,22 @@ def get_zmq_socket( buf_size = -1 socket = context.socket(socket_type) - if socket_type == zmq.PUSH: + + def set_send_opt(): socket.setsockopt(zmq.SNDHWM, 0) socket.setsockopt(zmq.SNDBUF, buf_size) - elif socket_type == zmq.PULL: + + def set_recv_opt(): socket.setsockopt(zmq.RCVHWM, 0) socket.setsockopt(zmq.RCVBUF, buf_size) + + if socket_type == zmq.PUSH: + set_send_opt() + elif socket_type == zmq.PULL: + set_recv_opt() + elif socket_type == zmq.DEALER: + set_send_opt() + set_recv_opt() else: raise ValueError(f"Unsupported socket type: {socket_type}") @@ -1572,3 +1583,29 @@ def add_prefix(name: str, prefix: str) -> str: The string `prefix.name` if prefix is non-empty, otherwise just `name`. """ return name if not prefix else f"{prefix}.{name}" + + +def is_remote_url(url: Union[str, Path]) -> bool: + """ + Check if the URL is a remote URL of the format: + ://:/ + """ + if isinstance(url, Path): + return False + + pattern = r"(.+)://(.*)" + m = re.match(pattern, url) + return m is not None + + +def parse_connector_type(url: str) -> str: + """ + Parse the connector type from the URL of the format: + :// + """ + pattern = r"(.+)://(.*)" + m = re.match(pattern, url) + if m is None: + return "" + + return m.group(1)