Support loading weights from remote instance (#8215)
Signed-off-by: Anqi Shen <amy.saq@antgroup.com> Co-authored-by: Chayenne <74843776+zhaochenyang20@users.noreply.github.com>
This commit is contained in:
@@ -8,10 +8,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class DeviceConfig:
|
||||
device: Optional[torch.device]
|
||||
gpu_id: Optional[int]
|
||||
|
||||
def __init__(self, device: str = "cuda") -> None:
|
||||
def __init__(self, device: str = "cuda", gpu_id: int = -1) -> None:
|
||||
if device in ["cuda", "xpu", "hpu", "cpu", "npu"]:
|
||||
self.device_type = device
|
||||
else:
|
||||
raise RuntimeError(f"Not supported device type: {device}")
|
||||
self.device = torch.device(self.device_type)
|
||||
self.gpu_id = gpu_id
|
||||
|
||||
@@ -23,6 +23,7 @@ class LoadFormat(str, enum.Enum):
|
||||
LAYERED = "layered"
|
||||
JAX = "jax"
|
||||
REMOTE = "remote"
|
||||
REMOTE_INSTANCE = "remote_instance"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -64,12 +64,28 @@ class ModelConfig:
|
||||
is_draft_model: bool = False,
|
||||
hybrid_kvcache_ratio: Optional[float] = None,
|
||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||
tp_rank: Optional[int] = None,
|
||||
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None,
|
||||
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None,
|
||||
remote_instance_weight_loader_send_weights_group_ports: Optional[
|
||||
List[int]
|
||||
] = None,
|
||||
) -> None:
|
||||
# Parse args
|
||||
self.model_path = model_path
|
||||
self.revision = revision
|
||||
self.quantization = quantization
|
||||
self.model_impl = model_impl
|
||||
self.tp_rank = tp_rank
|
||||
self.remote_instance_weight_loader_seed_instance_ip = (
|
||||
remote_instance_weight_loader_seed_instance_ip
|
||||
)
|
||||
self.remote_instance_weight_loader_seed_instance_service_port = (
|
||||
remote_instance_weight_loader_seed_instance_service_port
|
||||
)
|
||||
self.remote_instance_weight_loader_send_weights_group_ports = (
|
||||
remote_instance_weight_loader_send_weights_group_ports
|
||||
)
|
||||
|
||||
self.maybe_pull_model_tokenizer_from_remote()
|
||||
self.model_override_args = json.loads(model_override_args)
|
||||
@@ -329,6 +345,9 @@ class ModelConfig:
|
||||
quantization=server_args.quantization,
|
||||
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
||||
model_impl=server_args.model_impl,
|
||||
remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
|
||||
remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
|
||||
remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from sglang.srt.connector.base_connector import (
|
||||
BaseKVConnector,
|
||||
)
|
||||
from sglang.srt.connector.redis import RedisConnector
|
||||
from sglang.srt.connector.remote_instance import RemoteInstanceConnector
|
||||
from sglang.srt.connector.s3 import S3Connector
|
||||
from sglang.srt.utils import parse_connector_type
|
||||
|
||||
@@ -18,14 +19,17 @@ logger = logging.getLogger(__name__)
|
||||
class ConnectorType(str, enum.Enum):
|
||||
FS = "filesystem"
|
||||
KV = "KV"
|
||||
INSTANCE = "instance"
|
||||
|
||||
|
||||
def create_remote_connector(url, **kwargs) -> BaseConnector:
|
||||
def create_remote_connector(url, device, **kwargs) -> BaseConnector:
|
||||
connector_type = parse_connector_type(url)
|
||||
if connector_type == "redis":
|
||||
return RedisConnector(url)
|
||||
elif connector_type == "s3":
|
||||
return S3Connector(url)
|
||||
elif connector_type == "instance":
|
||||
return RemoteInstanceConnector(url, device)
|
||||
else:
|
||||
raise ValueError(f"Invalid connector type: {url}")
|
||||
|
||||
@@ -35,6 +39,8 @@ def get_connector_type(client: BaseConnector) -> ConnectorType:
|
||||
return ConnectorType.KV
|
||||
if isinstance(client, BaseFileConnector):
|
||||
return ConnectorType.FS
|
||||
if isinstance(client, RemoteInstanceConnector):
|
||||
return ConnectorType.INSTANCE
|
||||
|
||||
raise ValueError(f"Invalid connector type: {client}")
|
||||
|
||||
@@ -44,6 +50,7 @@ __all__ = [
|
||||
"BaseFileConnector",
|
||||
"BaseKVConnector",
|
||||
"RedisConnector",
|
||||
"RemoteInstanceConnector",
|
||||
"S3Connector",
|
||||
"ConnectorType",
|
||||
"create_remote_connector",
|
||||
|
||||
82
python/sglang/srt/connector/remote_instance.py
Normal file
82
python/sglang/srt/connector/remote_instance.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.connector import BaseConnector
|
||||
from sglang.srt.utils import init_custom_process_group
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RemoteInstanceConnector(BaseConnector):
|
||||
|
||||
def __init__(self, url: str, device: torch.device = "cpu"):
|
||||
assert (
|
||||
device.type == "cuda"
|
||||
), "RemoteInstanceConnector only supports cuda device."
|
||||
super().__init__(url)
|
||||
self.url = url
|
||||
self.device = device
|
||||
|
||||
def build_group(
|
||||
self,
|
||||
gpu_id: int = -1,
|
||||
tp_rank: int = -1,
|
||||
instance_ip: str = None,
|
||||
group_rank: int = 1,
|
||||
world_size: int = 2,
|
||||
):
|
||||
assert (
|
||||
self.device.type == "cuda"
|
||||
), "RemoteInstanceConnector only supports cuda device."
|
||||
assert (
|
||||
gpu_id != -1 and tp_rank != -1
|
||||
), "gpu_id and tp_rank must be specified for RemoteInstanceConnector. "
|
||||
|
||||
self.device_id = torch.device(self.device.type, gpu_id)
|
||||
|
||||
parsed_url = urlparse(self.url)
|
||||
master_address = parsed_url.hostname
|
||||
master_port = parsed_url.port
|
||||
group_name = f"send_weights_{instance_ip}_{master_port}_{tp_rank}"
|
||||
backend = "nccl"
|
||||
|
||||
logger.info(
|
||||
f"init custom process group: master_address={master_address}, master_port={master_port}, "
|
||||
f"rank_offset={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
||||
)
|
||||
|
||||
try:
|
||||
self._model_update_group = init_custom_process_group(
|
||||
backend=backend,
|
||||
init_method=f"tcp://{master_address}:{master_port}",
|
||||
world_size=world_size,
|
||||
rank=group_rank,
|
||||
group_name=group_name,
|
||||
device_id=self.device_id,
|
||||
)
|
||||
dist.barrier(group=self._model_update_group)
|
||||
return True, "Succeeded to initialize custom process group."
|
||||
except Exception as e:
|
||||
message = f"Failed to initialize custom process group: {e}."
|
||||
logger.error(message)
|
||||
return False, message
|
||||
|
||||
# Implemented as a no-op to make BaseConnector interface consistent.
|
||||
def pull_files(
|
||||
self,
|
||||
allow_pattern: Optional[list[str]] = None,
|
||||
ignore_pattern: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
# Implemented as a no-op to make BaseConnector interface consistent.
|
||||
def weight_iterator(
|
||||
self, rank: int = 0
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
return
|
||||
@@ -73,6 +73,7 @@ from sglang.srt.managers.io_struct import (
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
OpenSessionReqInput,
|
||||
@@ -80,6 +81,7 @@ from sglang.srt.managers.io_struct import (
|
||||
ProfileReqInput,
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
SendWeightsToRemoteInstanceReqInput,
|
||||
SeparateReasoningReqInput,
|
||||
SetInternalStateReq,
|
||||
SlowDownReqInput,
|
||||
@@ -670,6 +672,38 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
|
||||
)
|
||||
|
||||
|
||||
@app.post("/init_weights_send_group_for_remote_instance")
|
||||
async def init_weights_send_group_for_remote_instance(
|
||||
obj: InitWeightsSendGroupForRemoteInstanceReqInput, request: Request
|
||||
):
|
||||
success, message = (
|
||||
await _global_state.tokenizer_manager.init_weights_send_group_for_remote_instance(
|
||||
obj, request
|
||||
)
|
||||
)
|
||||
content = {"success": success, "message": message}
|
||||
if success:
|
||||
return ORJSONResponse(content, status_code=200)
|
||||
else:
|
||||
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
@app.post("/send_weights_to_remote_instance")
|
||||
async def send_weights_to_remote_instance(
|
||||
obj: SendWeightsToRemoteInstanceReqInput, request: Request
|
||||
):
|
||||
success, message = (
|
||||
await _global_state.tokenizer_manager.send_weights_to_remote_instance(
|
||||
obj, request
|
||||
)
|
||||
)
|
||||
content = {"success": success, "message": message}
|
||||
if success:
|
||||
return ORJSONResponse(content, status_code=200)
|
||||
else:
|
||||
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
@app.post("/init_weights_update_group")
|
||||
async def init_weights_update_group(
|
||||
obj: InitWeightsUpdateGroupReqInput, request: Request
|
||||
|
||||
@@ -1020,6 +1020,44 @@ class UpdateWeightsFromTensorReqOutput:
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class InitWeightsSendGroupForRemoteInstanceReqInput:
|
||||
# The master address
|
||||
master_address: str
|
||||
# The ports for each rank's communication group
|
||||
ports: str
|
||||
# The rank in the communication group
|
||||
group_rank: int
|
||||
# The world size
|
||||
world_size: int
|
||||
# The group name
|
||||
group_name: str = "weight_send_group"
|
||||
# The backend
|
||||
backend: str = "nccl"
|
||||
|
||||
|
||||
@dataclass
|
||||
class InitWeightsSendGroupForRemoteInstanceReqOutput:
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendWeightsToRemoteInstanceReqInput:
|
||||
# The master address
|
||||
master_address: str
|
||||
# The ports for each rank's communication group
|
||||
ports: str
|
||||
# The group name
|
||||
group_name: str = "weight_send_group"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendWeightsToRemoteInstanceReqOutput:
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class InitWeightsUpdateGroupReqInput:
|
||||
# The master address
|
||||
|
||||
@@ -81,6 +81,8 @@ from sglang.srt.managers.io_struct import (
|
||||
GetInternalStateReqOutput,
|
||||
GetWeightsByNameReqInput,
|
||||
HealthCheckOutput,
|
||||
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||
InitWeightsSendGroupForRemoteInstanceReqOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
LoadLoRAAdapterReqOutput,
|
||||
@@ -93,6 +95,8 @@ from sglang.srt.managers.io_struct import (
|
||||
ResumeMemoryOccupationReqInput,
|
||||
RpcReqInput,
|
||||
RpcReqOutput,
|
||||
SendWeightsToRemoteInstanceReqInput,
|
||||
SendWeightsToRemoteInstanceReqOutput,
|
||||
SetInternalStateReq,
|
||||
SetInternalStateReqOutput,
|
||||
SlowDownReqInput,
|
||||
@@ -538,6 +542,14 @@ class Scheduler(
|
||||
(CloseSessionReqInput, self.close_session),
|
||||
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
||||
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
||||
(
|
||||
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||
self.init_weights_send_group_for_remote_instance,
|
||||
),
|
||||
(
|
||||
SendWeightsToRemoteInstanceReqInput,
|
||||
self.send_weights_to_remote_instance,
|
||||
),
|
||||
(
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
self.update_weights_from_distributed,
|
||||
@@ -2429,6 +2441,22 @@ class Scheduler(
|
||||
self.send_to_detokenizer.send_pyobj(recv_req)
|
||||
return recv_req
|
||||
|
||||
def init_weights_send_group_for_remote_instance(
|
||||
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
||||
):
|
||||
"""Init the seed and client instance communication group."""
|
||||
success, message = self.tp_worker.init_weights_send_group_for_remote_instance(
|
||||
recv_req
|
||||
)
|
||||
return InitWeightsSendGroupForRemoteInstanceReqOutput(success, message)
|
||||
|
||||
def send_weights_to_remote_instance(
|
||||
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
||||
):
|
||||
"""Send the seed instance weights to the destination instance."""
|
||||
success, message = self.tp_worker.send_weights_to_remote_instance(recv_req)
|
||||
return SendWeightsToRemoteInstanceReqOutput(success, message)
|
||||
|
||||
def slow_down(self, recv_req: SlowDownReqInput):
|
||||
t = recv_req.forward_sleep_time
|
||||
if t is not None and t <= 0:
|
||||
|
||||
@@ -30,6 +30,8 @@ from sglang.srt.managers.io_struct import (
|
||||
GetInternalStateReqOutput,
|
||||
GetWeightsByNameReqInput,
|
||||
GetWeightsByNameReqOutput,
|
||||
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||
InitWeightsSendGroupForRemoteInstanceReqOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
@@ -43,6 +45,8 @@ from sglang.srt.managers.io_struct import (
|
||||
ReleaseMemoryOccupationReqOutput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqOutput,
|
||||
SendWeightsToRemoteInstanceReqInput,
|
||||
SendWeightsToRemoteInstanceReqOutput,
|
||||
SetInternalStateReq,
|
||||
SetInternalStateReqOutput,
|
||||
SlowDownReqInput,
|
||||
@@ -119,6 +123,12 @@ class TokenizerCommunicatorMixin:
|
||||
self.update_weights_from_distributed_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.init_weights_send_group_for_remote_instance_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.send_weights_to_remote_instance_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.update_weights_from_tensor_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
@@ -169,6 +179,14 @@ class TokenizerCommunicatorMixin:
|
||||
UpdateWeightsFromDistributedReqOutput,
|
||||
self.update_weights_from_distributed_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
InitWeightsSendGroupForRemoteInstanceReqOutput,
|
||||
self.init_weights_send_group_for_remote_instance_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
SendWeightsToRemoteInstanceReqOutput,
|
||||
self.send_weights_to_remote_instance_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
UpdateWeightsFromTensorReqOutput,
|
||||
self.update_weights_from_tensor_communicator.handle_recv,
|
||||
@@ -310,6 +328,34 @@ class TokenizerCommunicatorMixin:
|
||||
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
async def init_weights_send_group_for_remote_instance(
|
||||
self,
|
||||
obj: InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
self.auto_create_handle_loop()
|
||||
# TODO: support DP
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for init_weights_send_group_for_remote_instance"
|
||||
result = (
|
||||
await self.init_weights_send_group_for_remote_instance_communicator(obj)
|
||||
)[0]
|
||||
return result.success, result.message
|
||||
|
||||
async def send_weights_to_remote_instance(
|
||||
self,
|
||||
obj: SendWeightsToRemoteInstanceReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
self.auto_create_handle_loop()
|
||||
# TODO: support DP
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for send_weights_to_remote_instance"
|
||||
result = (await self.send_weights_to_remote_instance_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
async def update_weights_from_tensor(
|
||||
self: TokenizerManager,
|
||||
obj: UpdateWeightsFromTensorReqInput,
|
||||
|
||||
@@ -30,8 +30,10 @@ from sglang.srt.hf_transformers_utils import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
SendWeightsToRemoteInstanceReqInput,
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
@@ -88,6 +90,7 @@ class TpModelWorker:
|
||||
else server_args.speculative_draft_model_revision
|
||||
),
|
||||
is_draft_model=is_draft_worker,
|
||||
tp_rank=tp_rank,
|
||||
)
|
||||
|
||||
self.model_runner = ModelRunner(
|
||||
@@ -292,6 +295,31 @@ class TpModelWorker:
|
||||
)
|
||||
return success, message
|
||||
|
||||
def init_weights_send_group_for_remote_instance(
|
||||
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
||||
):
|
||||
success, message = (
|
||||
self.model_runner.init_weights_send_group_for_remote_instance(
|
||||
recv_req.master_address,
|
||||
recv_req.ports,
|
||||
recv_req.group_rank,
|
||||
recv_req.world_size,
|
||||
recv_req.group_name,
|
||||
recv_req.backend,
|
||||
)
|
||||
)
|
||||
return success, message
|
||||
|
||||
def send_weights_to_remote_instance(
|
||||
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
||||
):
|
||||
success, message = self.model_runner.send_weights_to_remote_instance(
|
||||
recv_req.master_address,
|
||||
recv_req.ports,
|
||||
recv_req.group_name,
|
||||
)
|
||||
return success, message
|
||||
|
||||
def update_weights_from_distributed(
|
||||
self, recv_req: UpdateWeightsFromDistributedReqInput
|
||||
):
|
||||
|
||||
@@ -26,8 +26,10 @@ import torch
|
||||
|
||||
from sglang.srt.managers.io_struct import (
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
SendWeightsToRemoteInstanceReqInput,
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
@@ -267,6 +269,20 @@ class TpModelWorkerClient:
|
||||
success, message = self.worker.init_weights_update_group(recv_req)
|
||||
return success, message
|
||||
|
||||
def init_weights_send_group_for_remote_instance(
|
||||
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
||||
):
|
||||
success, message = self.worker.init_weights_send_group_for_remote_instance(
|
||||
recv_req
|
||||
)
|
||||
return success, message
|
||||
|
||||
def send_weights_to_remote_instance(
|
||||
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
||||
):
|
||||
success, message = self.worker.send_weights_to_remote_instance(recv_req)
|
||||
return success, message
|
||||
|
||||
def update_weights_from_distributed(
|
||||
self, recv_req: UpdateWeightsFromDistributedReqInput
|
||||
):
|
||||
|
||||
@@ -19,18 +19,23 @@ import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
|
||||
from sglang.srt.connector import ConnectorType
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
|
||||
from sglang.srt.distributed import (
|
||||
get_pp_group,
|
||||
@@ -106,6 +111,9 @@ from sglang.srt.offloader import (
|
||||
set_offloader,
|
||||
)
|
||||
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
||||
from sglang.srt.remote_instance_weight_loader_utils import (
|
||||
trigger_init_weights_send_group_for_remote_instance_request,
|
||||
)
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
@@ -128,6 +136,7 @@ from sglang.srt.utils import (
|
||||
is_sm100_supported,
|
||||
monkey_patch_p2p_access_check,
|
||||
monkey_patch_vllm_gguf_config,
|
||||
parse_connector_type,
|
||||
set_cuda_arch,
|
||||
)
|
||||
from sglang.srt.weight_sync.tensor_bucket import (
|
||||
@@ -256,6 +265,7 @@ class ModelRunner:
|
||||
|
||||
# For weight updates
|
||||
self._model_update_group = {}
|
||||
self._weights_send_group = {}
|
||||
|
||||
def initialize(self, min_per_gpu_memory: float):
|
||||
server_args = self.server_args
|
||||
@@ -726,6 +736,20 @@ class ModelRunner:
|
||||
if self.server_args.load_format == "gguf":
|
||||
monkey_patch_vllm_gguf_config()
|
||||
|
||||
if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE:
|
||||
if self.tp_rank == 0:
|
||||
instance_ip = socket.gethostbyname(socket.gethostname())
|
||||
t = threading.Thread(
|
||||
target=trigger_init_weights_send_group_for_remote_instance_request,
|
||||
args=(
|
||||
self.server_args.remote_instance_weight_loader_seed_instance_ip,
|
||||
self.server_args.remote_instance_weight_loader_seed_instance_service_port,
|
||||
self.server_args.remote_instance_weight_loader_send_weights_group_ports,
|
||||
instance_ip,
|
||||
),
|
||||
)
|
||||
t.start()
|
||||
|
||||
# Load the model
|
||||
# Remove monkey_patch when linear.py quant remove dependencies with vllm
|
||||
monkey_patch_vllm_parallel_state()
|
||||
@@ -735,7 +759,7 @@ class ModelRunner:
|
||||
self.model = get_model(
|
||||
model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=DeviceConfig(self.device),
|
||||
device_config=DeviceConfig(self.device, self.gpu_id),
|
||||
)
|
||||
monkey_patch_vllm_parallel_state(reverse=True)
|
||||
monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
|
||||
@@ -867,6 +891,103 @@ class ModelRunner:
|
||||
logger.info("Update weights end.")
|
||||
return True, "Succeeded to update model weights."
|
||||
|
||||
def init_weights_send_group_for_remote_instance(
|
||||
self,
|
||||
master_address,
|
||||
ports,
|
||||
group_rank,
|
||||
world_size,
|
||||
group_name,
|
||||
backend="nccl",
|
||||
):
|
||||
assert (
|
||||
torch.distributed.is_initialized()
|
||||
), "Default torch process group must be initialized"
|
||||
assert group_name != "", "Group name cannot be empty"
|
||||
|
||||
ports_list = ports.split(",")
|
||||
assert (
|
||||
len(ports_list) == self.tp_size
|
||||
), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
|
||||
group_port = ports_list[self.tp_rank]
|
||||
group_name = f"{group_name}_{group_port}_{self.tp_rank}"
|
||||
|
||||
logger.info(
|
||||
f"init custom process group: tp_rank={self.tp_rank}, gpu_id={self.gpu_id}, master_address={master_address}, master_port={group_port}, "
|
||||
f"group_rank={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
success = False
|
||||
message = ""
|
||||
try:
|
||||
self._weights_send_group[group_name] = init_custom_process_group(
|
||||
backend=backend,
|
||||
init_method=f"tcp://{master_address}:{group_port}",
|
||||
world_size=world_size,
|
||||
rank=group_rank,
|
||||
group_name=group_name,
|
||||
device_id=torch.device("cuda", self.gpu_id),
|
||||
)
|
||||
dist.barrier(group=self._weights_send_group[group_name])
|
||||
success = True
|
||||
message = (
|
||||
f"Succeeded to init group through {master_address}:{group_port} group."
|
||||
)
|
||||
except Exception as e:
|
||||
message = f"Failed to init group: {e}."
|
||||
logger.error(message)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
return success, message
|
||||
|
||||
def send_weights_to_remote_instance(
|
||||
self,
|
||||
master_address,
|
||||
ports,
|
||||
group_name,
|
||||
):
|
||||
assert (
|
||||
torch.distributed.is_initialized()
|
||||
), "Default torch process group must be initialized"
|
||||
assert group_name != "", "Group name cannot be empty"
|
||||
|
||||
ports_list = ports.split(",")
|
||||
assert (
|
||||
len(ports_list) == self.tp_size
|
||||
), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
|
||||
group_port = ports_list[self.tp_rank]
|
||||
group_name = f"{group_name}_{group_port}_{self.tp_rank}"
|
||||
|
||||
if self._weights_send_group[group_name] is not None:
|
||||
send_group = self._weights_send_group[group_name]
|
||||
else:
|
||||
message = f"Group {group_name} not in _weights_send_group list. Please call `init_weights_send_group_for_remote_instance` first."
|
||||
logger.error(message)
|
||||
return False, message
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
success = False
|
||||
message = ""
|
||||
try:
|
||||
for _, weights in self.model.named_parameters():
|
||||
torch.distributed.broadcast(
|
||||
weights,
|
||||
src=0,
|
||||
group=send_group,
|
||||
)
|
||||
success = True
|
||||
message = f"Succeeded to send weights through {master_address}:{group_port} {group_name}."
|
||||
except Exception as e:
|
||||
message = f"Failed to send weights: {e}."
|
||||
logger.error(message)
|
||||
|
||||
# destroy the process group after sending weights
|
||||
del self._weights_send_group[group_name]
|
||||
torch.distributed.distributed_c10d.destroy_process_group(send_group)
|
||||
torch.cuda.empty_cache()
|
||||
return success, message
|
||||
|
||||
def init_weights_update_group(
|
||||
self,
|
||||
master_address,
|
||||
|
||||
@@ -12,6 +12,9 @@ import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
@@ -27,9 +30,11 @@ from typing import (
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
import requests
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
@@ -56,6 +61,7 @@ from sglang.srt.model_loader.utils import (
|
||||
)
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
_BAR_FORMAT,
|
||||
default_weight_loader,
|
||||
download_safetensors_index_file_from_hf,
|
||||
download_weights_from_hf,
|
||||
filter_duplicate_safetensors_files,
|
||||
@@ -71,6 +77,9 @@ from sglang.srt.model_loader.weight_utils import (
|
||||
safetensors_weights_iterator,
|
||||
set_runai_streamer_env,
|
||||
)
|
||||
from sglang.srt.remote_instance_weight_loader_utils import (
|
||||
trigger_transferring_weights_request,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
get_device_capability,
|
||||
@@ -1380,6 +1389,104 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
return model
|
||||
|
||||
|
||||
class RemoteInstanceModelLoader(BaseModelLoader):
|
||||
"""Model loader that can load Tensors from remote sglang instance."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
raise ValueError(
|
||||
f"Model loader extra config is not supported for "
|
||||
f"load format {load_config.load_format}"
|
||||
)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
) -> nn.Module:
|
||||
logger.info("Loading weights from remote instance ...")
|
||||
load_config = self.load_config
|
||||
|
||||
assert load_config.load_format == LoadFormat.REMOTE_INSTANCE, (
|
||||
f"Model loader {self.load_config.load_format} is not supported for "
|
||||
f"load format {load_config.load_format}"
|
||||
)
|
||||
|
||||
model_weights = f"instance://{model_config.remote_instance_weight_loader_seed_instance_ip}:{model_config.remote_instance_weight_loader_send_weights_group_ports[model_config.tp_rank]}"
|
||||
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config)
|
||||
|
||||
with create_remote_connector(model_weights, device_config.device) as client:
|
||||
connector_type = get_connector_type(client)
|
||||
if connector_type == ConnectorType.INSTANCE:
|
||||
self.load_model_from_remote_instance(
|
||||
model, client, model_config, device_config
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported connector type {connector_type} for "
|
||||
f"remote tensor model loading."
|
||||
)
|
||||
return model.eval()
|
||||
|
||||
def load_model_from_remote_instance(
|
||||
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
|
||||
) -> nn.Module:
|
||||
instance_ip = socket.gethostbyname(socket.gethostname())
|
||||
start_build_group_tic = time.time()
|
||||
client.build_group(
|
||||
gpu_id=device_config.gpu_id,
|
||||
tp_rank=model_config.tp_rank,
|
||||
instance_ip=instance_ip,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
end_build_group_tic = time.time()
|
||||
logger.debug(
|
||||
f"finish building group for remote instance, time used: {(end_build_group_tic - start_build_group_tic):.4f}s"
|
||||
)
|
||||
|
||||
if model_config.tp_rank == 0:
|
||||
t = threading.Thread(
|
||||
target=trigger_transferring_weights_request,
|
||||
args=(
|
||||
model_config.remote_instance_weight_loader_seed_instance_ip,
|
||||
model_config.remote_instance_weight_loader_seed_instance_service_port,
|
||||
model_config.remote_instance_weight_loader_send_weights_group_ports,
|
||||
instance_ip,
|
||||
),
|
||||
)
|
||||
t.start()
|
||||
|
||||
start_get_weights_tic = time.time()
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
for _, tensor in model.named_parameters():
|
||||
torch.distributed.broadcast(
|
||||
tensor.data,
|
||||
src=0,
|
||||
group=client._model_update_group,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if hasattr(model, "post_load_weights"):
|
||||
model.post_load_weights()
|
||||
end_get_weights_tic = time.time()
|
||||
logger.debug(
|
||||
f"finish getting all weights from remote instance, time used: {(end_get_weights_tic - start_get_weights_tic):.4f}s"
|
||||
)
|
||||
# destroy the process group after loading weights
|
||||
torch.distributed.distributed_c10d.destroy_process_group(
|
||||
client._model_update_group
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
class RemoteModelLoader(BaseModelLoader):
|
||||
"""Model loader that can load Tensors from remote database."""
|
||||
|
||||
@@ -1581,4 +1688,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
if load_config.load_format == LoadFormat.REMOTE:
|
||||
return RemoteModelLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.REMOTE_INSTANCE:
|
||||
return RemoteInstanceModelLoader(load_config)
|
||||
|
||||
return DefaultModelLoader(load_config)
|
||||
|
||||
69
python/sglang/srt/remote_instance_weight_loader_utils.py
Normal file
69
python/sglang/srt/remote_instance_weight_loader_utils.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def trigger_init_weights_send_group_for_remote_instance_request(
|
||||
remote_instance_weight_loader_seed_instance_ip: str,
|
||||
remote_instance_weight_loader_seed_instance_service_port: int,
|
||||
remote_instance_weight_loader_send_weights_group_ports: List[int],
|
||||
remote_instance_weight_loader_client_id: str,
|
||||
):
|
||||
seed_instance_service_url = f"http://{remote_instance_weight_loader_seed_instance_ip}:{remote_instance_weight_loader_seed_instance_service_port}"
|
||||
# Only support loading weights from instance with same parallelism strategy.
|
||||
# Per TP rank pair between seed and dst instances will build a communication group for sending weights.
|
||||
# i.e. seed TP 0 <-> dst TP 0, seed TP 1 <-> dst TP 1, etc.
|
||||
# Each communication group will have a world size 2.
|
||||
try:
|
||||
requests.post(
|
||||
f"{seed_instance_service_url}/init_weights_send_group_for_remote_instance",
|
||||
json={
|
||||
"master_address": remote_instance_weight_loader_seed_instance_ip,
|
||||
"ports": (
|
||||
",".join(
|
||||
str(p)
|
||||
for p in remote_instance_weight_loader_send_weights_group_ports
|
||||
)
|
||||
),
|
||||
"group_rank": 0,
|
||||
"world_size": 2,
|
||||
"group_name": f"send_weights_{remote_instance_weight_loader_client_id}",
|
||||
"backend": "nccl",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to trigger init_weights_send_group_for_remote_instance_request to seed instance {seed_instance_service_url}: {e}."
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def trigger_transferring_weights_request(
|
||||
remote_instance_weight_loader_seed_instance_ip: str,
|
||||
remote_instance_weight_loader_seed_instance_service_port: int,
|
||||
remote_instance_weight_loader_send_weights_group_ports: List[int],
|
||||
remote_instance_weight_loader_client_id: str,
|
||||
):
|
||||
seed_instance_service_url = f"http://{remote_instance_weight_loader_seed_instance_ip}:{remote_instance_weight_loader_seed_instance_service_port}"
|
||||
try:
|
||||
requests.post(
|
||||
f"{seed_instance_service_url}/send_weights_to_remote_instance",
|
||||
json={
|
||||
"master_address": remote_instance_weight_loader_seed_instance_ip,
|
||||
"ports": (
|
||||
",".join(
|
||||
str(p)
|
||||
for p in remote_instance_weight_loader_send_weights_group_ports
|
||||
)
|
||||
),
|
||||
"group_name": f"send_weights_{remote_instance_weight_loader_client_id}",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trigger send weights to remote instance request: {e}")
|
||||
raise
|
||||
@@ -19,10 +19,12 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from sglang.srt.connector import ConnectorType
|
||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
|
||||
from sglang.srt.lora.lora_registry import LoRARef
|
||||
@@ -42,7 +44,9 @@ from sglang.srt.utils import (
|
||||
is_sm100_supported,
|
||||
is_triton_kernels_available,
|
||||
is_valid_ipv6_address,
|
||||
json_list_type,
|
||||
nullable_str,
|
||||
parse_connector_type,
|
||||
)
|
||||
from sglang.utils import is_in_ci
|
||||
|
||||
@@ -61,6 +65,7 @@ LOAD_FORMAT_CHOICES = [
|
||||
"bitsandbytes",
|
||||
"layered",
|
||||
"remote",
|
||||
"remote_instance",
|
||||
]
|
||||
|
||||
QUANTIZATION_CHOICES = [
|
||||
@@ -387,6 +392,11 @@ class ServerArgs:
|
||||
custom_weight_loader: Optional[List[str]] = None
|
||||
weight_loader_disable_mmap: bool = False
|
||||
|
||||
# Remote instance weight loading
|
||||
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
|
||||
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
|
||||
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
|
||||
|
||||
# For PD-Multiplexing
|
||||
enable_pdmux: bool = False
|
||||
sm_group_num: int = 3
|
||||
@@ -445,6 +455,7 @@ class ServerArgs:
|
||||
# Set missing default values
|
||||
if self.tokenizer_path is None:
|
||||
self.tokenizer_path = self.model_path
|
||||
|
||||
if self.served_model_name is None:
|
||||
self.served_model_name = self.model_path
|
||||
if self.device is None:
|
||||
@@ -538,7 +549,8 @@ class ServerArgs:
|
||||
self.sampling_backend = "pytorch"
|
||||
|
||||
# Model-specific adjustments
|
||||
self.model_specific_adjustments()
|
||||
if parse_connector_type(self.model_path) != ConnectorType.INSTANCE:
|
||||
self.model_specific_adjustments()
|
||||
|
||||
# Set kernel backends
|
||||
if self.device == "cpu":
|
||||
@@ -818,12 +830,19 @@ class ServerArgs:
|
||||
) and check_gguf_file(self.model_path):
|
||||
self.quantization = self.load_format = "gguf"
|
||||
|
||||
# Model loading
|
||||
if is_remote_url(self.model_path):
|
||||
self.load_format = "remote"
|
||||
if self.custom_weight_loader is None:
|
||||
self.custom_weight_loader = []
|
||||
|
||||
if self.load_format == "remote_instance":
|
||||
if (
|
||||
self.remote_instance_weight_loader_seed_instance_ip is None
|
||||
or self.remote_instance_weight_loader_seed_instance_service_port is None
|
||||
or self.remote_instance_weight_loader_send_weights_group_ports is None
|
||||
):
|
||||
self.load_format = "auto"
|
||||
|
||||
# PD disaggregation
|
||||
if self.disaggregation_mode == "decode":
|
||||
assert (
|
||||
@@ -881,6 +900,24 @@ class ServerArgs:
|
||||
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote-instance-weight-loader-seed-instance-ip",
|
||||
type=str,
|
||||
default=ServerArgs.remote_instance_weight_loader_seed_instance_ip,
|
||||
help="The ip of the seed instance for loading weights from remote instance.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote-instance-weight-loader-seed-instance-service-port",
|
||||
type=int,
|
||||
default=ServerArgs.remote_instance_weight_loader_seed_instance_service_port,
|
||||
help="The service port of the seed instance for loading weights from remote instance.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote-instance-weight-loader-send-weights-group-ports",
|
||||
type=json_list_type,
|
||||
default=ServerArgs.remote_instance_weight_loader_send_weights_group_ports,
|
||||
help="The communication group ports for loading weights from remote instance.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer-path",
|
||||
type=str,
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import builtins
|
||||
import ctypes
|
||||
@@ -1431,6 +1432,7 @@ def init_custom_process_group(
|
||||
store=None,
|
||||
group_name=None,
|
||||
pg_options=None,
|
||||
device_id=None,
|
||||
):
|
||||
from torch.distributed.distributed_c10d import (
|
||||
Backend,
|
||||
@@ -1484,6 +1486,7 @@ def init_custom_process_group(
|
||||
group_name=group_name,
|
||||
**{pg_options_param_name: pg_options},
|
||||
timeout=timeout,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
|
||||
@@ -3046,3 +3049,12 @@ def numa_bind_to_node(node: int):
|
||||
|
||||
libnuma.numa_run_on_node(ctypes.c_int(node))
|
||||
libnuma.numa_set_localalloc()
|
||||
|
||||
|
||||
def json_list_type(value):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Invalid JSON list: {value}. Please provide a valid JSON list."
|
||||
)
|
||||
|
||||
@@ -123,6 +123,7 @@ suites = {
|
||||
TestFile("rl/test_update_weights_from_distributed.py", 103),
|
||||
TestFile("test_data_parallelism.py", 73),
|
||||
TestFile("test_dp_attention.py", 277),
|
||||
TestFile("test_load_weights_from_remote_instance.py", 72),
|
||||
TestFile("test_patch_torch.py", 19),
|
||||
TestFile("test_release_memory_occupation.py", 127),
|
||||
TestFile("hicache/test_hicache_storage_file_backend.py", 400),
|
||||
@@ -251,6 +252,7 @@ suite_amd = {
|
||||
TestFile("lora/test_lora_tp.py", 116),
|
||||
TestFile("rl/test_update_weights_from_distributed.py", 103),
|
||||
TestFile("test_data_parallelism.py", 73),
|
||||
TestFile("test_load_weights_from_remote_instance.py", 72),
|
||||
TestFile("test_patch_torch.py", 19),
|
||||
],
|
||||
"per-commit-4-gpu-amd": [
|
||||
|
||||
384
test/srt/test_load_weights_from_remote_instance.py
Normal file
384
test/srt/test_load_weights_from_remote_instance.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""Test loading weights from remote instance.
|
||||
|
||||
This test suite simulates loading weights from a remote instance.
|
||||
Rank 0 represents the seed instance, while ranks 1 represents the
|
||||
new instance that needs to loading weights from the seed instance.
|
||||
|
||||
Seed instance must be started in `Server` mode, while the dst instance
|
||||
can be either `Engine` mode or `Server` mode.
|
||||
|
||||
Seed instance does not support concurrently serving multiple dst instances.
|
||||
User has to guarantee that there is only one dst instance trying to load
|
||||
weights from the seed instance at any time.
|
||||
|
||||
"""
|
||||
|
||||
import gc
|
||||
import os
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
)
|
||||
from sglang.utils import terminate_process
|
||||
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
|
||||
def verify_params_close(params1, params2, error_msg):
|
||||
"""Verify if two parameter arrays are close enough."""
|
||||
try:
|
||||
assert np.allclose(np.array(params1), np.array(params2)), error_msg
|
||||
except Exception as e:
|
||||
print(f"Parameters not close for {error_msg}")
|
||||
print("Params1:", np.array(params1))
|
||||
print("Params2:", np.array(params2))
|
||||
raise e
|
||||
|
||||
|
||||
def init_process(
|
||||
rank,
|
||||
param_queue,
|
||||
truncate_size,
|
||||
tp_size,
|
||||
model_name,
|
||||
backends,
|
||||
checking_parameters,
|
||||
seed_instance_ip,
|
||||
seed_instance_service_port,
|
||||
seed_instance_group_base_port,
|
||||
event_seed_ready,
|
||||
event_dst_ready_list,
|
||||
):
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
if rank == 0:
|
||||
init_process_seed(
|
||||
rank,
|
||||
param_queue,
|
||||
truncate_size,
|
||||
model_name,
|
||||
checking_parameters,
|
||||
tp_size,
|
||||
event_seed_ready,
|
||||
event_dst_ready_list,
|
||||
)
|
||||
elif rank in [1, 2]:
|
||||
init_process_dst(
|
||||
rank,
|
||||
param_queue,
|
||||
truncate_size,
|
||||
model_name,
|
||||
seed_instance_ip,
|
||||
seed_instance_service_port,
|
||||
seed_instance_group_base_port,
|
||||
checking_parameters,
|
||||
backends[rank - 1],
|
||||
tp_size,
|
||||
event_seed_ready,
|
||||
event_dst_ready_list,
|
||||
)
|
||||
|
||||
|
||||
def init_process_seed(
|
||||
rank,
|
||||
param_queue,
|
||||
truncate_size,
|
||||
model_name,
|
||||
checking_parameters,
|
||||
tp_size,
|
||||
event_seed_ready,
|
||||
event_dst_ready_list,
|
||||
):
|
||||
# These two environment variables are very important
|
||||
# to avoid unexpected behaviors of CUDA and NCCL.
|
||||
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
||||
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
||||
|
||||
# Load model and get parameters
|
||||
torch.cuda.set_device(rank)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
url = DEFAULT_URL_FOR_TEST
|
||||
process = popen_launch_server(
|
||||
model_name,
|
||||
url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=(
|
||||
"--base-gpu-id",
|
||||
str(rank),
|
||||
"--tp-size",
|
||||
str(tp_size),
|
||||
),
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
seed_params = []
|
||||
# Get the weights of seed instance for correctness check.
|
||||
for parameter_name in checking_parameters:
|
||||
seed_params.append(
|
||||
requests.get(
|
||||
f"{url}/get_weights_by_name",
|
||||
json={
|
||||
"name": parameter_name,
|
||||
"truncate_size": truncate_size,
|
||||
},
|
||||
).json()
|
||||
)
|
||||
param_queue.put((f"seed_params", seed_params))
|
||||
|
||||
event_seed_ready.set()
|
||||
for i in range(len(event_dst_ready_list)):
|
||||
event_dst_ready_list[i].wait()
|
||||
terminate_process(process)
|
||||
|
||||
|
||||
def init_process_dst(
|
||||
rank,
|
||||
param_queue,
|
||||
truncate_size,
|
||||
model_name,
|
||||
seed_instance_ip,
|
||||
seed_instance_service_port,
|
||||
seed_instance_group_base_port,
|
||||
checking_parameters,
|
||||
backend,
|
||||
tp_size,
|
||||
event_seed_ready,
|
||||
event_dst_ready_list,
|
||||
):
|
||||
torch.cuda.set_device(rank * tp_size)
|
||||
torch.cuda.synchronize()
|
||||
base_gpu_id = rank * tp_size
|
||||
|
||||
event_seed_ready.wait()
|
||||
print(f"rank {rank}, seed ready")
|
||||
for i in range(rank - 1):
|
||||
print(f"rank {rank}, wait dst {i}")
|
||||
event_dst_ready_list[i].wait()
|
||||
|
||||
ports = []
|
||||
for i in range(tp_size):
|
||||
ports.append(seed_instance_group_base_port + (rank - 1) * tp_size + i)
|
||||
|
||||
if backend == "Engine":
|
||||
print(f"[sgl] rank {rank} init engine")
|
||||
engine = sgl.Engine(
|
||||
model_path=model_name,
|
||||
base_gpu_id=base_gpu_id,
|
||||
tp_size=tp_size,
|
||||
cuda_graph_max_bs=2,
|
||||
tokenizer_path=model_name,
|
||||
remote_instance_weight_loader_seed_instance_ip=seed_instance_ip,
|
||||
remote_instance_weight_loader_seed_instance_service_port=seed_instance_service_port,
|
||||
remote_instance_weight_loader_send_weights_group_ports=ports,
|
||||
load_format="remote_instance",
|
||||
)
|
||||
else:
|
||||
host, _, port = DEFAULT_URL_FOR_TEST.rpartition(":")
|
||||
url = ":".join([host, str(int(port) + 10000 + rank)])
|
||||
|
||||
print(f"[sgl] rank {rank} init server on url: {url}")
|
||||
process = popen_launch_server(
|
||||
model_name,
|
||||
url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=(
|
||||
"--base-gpu-id",
|
||||
str(base_gpu_id),
|
||||
"--tp-size",
|
||||
str(tp_size),
|
||||
"--cuda-graph-max-bs",
|
||||
2,
|
||||
"--tokenizer-path",
|
||||
model_name,
|
||||
"--remote-instance-weight-loader-seed-instance-ip",
|
||||
seed_instance_ip,
|
||||
"--remote-instance-weight-loader-seed-instance-service-port",
|
||||
seed_instance_service_port,
|
||||
"--remote-instance-weight-loader-send-weights-group-ports",
|
||||
f"[{','.join(str(port) for port in ports)}]",
|
||||
"--load-format",
|
||||
"remote_instance",
|
||||
),
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
event_dst_ready_list[rank - 1].set()
|
||||
|
||||
# Get weights of destination instance loaded from remote instance.
|
||||
dst_params = []
|
||||
for parameter_name in checking_parameters:
|
||||
dst_params.append(
|
||||
engine.get_weights_by_name(parameter_name, truncate_size)
|
||||
if backend == "Engine"
|
||||
else requests.get(
|
||||
f"{url}/get_weights_by_name",
|
||||
json={"name": parameter_name, "truncate_size": truncate_size},
|
||||
).json()
|
||||
)
|
||||
|
||||
param_queue.put((f"sgl_dp_{rank}_dst_params", dst_params))
|
||||
|
||||
# Shutdown the engine or terminate the server process.
|
||||
if backend == "Engine":
|
||||
engine.shutdown()
|
||||
else:
|
||||
terminate_process(process)
|
||||
|
||||
|
||||
def test_load_weights_from_remote_instance(
|
||||
tp_size,
|
||||
dp_size,
|
||||
model_name,
|
||||
backends,
|
||||
truncate_size,
|
||||
checking_parameters,
|
||||
seed_instance_ip,
|
||||
seed_instance_service_port,
|
||||
seed_instance_group_base_port,
|
||||
):
|
||||
print(
|
||||
f"Testing model: {model_name} tp_size: {tp_size}, dp_size: {dp_size} backend: {backends}"
|
||||
)
|
||||
param_queue = mp.Queue()
|
||||
results = {}
|
||||
event_seed_ready = mp.Event()
|
||||
event_dst_ready_list = []
|
||||
for i in range(dp_size):
|
||||
event_dst_ready = mp.Event()
|
||||
event_dst_ready_list.append(event_dst_ready)
|
||||
|
||||
context = mp.spawn(
|
||||
init_process,
|
||||
args=(
|
||||
param_queue,
|
||||
truncate_size,
|
||||
tp_size,
|
||||
model_name,
|
||||
backends,
|
||||
checking_parameters,
|
||||
seed_instance_ip,
|
||||
seed_instance_service_port,
|
||||
seed_instance_group_base_port,
|
||||
event_seed_ready,
|
||||
event_dst_ready_list,
|
||||
),
|
||||
nprocs=1 + dp_size,
|
||||
join=False,
|
||||
)
|
||||
|
||||
while len(results) < (1 + dp_size):
|
||||
try:
|
||||
key, value = param_queue.get(timeout=5)
|
||||
results[key] = value
|
||||
except Exception as e:
|
||||
if all(not p.is_alive() for p in context.processes):
|
||||
break
|
||||
|
||||
context.join()
|
||||
|
||||
if len(results) != (1 + dp_size):
|
||||
raise RuntimeError(
|
||||
f"Expected {(1 + dp_size)} parameters but got {len(results)}"
|
||||
)
|
||||
|
||||
params = {
|
||||
"seed": results.get("seed_params"),
|
||||
"sgl_dp_1_dest": results.get("sgl_dp_1_dst_params"),
|
||||
}
|
||||
|
||||
if dp_size == 2:
|
||||
dp2_params = {
|
||||
"sgl_dp_2_dest": results.get("sgl_dp_2_dst_params"),
|
||||
}
|
||||
assert all(v is not None for v in dp2_params.values())
|
||||
params.update(dp2_params)
|
||||
|
||||
# Check the correctness of weights loaded from remote instance
|
||||
# by verifying the weights of seed instance and destination instance.
|
||||
for i in range(len(params["seed"])):
|
||||
verify_params_close(
|
||||
params["seed"][i],
|
||||
params["sgl_dp_1_dest"][i],
|
||||
f"sgl_dp_1_dst_params rank {i}",
|
||||
)
|
||||
|
||||
if dp_size == 2:
|
||||
verify_params_close(
|
||||
params["seed"][i],
|
||||
params["sgl_dp_2_dest"][i],
|
||||
f"sgl_dp_2_dst_params rank {i}",
|
||||
)
|
||||
|
||||
# Delete the context and close the parameter queue.
|
||||
del context
|
||||
param_queue.close()
|
||||
param_queue.join_thread()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
class TestLoadWeightsFromRemoteInstance(CustomTestCase):
|
||||
|
||||
def test_load_weights_from_remote_instance(self):
|
||||
|
||||
assert torch.cuda.device_count() >= 2, "At least 2 GPUs are required"
|
||||
# test_suits : tp, dp, model_name, backend, dst_instance_id
|
||||
if is_in_ci():
|
||||
mode = random.choice(["Engine", "Server"])
|
||||
test_suits = [
|
||||
(1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, [mode]),
|
||||
]
|
||||
else:
|
||||
test_suits = [
|
||||
(1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, ["Engine"]),
|
||||
(1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, ["Sever"]),
|
||||
(2, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, ["Engine", "Server"]),
|
||||
]
|
||||
|
||||
truncate_size = 10
|
||||
checking_parameters = [
|
||||
"model.embed_tokens.weight",
|
||||
"model.layers.0.input_layernorm.weight",
|
||||
"model.layers.1.self_attn.q_proj.weight",
|
||||
"model.layers.2.self_attn.k_proj.weight",
|
||||
"model.layers.3.self_attn.v_proj.weight",
|
||||
"model.layers.4.self_attn.o_proj.weight",
|
||||
"model.layers.5.mlp.gate_proj.weight",
|
||||
"model.layers.6.mlp.up_proj.weight",
|
||||
"model.layers.7.mlp.down_proj.weight",
|
||||
"model.layers.8.post_attention_layernorm.weight",
|
||||
"model.norm.weight",
|
||||
]
|
||||
|
||||
for tp_size, dp_size, model_name, backends in test_suits:
|
||||
test_load_weights_from_remote_instance(
|
||||
tp_size,
|
||||
dp_size,
|
||||
model_name,
|
||||
backends,
|
||||
truncate_size,
|
||||
checking_parameters,
|
||||
"127.0.0.1",
|
||||
DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000,
|
||||
60000,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user