diff --git a/python/sglang/srt/configs/device_config.py b/python/sglang/srt/configs/device_config.py index 3b9d3a1ed..20b9af9be 100644 --- a/python/sglang/srt/configs/device_config.py +++ b/python/sglang/srt/configs/device_config.py @@ -8,10 +8,12 @@ logger = logging.getLogger(__name__) class DeviceConfig: device: Optional[torch.device] + gpu_id: Optional[int] - def __init__(self, device: str = "cuda") -> None: + def __init__(self, device: str = "cuda", gpu_id: int = -1) -> None: if device in ["cuda", "xpu", "hpu", "cpu", "npu"]: self.device_type = device else: raise RuntimeError(f"Not supported device type: {device}") self.device = torch.device(self.device_type) + self.gpu_id = gpu_id diff --git a/python/sglang/srt/configs/load_config.py b/python/sglang/srt/configs/load_config.py index be9a40b4b..6ac003ea4 100644 --- a/python/sglang/srt/configs/load_config.py +++ b/python/sglang/srt/configs/load_config.py @@ -23,6 +23,7 @@ class LoadFormat(str, enum.Enum): LAYERED = "layered" JAX = "jax" REMOTE = "remote" + REMOTE_INSTANCE = "remote_instance" @dataclass diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 2ba4bbc7a..febf94735 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -64,12 +64,28 @@ class ModelConfig: is_draft_model: bool = False, hybrid_kvcache_ratio: Optional[float] = None, model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, + tp_rank: Optional[int] = None, + remote_instance_weight_loader_seed_instance_ip: Optional[str] = None, + remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None, + remote_instance_weight_loader_send_weights_group_ports: Optional[ + List[int] + ] = None, ) -> None: # Parse args self.model_path = model_path self.revision = revision self.quantization = quantization self.model_impl = model_impl + self.tp_rank = tp_rank + self.remote_instance_weight_loader_seed_instance_ip = ( + remote_instance_weight_loader_seed_instance_ip + ) + self.remote_instance_weight_loader_seed_instance_service_port = ( + remote_instance_weight_loader_seed_instance_service_port + ) + self.remote_instance_weight_loader_send_weights_group_ports = ( + remote_instance_weight_loader_send_weights_group_ports + ) self.maybe_pull_model_tokenizer_from_remote() self.model_override_args = json.loads(model_override_args) @@ -329,6 +345,9 @@ class ModelConfig: quantization=server_args.quantization, hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio, model_impl=server_args.model_impl, + remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip, + remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port, + remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports, **kwargs, ) diff --git a/python/sglang/srt/connector/__init__.py b/python/sglang/srt/connector/__init__.py index 38e1d5eab..c9663a836 100644 --- a/python/sglang/srt/connector/__init__.py +++ b/python/sglang/srt/connector/__init__.py @@ -9,6 +9,7 @@ from sglang.srt.connector.base_connector import ( BaseKVConnector, ) from sglang.srt.connector.redis import RedisConnector +from sglang.srt.connector.remote_instance import RemoteInstanceConnector from sglang.srt.connector.s3 import S3Connector from sglang.srt.utils import parse_connector_type @@ -18,14 +19,17 @@ logger = logging.getLogger(__name__) class ConnectorType(str, enum.Enum): FS = "filesystem" KV = "KV" + INSTANCE = "instance" -def create_remote_connector(url, **kwargs) -> BaseConnector: +def create_remote_connector(url, device, **kwargs) -> BaseConnector: connector_type = parse_connector_type(url) if connector_type == "redis": return RedisConnector(url) elif connector_type == "s3": return S3Connector(url) + elif connector_type == "instance": + return RemoteInstanceConnector(url, device) else: raise ValueError(f"Invalid connector type: {url}") @@ -35,6 +39,8 @@ def get_connector_type(client: BaseConnector) -> ConnectorType: return ConnectorType.KV if isinstance(client, BaseFileConnector): return ConnectorType.FS + if isinstance(client, RemoteInstanceConnector): + return ConnectorType.INSTANCE raise ValueError(f"Invalid connector type: {client}") @@ -44,6 +50,7 @@ __all__ = [ "BaseFileConnector", "BaseKVConnector", "RedisConnector", + "RemoteInstanceConnector", "S3Connector", "ConnectorType", "create_remote_connector", diff --git a/python/sglang/srt/connector/remote_instance.py b/python/sglang/srt/connector/remote_instance.py new file mode 100644 index 000000000..e1f00037f --- /dev/null +++ b/python/sglang/srt/connector/remote_instance.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Generator, List, Optional, Tuple +from urllib.parse import urlparse + +import torch +import torch.distributed as dist + +from sglang.srt.connector import BaseConnector +from sglang.srt.utils import init_custom_process_group + +logger = logging.getLogger(__name__) + + +class RemoteInstanceConnector(BaseConnector): + + def __init__(self, url: str, device: torch.device = "cpu"): + assert ( + device.type == "cuda" + ), "RemoteInstanceConnector only supports cuda device." + super().__init__(url) + self.url = url + self.device = device + + def build_group( + self, + gpu_id: int = -1, + tp_rank: int = -1, + instance_ip: str = None, + group_rank: int = 1, + world_size: int = 2, + ): + assert ( + self.device.type == "cuda" + ), "RemoteInstanceConnector only supports cuda device." + assert ( + gpu_id != -1 and tp_rank != -1 + ), "gpu_id and tp_rank must be specified for RemoteInstanceConnector. " + + self.device_id = torch.device(self.device.type, gpu_id) + + parsed_url = urlparse(self.url) + master_address = parsed_url.hostname + master_port = parsed_url.port + group_name = f"send_weights_{instance_ip}_{master_port}_{tp_rank}" + backend = "nccl" + + logger.info( + f"init custom process group: master_address={master_address}, master_port={master_port}, " + f"rank_offset={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}" + ) + + try: + self._model_update_group = init_custom_process_group( + backend=backend, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=group_rank, + group_name=group_name, + device_id=self.device_id, + ) + dist.barrier(group=self._model_update_group) + return True, "Succeeded to initialize custom process group." + except Exception as e: + message = f"Failed to initialize custom process group: {e}." + logger.error(message) + return False, message + + # Implemented as a no-op to make BaseConnector interface consistent. + def pull_files( + self, + allow_pattern: Optional[list[str]] = None, + ignore_pattern: Optional[list[str]] = None, + ) -> None: + return + + # Implemented as a no-op to make BaseConnector interface consistent. + def weight_iterator( + self, rank: int = 0 + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + return diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 06f9db057..524db4693 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -73,6 +73,7 @@ from sglang.srt.managers.io_struct import ( EmbeddingReqInput, GenerateReqInput, GetWeightsByNameReqInput, + InitWeightsSendGroupForRemoteInstanceReqInput, InitWeightsUpdateGroupReqInput, LoadLoRAAdapterReqInput, OpenSessionReqInput, @@ -80,6 +81,7 @@ from sglang.srt.managers.io_struct import ( ProfileReqInput, ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, + SendWeightsToRemoteInstanceReqInput, SeparateReasoningReqInput, SetInternalStateReq, SlowDownReqInput, @@ -670,6 +672,38 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R ) +@app.post("/init_weights_send_group_for_remote_instance") +async def init_weights_send_group_for_remote_instance( + obj: InitWeightsSendGroupForRemoteInstanceReqInput, request: Request +): + success, message = ( + await _global_state.tokenizer_manager.init_weights_send_group_for_remote_instance( + obj, request + ) + ) + content = {"success": success, "message": message} + if success: + return ORJSONResponse(content, status_code=200) + else: + return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) + + +@app.post("/send_weights_to_remote_instance") +async def send_weights_to_remote_instance( + obj: SendWeightsToRemoteInstanceReqInput, request: Request +): + success, message = ( + await _global_state.tokenizer_manager.send_weights_to_remote_instance( + obj, request + ) + ) + content = {"success": success, "message": message} + if success: + return ORJSONResponse(content, status_code=200) + else: + return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) + + @app.post("/init_weights_update_group") async def init_weights_update_group( obj: InitWeightsUpdateGroupReqInput, request: Request diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 6237cd383..093060174 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1020,6 +1020,44 @@ class UpdateWeightsFromTensorReqOutput: message: str +@dataclass +class InitWeightsSendGroupForRemoteInstanceReqInput: + # The master address + master_address: str + # The ports for each rank's communication group + ports: str + # The rank in the communication group + group_rank: int + # The world size + world_size: int + # The group name + group_name: str = "weight_send_group" + # The backend + backend: str = "nccl" + + +@dataclass +class InitWeightsSendGroupForRemoteInstanceReqOutput: + success: bool + message: str + + +@dataclass +class SendWeightsToRemoteInstanceReqInput: + # The master address + master_address: str + # The ports for each rank's communication group + ports: str + # The group name + group_name: str = "weight_send_group" + + +@dataclass +class SendWeightsToRemoteInstanceReqOutput: + success: bool + message: str + + @dataclass class InitWeightsUpdateGroupReqInput: # The master address diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5b80afcc1..cd915f765 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -81,6 +81,8 @@ from sglang.srt.managers.io_struct import ( GetInternalStateReqOutput, GetWeightsByNameReqInput, HealthCheckOutput, + InitWeightsSendGroupForRemoteInstanceReqInput, + InitWeightsSendGroupForRemoteInstanceReqOutput, InitWeightsUpdateGroupReqInput, LoadLoRAAdapterReqInput, LoadLoRAAdapterReqOutput, @@ -93,6 +95,8 @@ from sglang.srt.managers.io_struct import ( ResumeMemoryOccupationReqInput, RpcReqInput, RpcReqOutput, + SendWeightsToRemoteInstanceReqInput, + SendWeightsToRemoteInstanceReqOutput, SetInternalStateReq, SetInternalStateReqOutput, SlowDownReqInput, @@ -538,6 +542,14 @@ class Scheduler( (CloseSessionReqInput, self.close_session), (UpdateWeightFromDiskReqInput, self.update_weights_from_disk), (InitWeightsUpdateGroupReqInput, self.init_weights_update_group), + ( + InitWeightsSendGroupForRemoteInstanceReqInput, + self.init_weights_send_group_for_remote_instance, + ), + ( + SendWeightsToRemoteInstanceReqInput, + self.send_weights_to_remote_instance, + ), ( UpdateWeightsFromDistributedReqInput, self.update_weights_from_distributed, @@ -2429,6 +2441,22 @@ class Scheduler( self.send_to_detokenizer.send_pyobj(recv_req) return recv_req + def init_weights_send_group_for_remote_instance( + self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput + ): + """Init the seed and client instance communication group.""" + success, message = self.tp_worker.init_weights_send_group_for_remote_instance( + recv_req + ) + return InitWeightsSendGroupForRemoteInstanceReqOutput(success, message) + + def send_weights_to_remote_instance( + self, recv_req: SendWeightsToRemoteInstanceReqInput + ): + """Send the seed instance weights to the destination instance.""" + success, message = self.tp_worker.send_weights_to_remote_instance(recv_req) + return SendWeightsToRemoteInstanceReqOutput(success, message) + def slow_down(self, recv_req: SlowDownReqInput): t = recv_req.forward_sleep_time if t is not None and t <= 0: diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index e59d3f296..33c222a94 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -30,6 +30,8 @@ from sglang.srt.managers.io_struct import ( GetInternalStateReqOutput, GetWeightsByNameReqInput, GetWeightsByNameReqOutput, + InitWeightsSendGroupForRemoteInstanceReqInput, + InitWeightsSendGroupForRemoteInstanceReqOutput, InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqOutput, LoadLoRAAdapterReqInput, @@ -43,6 +45,8 @@ from sglang.srt.managers.io_struct import ( ReleaseMemoryOccupationReqOutput, ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqOutput, + SendWeightsToRemoteInstanceReqInput, + SendWeightsToRemoteInstanceReqOutput, SetInternalStateReq, SetInternalStateReqOutput, SlowDownReqInput, @@ -119,6 +123,12 @@ class TokenizerCommunicatorMixin: self.update_weights_from_distributed_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.init_weights_send_group_for_remote_instance_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.send_weights_to_remote_instance_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) self.update_weights_from_tensor_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) @@ -169,6 +179,14 @@ class TokenizerCommunicatorMixin: UpdateWeightsFromDistributedReqOutput, self.update_weights_from_distributed_communicator.handle_recv, ), + ( + InitWeightsSendGroupForRemoteInstanceReqOutput, + self.init_weights_send_group_for_remote_instance_communicator.handle_recv, + ), + ( + SendWeightsToRemoteInstanceReqOutput, + self.send_weights_to_remote_instance_communicator.handle_recv, + ), ( UpdateWeightsFromTensorReqOutput, self.update_weights_from_tensor_communicator.handle_recv, @@ -310,6 +328,34 @@ class TokenizerCommunicatorMixin: result = (await self.update_weights_from_distributed_communicator(obj))[0] return result.success, result.message + async def init_weights_send_group_for_remote_instance( + self, + obj: InitWeightsSendGroupForRemoteInstanceReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() + # TODO: support DP + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be 1 for init_weights_send_group_for_remote_instance" + result = ( + await self.init_weights_send_group_for_remote_instance_communicator(obj) + )[0] + return result.success, result.message + + async def send_weights_to_remote_instance( + self, + obj: SendWeightsToRemoteInstanceReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() + # TODO: support DP + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be 1 for send_weights_to_remote_instance" + result = (await self.send_weights_to_remote_instance_communicator(obj))[0] + return result.success, result.message + async def update_weights_from_tensor( self: TokenizerManager, obj: UpdateWeightsFromTensorReqInput, diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 1cdc48c25..4c13ff796 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -30,8 +30,10 @@ from sglang.srt.hf_transformers_utils import ( from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( GetWeightsByNameReqInput, + InitWeightsSendGroupForRemoteInstanceReqInput, InitWeightsUpdateGroupReqInput, LoadLoRAAdapterReqInput, + SendWeightsToRemoteInstanceReqInput, UnloadLoRAAdapterReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, @@ -88,6 +90,7 @@ class TpModelWorker: else server_args.speculative_draft_model_revision ), is_draft_model=is_draft_worker, + tp_rank=tp_rank, ) self.model_runner = ModelRunner( @@ -292,6 +295,31 @@ class TpModelWorker: ) return success, message + def init_weights_send_group_for_remote_instance( + self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput + ): + success, message = ( + self.model_runner.init_weights_send_group_for_remote_instance( + recv_req.master_address, + recv_req.ports, + recv_req.group_rank, + recv_req.world_size, + recv_req.group_name, + recv_req.backend, + ) + ) + return success, message + + def send_weights_to_remote_instance( + self, recv_req: SendWeightsToRemoteInstanceReqInput + ): + success, message = self.model_runner.send_weights_to_remote_instance( + recv_req.master_address, + recv_req.ports, + recv_req.group_name, + ) + return success, message + def update_weights_from_distributed( self, recv_req: UpdateWeightsFromDistributedReqInput ): diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index e72d4fb6e..399ac1675 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -26,8 +26,10 @@ import torch from sglang.srt.managers.io_struct import ( GetWeightsByNameReqInput, + InitWeightsSendGroupForRemoteInstanceReqInput, InitWeightsUpdateGroupReqInput, LoadLoRAAdapterReqInput, + SendWeightsToRemoteInstanceReqInput, UnloadLoRAAdapterReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, @@ -267,6 +269,20 @@ class TpModelWorkerClient: success, message = self.worker.init_weights_update_group(recv_req) return success, message + def init_weights_send_group_for_remote_instance( + self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput + ): + success, message = self.worker.init_weights_send_group_for_remote_instance( + recv_req + ) + return success, message + + def send_weights_to_remote_instance( + self, recv_req: SendWeightsToRemoteInstanceReqInput + ): + success, message = self.worker.send_weights_to_remote_instance(recv_req) + return success, message + def update_weights_from_distributed( self, recv_req: UpdateWeightsFromDistributedReqInput ): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index aa0e2e0e6..39d1ab5fd 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -19,18 +19,23 @@ import inspect import json import logging import os +import socket +import threading import time from collections import defaultdict from dataclasses import dataclass from typing import List, Optional, Tuple, Union +from urllib.parse import urlparse +import requests import torch import torch.distributed as dist from sglang.srt.configs.device_config import DeviceConfig -from sglang.srt.configs.load_config import LoadConfig +from sglang.srt.configs.load_config import LoadConfig, LoadFormat from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp +from sglang.srt.connector import ConnectorType from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS from sglang.srt.distributed import ( get_pp_group, @@ -106,6 +111,9 @@ from sglang.srt.offloader import ( set_offloader, ) from sglang.srt.patch_torch import monkey_patch_torch_reductions +from sglang.srt.remote_instance_weight_loader_utils import ( + trigger_init_weights_send_group_for_remote_instance_request, +) from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm @@ -128,6 +136,7 @@ from sglang.srt.utils import ( is_sm100_supported, monkey_patch_p2p_access_check, monkey_patch_vllm_gguf_config, + parse_connector_type, set_cuda_arch, ) from sglang.srt.weight_sync.tensor_bucket import ( @@ -256,6 +265,7 @@ class ModelRunner: # For weight updates self._model_update_group = {} + self._weights_send_group = {} def initialize(self, min_per_gpu_memory: float): server_args = self.server_args @@ -726,6 +736,20 @@ class ModelRunner: if self.server_args.load_format == "gguf": monkey_patch_vllm_gguf_config() + if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE: + if self.tp_rank == 0: + instance_ip = socket.gethostbyname(socket.gethostname()) + t = threading.Thread( + target=trigger_init_weights_send_group_for_remote_instance_request, + args=( + self.server_args.remote_instance_weight_loader_seed_instance_ip, + self.server_args.remote_instance_weight_loader_seed_instance_service_port, + self.server_args.remote_instance_weight_loader_send_weights_group_ports, + instance_ip, + ), + ) + t.start() + # Load the model # Remove monkey_patch when linear.py quant remove dependencies with vllm monkey_patch_vllm_parallel_state() @@ -735,7 +759,7 @@ class ModelRunner: self.model = get_model( model_config=self.model_config, load_config=self.load_config, - device_config=DeviceConfig(self.device), + device_config=DeviceConfig(self.device, self.gpu_id), ) monkey_patch_vllm_parallel_state(reverse=True) monkey_patch_isinstance_for_vllm_base_layer(reverse=True) @@ -867,6 +891,103 @@ class ModelRunner: logger.info("Update weights end.") return True, "Succeeded to update model weights." + def init_weights_send_group_for_remote_instance( + self, + master_address, + ports, + group_rank, + world_size, + group_name, + backend="nccl", + ): + assert ( + torch.distributed.is_initialized() + ), "Default torch process group must be initialized" + assert group_name != "", "Group name cannot be empty" + + ports_list = ports.split(",") + assert ( + len(ports_list) == self.tp_size + ), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports." + group_port = ports_list[self.tp_rank] + group_name = f"{group_name}_{group_port}_{self.tp_rank}" + + logger.info( + f"init custom process group: tp_rank={self.tp_rank}, gpu_id={self.gpu_id}, master_address={master_address}, master_port={group_port}, " + f"group_rank={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}" + ) + + torch.cuda.empty_cache() + success = False + message = "" + try: + self._weights_send_group[group_name] = init_custom_process_group( + backend=backend, + init_method=f"tcp://{master_address}:{group_port}", + world_size=world_size, + rank=group_rank, + group_name=group_name, + device_id=torch.device("cuda", self.gpu_id), + ) + dist.barrier(group=self._weights_send_group[group_name]) + success = True + message = ( + f"Succeeded to init group through {master_address}:{group_port} group." + ) + except Exception as e: + message = f"Failed to init group: {e}." + logger.error(message) + + torch.cuda.empty_cache() + return success, message + + def send_weights_to_remote_instance( + self, + master_address, + ports, + group_name, + ): + assert ( + torch.distributed.is_initialized() + ), "Default torch process group must be initialized" + assert group_name != "", "Group name cannot be empty" + + ports_list = ports.split(",") + assert ( + len(ports_list) == self.tp_size + ), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports." + group_port = ports_list[self.tp_rank] + group_name = f"{group_name}_{group_port}_{self.tp_rank}" + + if self._weights_send_group[group_name] is not None: + send_group = self._weights_send_group[group_name] + else: + message = f"Group {group_name} not in _weights_send_group list. Please call `init_weights_send_group_for_remote_instance` first." + logger.error(message) + return False, message + + torch.cuda.empty_cache() + success = False + message = "" + try: + for _, weights in self.model.named_parameters(): + torch.distributed.broadcast( + weights, + src=0, + group=send_group, + ) + success = True + message = f"Succeeded to send weights through {master_address}:{group_port} {group_name}." + except Exception as e: + message = f"Failed to send weights: {e}." + logger.error(message) + + # destroy the process group after sending weights + del self._weights_send_group[group_name] + torch.distributed.distributed_c10d.destroy_process_group(send_group) + torch.cuda.empty_cache() + return success, message + def init_weights_update_group( self, master_address, diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index d2b4c6bfc..ab9c69fc2 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -12,6 +12,9 @@ import json import logging import math import os +import re +import socket +import threading import time from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor @@ -27,9 +30,11 @@ from typing import ( Tuple, cast, ) +from urllib.parse import urlparse import huggingface_hub import numpy as np +import requests import safetensors.torch import torch from huggingface_hub import HfApi, hf_hub_download @@ -56,6 +61,7 @@ from sglang.srt.model_loader.utils import ( ) from sglang.srt.model_loader.weight_utils import ( _BAR_FORMAT, + default_weight_loader, download_safetensors_index_file_from_hf, download_weights_from_hf, filter_duplicate_safetensors_files, @@ -71,6 +77,9 @@ from sglang.srt.model_loader.weight_utils import ( safetensors_weights_iterator, set_runai_streamer_env, ) +from sglang.srt.remote_instance_weight_loader_utils import ( + trigger_transferring_weights_request, +) from sglang.srt.utils import ( get_bool_env_var, get_device_capability, @@ -1380,6 +1389,104 @@ class GGUFModelLoader(BaseModelLoader): return model +class RemoteInstanceModelLoader(BaseModelLoader): + """Model loader that can load Tensors from remote sglang instance.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError( + f"Model loader extra config is not supported for " + f"load format {load_config.load_format}" + ) + + def download_model(self, model_config: ModelConfig) -> None: + raise NotImplementedError + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + ) -> nn.Module: + logger.info("Loading weights from remote instance ...") + load_config = self.load_config + + assert load_config.load_format == LoadFormat.REMOTE_INSTANCE, ( + f"Model loader {self.load_config.load_format} is not supported for " + f"load format {load_config.load_format}" + ) + + model_weights = f"instance://{model_config.remote_instance_weight_loader_seed_instance_ip}:{model_config.remote_instance_weight_loader_send_weights_group_ports[model_config.tp_rank]}" + + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config) + + with create_remote_connector(model_weights, device_config.device) as client: + connector_type = get_connector_type(client) + if connector_type == ConnectorType.INSTANCE: + self.load_model_from_remote_instance( + model, client, model_config, device_config + ) + else: + raise ValueError( + f"Unsupported connector type {connector_type} for " + f"remote tensor model loading." + ) + return model.eval() + + def load_model_from_remote_instance( + self, model, client, model_config: ModelConfig, device_config: DeviceConfig + ) -> nn.Module: + instance_ip = socket.gethostbyname(socket.gethostname()) + start_build_group_tic = time.time() + client.build_group( + gpu_id=device_config.gpu_id, + tp_rank=model_config.tp_rank, + instance_ip=instance_ip, + ) + torch.cuda.synchronize() + end_build_group_tic = time.time() + logger.debug( + f"finish building group for remote instance, time used: {(end_build_group_tic - start_build_group_tic):.4f}s" + ) + + if model_config.tp_rank == 0: + t = threading.Thread( + target=trigger_transferring_weights_request, + args=( + model_config.remote_instance_weight_loader_seed_instance_ip, + model_config.remote_instance_weight_loader_seed_instance_service_port, + model_config.remote_instance_weight_loader_send_weights_group_ports, + instance_ip, + ), + ) + t.start() + + start_get_weights_tic = time.time() + with set_default_torch_dtype(model_config.dtype): + for _, tensor in model.named_parameters(): + torch.distributed.broadcast( + tensor.data, + src=0, + group=client._model_update_group, + ) + torch.cuda.synchronize() + + if hasattr(model, "post_load_weights"): + model.post_load_weights() + end_get_weights_tic = time.time() + logger.debug( + f"finish getting all weights from remote instance, time used: {(end_get_weights_tic - start_get_weights_tic):.4f}s" + ) + # destroy the process group after loading weights + torch.distributed.distributed_c10d.destroy_process_group( + client._model_update_group + ) + torch.cuda.empty_cache() + + class RemoteModelLoader(BaseModelLoader): """Model loader that can load Tensors from remote database.""" @@ -1581,4 +1688,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: if load_config.load_format == LoadFormat.REMOTE: return RemoteModelLoader(load_config) + if load_config.load_format == LoadFormat.REMOTE_INSTANCE: + return RemoteInstanceModelLoader(load_config) + return DefaultModelLoader(load_config) diff --git a/python/sglang/srt/remote_instance_weight_loader_utils.py b/python/sglang/srt/remote_instance_weight_loader_utils.py new file mode 100644 index 000000000..5974bba20 --- /dev/null +++ b/python/sglang/srt/remote_instance_weight_loader_utils.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import List + +import requests + +logger = logging.getLogger(__name__) + + +def trigger_init_weights_send_group_for_remote_instance_request( + remote_instance_weight_loader_seed_instance_ip: str, + remote_instance_weight_loader_seed_instance_service_port: int, + remote_instance_weight_loader_send_weights_group_ports: List[int], + remote_instance_weight_loader_client_id: str, +): + seed_instance_service_url = f"http://{remote_instance_weight_loader_seed_instance_ip}:{remote_instance_weight_loader_seed_instance_service_port}" + # Only support loading weights from instance with same parallelism strategy. + # Per TP rank pair between seed and dst instances will build a communication group for sending weights. + # i.e. seed TP 0 <-> dst TP 0, seed TP 1 <-> dst TP 1, etc. + # Each communication group will have a world size 2. + try: + requests.post( + f"{seed_instance_service_url}/init_weights_send_group_for_remote_instance", + json={ + "master_address": remote_instance_weight_loader_seed_instance_ip, + "ports": ( + ",".join( + str(p) + for p in remote_instance_weight_loader_send_weights_group_ports + ) + ), + "group_rank": 0, + "world_size": 2, + "group_name": f"send_weights_{remote_instance_weight_loader_client_id}", + "backend": "nccl", + }, + ) + except Exception as e: + logger.error( + f"Failed to trigger init_weights_send_group_for_remote_instance_request to seed instance {seed_instance_service_url}: {e}." + ) + raise + + +def trigger_transferring_weights_request( + remote_instance_weight_loader_seed_instance_ip: str, + remote_instance_weight_loader_seed_instance_service_port: int, + remote_instance_weight_loader_send_weights_group_ports: List[int], + remote_instance_weight_loader_client_id: str, +): + seed_instance_service_url = f"http://{remote_instance_weight_loader_seed_instance_ip}:{remote_instance_weight_loader_seed_instance_service_port}" + try: + requests.post( + f"{seed_instance_service_url}/send_weights_to_remote_instance", + json={ + "master_address": remote_instance_weight_loader_seed_instance_ip, + "ports": ( + ",".join( + str(p) + for p in remote_instance_weight_loader_send_weights_group_ports + ) + ), + "group_name": f"send_weights_{remote_instance_weight_loader_client_id}", + }, + ) + except Exception as e: + logger.error(f"Failed to trigger send weights to remote instance request: {e}") + raise diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 77199e7d3..f041011e0 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -19,10 +19,12 @@ import json import logging import os import random +import socket import sys import tempfile from typing import List, Literal, Optional, Union +from sglang.srt.connector import ConnectorType from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.hf_transformers_utils import check_gguf_file, get_config from sglang.srt.lora.lora_registry import LoRARef @@ -42,7 +44,9 @@ from sglang.srt.utils import ( is_sm100_supported, is_triton_kernels_available, is_valid_ipv6_address, + json_list_type, nullable_str, + parse_connector_type, ) from sglang.utils import is_in_ci @@ -61,6 +65,7 @@ LOAD_FORMAT_CHOICES = [ "bitsandbytes", "layered", "remote", + "remote_instance", ] QUANTIZATION_CHOICES = [ @@ -387,6 +392,11 @@ class ServerArgs: custom_weight_loader: Optional[List[str]] = None weight_loader_disable_mmap: bool = False + # Remote instance weight loading + remote_instance_weight_loader_seed_instance_ip: Optional[str] = None + remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None + remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None + # For PD-Multiplexing enable_pdmux: bool = False sm_group_num: int = 3 @@ -445,6 +455,7 @@ class ServerArgs: # Set missing default values if self.tokenizer_path is None: self.tokenizer_path = self.model_path + if self.served_model_name is None: self.served_model_name = self.model_path if self.device is None: @@ -538,7 +549,8 @@ class ServerArgs: self.sampling_backend = "pytorch" # Model-specific adjustments - self.model_specific_adjustments() + if parse_connector_type(self.model_path) != ConnectorType.INSTANCE: + self.model_specific_adjustments() # Set kernel backends if self.device == "cpu": @@ -818,12 +830,19 @@ class ServerArgs: ) and check_gguf_file(self.model_path): self.quantization = self.load_format = "gguf" - # Model loading if is_remote_url(self.model_path): self.load_format = "remote" if self.custom_weight_loader is None: self.custom_weight_loader = [] + if self.load_format == "remote_instance": + if ( + self.remote_instance_weight_loader_seed_instance_ip is None + or self.remote_instance_weight_loader_seed_instance_service_port is None + or self.remote_instance_weight_loader_send_weights_group_ports is None + ): + self.load_format = "auto" + # PD disaggregation if self.disaggregation_mode == "decode": assert ( @@ -881,6 +900,24 @@ class ServerArgs: help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.", required=True, ) + parser.add_argument( + "--remote-instance-weight-loader-seed-instance-ip", + type=str, + default=ServerArgs.remote_instance_weight_loader_seed_instance_ip, + help="The ip of the seed instance for loading weights from remote instance.", + ) + parser.add_argument( + "--remote-instance-weight-loader-seed-instance-service-port", + type=int, + default=ServerArgs.remote_instance_weight_loader_seed_instance_service_port, + help="The service port of the seed instance for loading weights from remote instance.", + ) + parser.add_argument( + "--remote-instance-weight-loader-send-weights-group-ports", + type=json_list_type, + default=ServerArgs.remote_instance_weight_loader_send_weights_group_ports, + help="The communication group ports for loading weights from remote instance.", + ) parser.add_argument( "--tokenizer-path", type=str, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 846baeb01..914d371b7 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -15,6 +15,7 @@ from __future__ import annotations +import argparse import asyncio import builtins import ctypes @@ -1431,6 +1432,7 @@ def init_custom_process_group( store=None, group_name=None, pg_options=None, + device_id=None, ): from torch.distributed.distributed_c10d import ( Backend, @@ -1484,6 +1486,7 @@ def init_custom_process_group( group_name=group_name, **{pg_options_param_name: pg_options}, timeout=timeout, + device_id=device_id, ) _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} @@ -3046,3 +3049,12 @@ def numa_bind_to_node(node: int): libnuma.numa_run_on_node(ctypes.c_int(node)) libnuma.numa_set_localalloc() + + +def json_list_type(value): + try: + return json.loads(value) + except json.JSONDecodeError: + raise argparse.ArgumentTypeError( + f"Invalid JSON list: {value}. Please provide a valid JSON list." + ) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 593920d9d..bb881bf8c 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -123,6 +123,7 @@ suites = { TestFile("rl/test_update_weights_from_distributed.py", 103), TestFile("test_data_parallelism.py", 73), TestFile("test_dp_attention.py", 277), + TestFile("test_load_weights_from_remote_instance.py", 72), TestFile("test_patch_torch.py", 19), TestFile("test_release_memory_occupation.py", 127), TestFile("hicache/test_hicache_storage_file_backend.py", 400), @@ -251,6 +252,7 @@ suite_amd = { TestFile("lora/test_lora_tp.py", 116), TestFile("rl/test_update_weights_from_distributed.py", 103), TestFile("test_data_parallelism.py", 73), + TestFile("test_load_weights_from_remote_instance.py", 72), TestFile("test_patch_torch.py", 19), ], "per-commit-4-gpu-amd": [ diff --git a/test/srt/test_load_weights_from_remote_instance.py b/test/srt/test_load_weights_from_remote_instance.py new file mode 100644 index 000000000..71ab24d1d --- /dev/null +++ b/test/srt/test_load_weights_from_remote_instance.py @@ -0,0 +1,384 @@ +"""Test loading weights from remote instance. + +This test suite simulates loading weights from a remote instance. +Rank 0 represents the seed instance, while ranks 1 represents the +new instance that needs to loading weights from the seed instance. + +Seed instance must be started in `Server` mode, while the dst instance +can be either `Engine` mode or `Server` mode. + +Seed instance does not support concurrently serving multiple dst instances. +User has to guarantee that there is only one dst instance trying to load +weights from the seed instance at any time. + +""" + +import gc +import os +import random +import unittest + +import numpy as np +import requests +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import sglang as sgl +from sglang.test.test_utils import ( + DEFAULT_PORT_FOR_SRT_TEST_RUNNER, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + is_in_ci, + popen_launch_server, +) +from sglang.utils import terminate_process + +mp.set_start_method("spawn", force=True) + + +def verify_params_close(params1, params2, error_msg): + """Verify if two parameter arrays are close enough.""" + try: + assert np.allclose(np.array(params1), np.array(params2)), error_msg + except Exception as e: + print(f"Parameters not close for {error_msg}") + print("Params1:", np.array(params1)) + print("Params2:", np.array(params2)) + raise e + + +def init_process( + rank, + param_queue, + truncate_size, + tp_size, + model_name, + backends, + checking_parameters, + seed_instance_ip, + seed_instance_service_port, + seed_instance_group_base_port, + event_seed_ready, + event_dst_ready_list, +): + torch.cuda.set_device(rank) + + if rank == 0: + init_process_seed( + rank, + param_queue, + truncate_size, + model_name, + checking_parameters, + tp_size, + event_seed_ready, + event_dst_ready_list, + ) + elif rank in [1, 2]: + init_process_dst( + rank, + param_queue, + truncate_size, + model_name, + seed_instance_ip, + seed_instance_service_port, + seed_instance_group_base_port, + checking_parameters, + backends[rank - 1], + tp_size, + event_seed_ready, + event_dst_ready_list, + ) + + +def init_process_seed( + rank, + param_queue, + truncate_size, + model_name, + checking_parameters, + tp_size, + event_seed_ready, + event_dst_ready_list, +): + # These two environment variables are very important + # to avoid unexpected behaviors of CUDA and NCCL. + os.environ["NCCL_CUMEM_ENABLE"] = "0" + os.environ["NCCL_NVLS_ENABLE"] = "0" + + # Load model and get parameters + torch.cuda.set_device(rank) + torch.cuda.synchronize() + + url = DEFAULT_URL_FOR_TEST + process = popen_launch_server( + model_name, + url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=( + "--base-gpu-id", + str(rank), + "--tp-size", + str(tp_size), + ), + ) + torch.cuda.synchronize() + + seed_params = [] + # Get the weights of seed instance for correctness check. + for parameter_name in checking_parameters: + seed_params.append( + requests.get( + f"{url}/get_weights_by_name", + json={ + "name": parameter_name, + "truncate_size": truncate_size, + }, + ).json() + ) + param_queue.put((f"seed_params", seed_params)) + + event_seed_ready.set() + for i in range(len(event_dst_ready_list)): + event_dst_ready_list[i].wait() + terminate_process(process) + + +def init_process_dst( + rank, + param_queue, + truncate_size, + model_name, + seed_instance_ip, + seed_instance_service_port, + seed_instance_group_base_port, + checking_parameters, + backend, + tp_size, + event_seed_ready, + event_dst_ready_list, +): + torch.cuda.set_device(rank * tp_size) + torch.cuda.synchronize() + base_gpu_id = rank * tp_size + + event_seed_ready.wait() + print(f"rank {rank}, seed ready") + for i in range(rank - 1): + print(f"rank {rank}, wait dst {i}") + event_dst_ready_list[i].wait() + + ports = [] + for i in range(tp_size): + ports.append(seed_instance_group_base_port + (rank - 1) * tp_size + i) + + if backend == "Engine": + print(f"[sgl] rank {rank} init engine") + engine = sgl.Engine( + model_path=model_name, + base_gpu_id=base_gpu_id, + tp_size=tp_size, + cuda_graph_max_bs=2, + tokenizer_path=model_name, + remote_instance_weight_loader_seed_instance_ip=seed_instance_ip, + remote_instance_weight_loader_seed_instance_service_port=seed_instance_service_port, + remote_instance_weight_loader_send_weights_group_ports=ports, + load_format="remote_instance", + ) + else: + host, _, port = DEFAULT_URL_FOR_TEST.rpartition(":") + url = ":".join([host, str(int(port) + 10000 + rank)]) + + print(f"[sgl] rank {rank} init server on url: {url}") + process = popen_launch_server( + model_name, + url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=( + "--base-gpu-id", + str(base_gpu_id), + "--tp-size", + str(tp_size), + "--cuda-graph-max-bs", + 2, + "--tokenizer-path", + model_name, + "--remote-instance-weight-loader-seed-instance-ip", + seed_instance_ip, + "--remote-instance-weight-loader-seed-instance-service-port", + seed_instance_service_port, + "--remote-instance-weight-loader-send-weights-group-ports", + f"[{','.join(str(port) for port in ports)}]", + "--load-format", + "remote_instance", + ), + ) + torch.cuda.synchronize() + + event_dst_ready_list[rank - 1].set() + + # Get weights of destination instance loaded from remote instance. + dst_params = [] + for parameter_name in checking_parameters: + dst_params.append( + engine.get_weights_by_name(parameter_name, truncate_size) + if backend == "Engine" + else requests.get( + f"{url}/get_weights_by_name", + json={"name": parameter_name, "truncate_size": truncate_size}, + ).json() + ) + + param_queue.put((f"sgl_dp_{rank}_dst_params", dst_params)) + + # Shutdown the engine or terminate the server process. + if backend == "Engine": + engine.shutdown() + else: + terminate_process(process) + + +def test_load_weights_from_remote_instance( + tp_size, + dp_size, + model_name, + backends, + truncate_size, + checking_parameters, + seed_instance_ip, + seed_instance_service_port, + seed_instance_group_base_port, +): + print( + f"Testing model: {model_name} tp_size: {tp_size}, dp_size: {dp_size} backend: {backends}" + ) + param_queue = mp.Queue() + results = {} + event_seed_ready = mp.Event() + event_dst_ready_list = [] + for i in range(dp_size): + event_dst_ready = mp.Event() + event_dst_ready_list.append(event_dst_ready) + + context = mp.spawn( + init_process, + args=( + param_queue, + truncate_size, + tp_size, + model_name, + backends, + checking_parameters, + seed_instance_ip, + seed_instance_service_port, + seed_instance_group_base_port, + event_seed_ready, + event_dst_ready_list, + ), + nprocs=1 + dp_size, + join=False, + ) + + while len(results) < (1 + dp_size): + try: + key, value = param_queue.get(timeout=5) + results[key] = value + except Exception as e: + if all(not p.is_alive() for p in context.processes): + break + + context.join() + + if len(results) != (1 + dp_size): + raise RuntimeError( + f"Expected {(1 + dp_size)} parameters but got {len(results)}" + ) + + params = { + "seed": results.get("seed_params"), + "sgl_dp_1_dest": results.get("sgl_dp_1_dst_params"), + } + + if dp_size == 2: + dp2_params = { + "sgl_dp_2_dest": results.get("sgl_dp_2_dst_params"), + } + assert all(v is not None for v in dp2_params.values()) + params.update(dp2_params) + + # Check the correctness of weights loaded from remote instance + # by verifying the weights of seed instance and destination instance. + for i in range(len(params["seed"])): + verify_params_close( + params["seed"][i], + params["sgl_dp_1_dest"][i], + f"sgl_dp_1_dst_params rank {i}", + ) + + if dp_size == 2: + verify_params_close( + params["seed"][i], + params["sgl_dp_2_dest"][i], + f"sgl_dp_2_dst_params rank {i}", + ) + + # Delete the context and close the parameter queue. + del context + param_queue.close() + param_queue.join_thread() + gc.collect() + torch.cuda.empty_cache() + + +class TestLoadWeightsFromRemoteInstance(CustomTestCase): + + def test_load_weights_from_remote_instance(self): + + assert torch.cuda.device_count() >= 2, "At least 2 GPUs are required" + # test_suits : tp, dp, model_name, backend, dst_instance_id + if is_in_ci(): + mode = random.choice(["Engine", "Server"]) + test_suits = [ + (1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, [mode]), + ] + else: + test_suits = [ + (1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, ["Engine"]), + (1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, ["Sever"]), + (2, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, ["Engine", "Server"]), + ] + + truncate_size = 10 + checking_parameters = [ + "model.embed_tokens.weight", + "model.layers.0.input_layernorm.weight", + "model.layers.1.self_attn.q_proj.weight", + "model.layers.2.self_attn.k_proj.weight", + "model.layers.3.self_attn.v_proj.weight", + "model.layers.4.self_attn.o_proj.weight", + "model.layers.5.mlp.gate_proj.weight", + "model.layers.6.mlp.up_proj.weight", + "model.layers.7.mlp.down_proj.weight", + "model.layers.8.post_attention_layernorm.weight", + "model.norm.weight", + ] + + for tp_size, dp_size, model_name, backends in test_suits: + test_load_weights_from_remote_instance( + tp_size, + dp_size, + model_name, + backends, + truncate_size, + checking_parameters, + "127.0.0.1", + DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000, + 60000, + ) + + +if __name__ == "__main__": + unittest.main()