Support loading weights from remote instance (#8215)
Signed-off-by: Anqi Shen <amy.saq@antgroup.com> Co-authored-by: Chayenne <74843776+zhaochenyang20@users.noreply.github.com>
This commit is contained in:
@@ -8,10 +8,12 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class DeviceConfig:
|
class DeviceConfig:
|
||||||
device: Optional[torch.device]
|
device: Optional[torch.device]
|
||||||
|
gpu_id: Optional[int]
|
||||||
|
|
||||||
def __init__(self, device: str = "cuda") -> None:
|
def __init__(self, device: str = "cuda", gpu_id: int = -1) -> None:
|
||||||
if device in ["cuda", "xpu", "hpu", "cpu", "npu"]:
|
if device in ["cuda", "xpu", "hpu", "cpu", "npu"]:
|
||||||
self.device_type = device
|
self.device_type = device
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Not supported device type: {device}")
|
raise RuntimeError(f"Not supported device type: {device}")
|
||||||
self.device = torch.device(self.device_type)
|
self.device = torch.device(self.device_type)
|
||||||
|
self.gpu_id = gpu_id
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ class LoadFormat(str, enum.Enum):
|
|||||||
LAYERED = "layered"
|
LAYERED = "layered"
|
||||||
JAX = "jax"
|
JAX = "jax"
|
||||||
REMOTE = "remote"
|
REMOTE = "remote"
|
||||||
|
REMOTE_INSTANCE = "remote_instance"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -64,12 +64,28 @@ class ModelConfig:
|
|||||||
is_draft_model: bool = False,
|
is_draft_model: bool = False,
|
||||||
hybrid_kvcache_ratio: Optional[float] = None,
|
hybrid_kvcache_ratio: Optional[float] = None,
|
||||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||||
|
tp_rank: Optional[int] = None,
|
||||||
|
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None,
|
||||||
|
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None,
|
||||||
|
remote_instance_weight_loader_send_weights_group_ports: Optional[
|
||||||
|
List[int]
|
||||||
|
] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Parse args
|
# Parse args
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
self.revision = revision
|
self.revision = revision
|
||||||
self.quantization = quantization
|
self.quantization = quantization
|
||||||
self.model_impl = model_impl
|
self.model_impl = model_impl
|
||||||
|
self.tp_rank = tp_rank
|
||||||
|
self.remote_instance_weight_loader_seed_instance_ip = (
|
||||||
|
remote_instance_weight_loader_seed_instance_ip
|
||||||
|
)
|
||||||
|
self.remote_instance_weight_loader_seed_instance_service_port = (
|
||||||
|
remote_instance_weight_loader_seed_instance_service_port
|
||||||
|
)
|
||||||
|
self.remote_instance_weight_loader_send_weights_group_ports = (
|
||||||
|
remote_instance_weight_loader_send_weights_group_ports
|
||||||
|
)
|
||||||
|
|
||||||
self.maybe_pull_model_tokenizer_from_remote()
|
self.maybe_pull_model_tokenizer_from_remote()
|
||||||
self.model_override_args = json.loads(model_override_args)
|
self.model_override_args = json.loads(model_override_args)
|
||||||
@@ -329,6 +345,9 @@ class ModelConfig:
|
|||||||
quantization=server_args.quantization,
|
quantization=server_args.quantization,
|
||||||
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
||||||
model_impl=server_args.model_impl,
|
model_impl=server_args.model_impl,
|
||||||
|
remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
|
||||||
|
remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
|
||||||
|
remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from sglang.srt.connector.base_connector import (
|
|||||||
BaseKVConnector,
|
BaseKVConnector,
|
||||||
)
|
)
|
||||||
from sglang.srt.connector.redis import RedisConnector
|
from sglang.srt.connector.redis import RedisConnector
|
||||||
|
from sglang.srt.connector.remote_instance import RemoteInstanceConnector
|
||||||
from sglang.srt.connector.s3 import S3Connector
|
from sglang.srt.connector.s3 import S3Connector
|
||||||
from sglang.srt.utils import parse_connector_type
|
from sglang.srt.utils import parse_connector_type
|
||||||
|
|
||||||
@@ -18,14 +19,17 @@ logger = logging.getLogger(__name__)
|
|||||||
class ConnectorType(str, enum.Enum):
|
class ConnectorType(str, enum.Enum):
|
||||||
FS = "filesystem"
|
FS = "filesystem"
|
||||||
KV = "KV"
|
KV = "KV"
|
||||||
|
INSTANCE = "instance"
|
||||||
|
|
||||||
|
|
||||||
def create_remote_connector(url, **kwargs) -> BaseConnector:
|
def create_remote_connector(url, device, **kwargs) -> BaseConnector:
|
||||||
connector_type = parse_connector_type(url)
|
connector_type = parse_connector_type(url)
|
||||||
if connector_type == "redis":
|
if connector_type == "redis":
|
||||||
return RedisConnector(url)
|
return RedisConnector(url)
|
||||||
elif connector_type == "s3":
|
elif connector_type == "s3":
|
||||||
return S3Connector(url)
|
return S3Connector(url)
|
||||||
|
elif connector_type == "instance":
|
||||||
|
return RemoteInstanceConnector(url, device)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid connector type: {url}")
|
raise ValueError(f"Invalid connector type: {url}")
|
||||||
|
|
||||||
@@ -35,6 +39,8 @@ def get_connector_type(client: BaseConnector) -> ConnectorType:
|
|||||||
return ConnectorType.KV
|
return ConnectorType.KV
|
||||||
if isinstance(client, BaseFileConnector):
|
if isinstance(client, BaseFileConnector):
|
||||||
return ConnectorType.FS
|
return ConnectorType.FS
|
||||||
|
if isinstance(client, RemoteInstanceConnector):
|
||||||
|
return ConnectorType.INSTANCE
|
||||||
|
|
||||||
raise ValueError(f"Invalid connector type: {client}")
|
raise ValueError(f"Invalid connector type: {client}")
|
||||||
|
|
||||||
@@ -44,6 +50,7 @@ __all__ = [
|
|||||||
"BaseFileConnector",
|
"BaseFileConnector",
|
||||||
"BaseKVConnector",
|
"BaseKVConnector",
|
||||||
"RedisConnector",
|
"RedisConnector",
|
||||||
|
"RemoteInstanceConnector",
|
||||||
"S3Connector",
|
"S3Connector",
|
||||||
"ConnectorType",
|
"ConnectorType",
|
||||||
"create_remote_connector",
|
"create_remote_connector",
|
||||||
|
|||||||
82
python/sglang/srt/connector/remote_instance.py
Normal file
82
python/sglang/srt/connector/remote_instance.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Generator, List, Optional, Tuple
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from sglang.srt.connector import BaseConnector
|
||||||
|
from sglang.srt.utils import init_custom_process_group
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteInstanceConnector(BaseConnector):
|
||||||
|
|
||||||
|
def __init__(self, url: str, device: torch.device = "cpu"):
|
||||||
|
assert (
|
||||||
|
device.type == "cuda"
|
||||||
|
), "RemoteInstanceConnector only supports cuda device."
|
||||||
|
super().__init__(url)
|
||||||
|
self.url = url
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def build_group(
|
||||||
|
self,
|
||||||
|
gpu_id: int = -1,
|
||||||
|
tp_rank: int = -1,
|
||||||
|
instance_ip: str = None,
|
||||||
|
group_rank: int = 1,
|
||||||
|
world_size: int = 2,
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
self.device.type == "cuda"
|
||||||
|
), "RemoteInstanceConnector only supports cuda device."
|
||||||
|
assert (
|
||||||
|
gpu_id != -1 and tp_rank != -1
|
||||||
|
), "gpu_id and tp_rank must be specified for RemoteInstanceConnector. "
|
||||||
|
|
||||||
|
self.device_id = torch.device(self.device.type, gpu_id)
|
||||||
|
|
||||||
|
parsed_url = urlparse(self.url)
|
||||||
|
master_address = parsed_url.hostname
|
||||||
|
master_port = parsed_url.port
|
||||||
|
group_name = f"send_weights_{instance_ip}_{master_port}_{tp_rank}"
|
||||||
|
backend = "nccl"
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"init custom process group: master_address={master_address}, master_port={master_port}, "
|
||||||
|
f"rank_offset={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._model_update_group = init_custom_process_group(
|
||||||
|
backend=backend,
|
||||||
|
init_method=f"tcp://{master_address}:{master_port}",
|
||||||
|
world_size=world_size,
|
||||||
|
rank=group_rank,
|
||||||
|
group_name=group_name,
|
||||||
|
device_id=self.device_id,
|
||||||
|
)
|
||||||
|
dist.barrier(group=self._model_update_group)
|
||||||
|
return True, "Succeeded to initialize custom process group."
|
||||||
|
except Exception as e:
|
||||||
|
message = f"Failed to initialize custom process group: {e}."
|
||||||
|
logger.error(message)
|
||||||
|
return False, message
|
||||||
|
|
||||||
|
# Implemented as a no-op to make BaseConnector interface consistent.
|
||||||
|
def pull_files(
|
||||||
|
self,
|
||||||
|
allow_pattern: Optional[list[str]] = None,
|
||||||
|
ignore_pattern: Optional[list[str]] = None,
|
||||||
|
) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Implemented as a no-op to make BaseConnector interface consistent.
|
||||||
|
def weight_iterator(
|
||||||
|
self, rank: int = 0
|
||||||
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||||
|
return
|
||||||
@@ -73,6 +73,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
|
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
LoadLoRAAdapterReqInput,
|
LoadLoRAAdapterReqInput,
|
||||||
OpenSessionReqInput,
|
OpenSessionReqInput,
|
||||||
@@ -80,6 +81,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
ProfileReqInput,
|
ProfileReqInput,
|
||||||
ReleaseMemoryOccupationReqInput,
|
ReleaseMemoryOccupationReqInput,
|
||||||
ResumeMemoryOccupationReqInput,
|
ResumeMemoryOccupationReqInput,
|
||||||
|
SendWeightsToRemoteInstanceReqInput,
|
||||||
SeparateReasoningReqInput,
|
SeparateReasoningReqInput,
|
||||||
SetInternalStateReq,
|
SetInternalStateReq,
|
||||||
SlowDownReqInput,
|
SlowDownReqInput,
|
||||||
@@ -670,6 +672,38 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/init_weights_send_group_for_remote_instance")
|
||||||
|
async def init_weights_send_group_for_remote_instance(
|
||||||
|
obj: InitWeightsSendGroupForRemoteInstanceReqInput, request: Request
|
||||||
|
):
|
||||||
|
success, message = (
|
||||||
|
await _global_state.tokenizer_manager.init_weights_send_group_for_remote_instance(
|
||||||
|
obj, request
|
||||||
|
)
|
||||||
|
)
|
||||||
|
content = {"success": success, "message": message}
|
||||||
|
if success:
|
||||||
|
return ORJSONResponse(content, status_code=200)
|
||||||
|
else:
|
||||||
|
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/send_weights_to_remote_instance")
|
||||||
|
async def send_weights_to_remote_instance(
|
||||||
|
obj: SendWeightsToRemoteInstanceReqInput, request: Request
|
||||||
|
):
|
||||||
|
success, message = (
|
||||||
|
await _global_state.tokenizer_manager.send_weights_to_remote_instance(
|
||||||
|
obj, request
|
||||||
|
)
|
||||||
|
)
|
||||||
|
content = {"success": success, "message": message}
|
||||||
|
if success:
|
||||||
|
return ORJSONResponse(content, status_code=200)
|
||||||
|
else:
|
||||||
|
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/init_weights_update_group")
|
@app.post("/init_weights_update_group")
|
||||||
async def init_weights_update_group(
|
async def init_weights_update_group(
|
||||||
obj: InitWeightsUpdateGroupReqInput, request: Request
|
obj: InitWeightsUpdateGroupReqInput, request: Request
|
||||||
|
|||||||
@@ -1020,6 +1020,44 @@ class UpdateWeightsFromTensorReqOutput:
|
|||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InitWeightsSendGroupForRemoteInstanceReqInput:
|
||||||
|
# The master address
|
||||||
|
master_address: str
|
||||||
|
# The ports for each rank's communication group
|
||||||
|
ports: str
|
||||||
|
# The rank in the communication group
|
||||||
|
group_rank: int
|
||||||
|
# The world size
|
||||||
|
world_size: int
|
||||||
|
# The group name
|
||||||
|
group_name: str = "weight_send_group"
|
||||||
|
# The backend
|
||||||
|
backend: str = "nccl"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InitWeightsSendGroupForRemoteInstanceReqOutput:
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SendWeightsToRemoteInstanceReqInput:
|
||||||
|
# The master address
|
||||||
|
master_address: str
|
||||||
|
# The ports for each rank's communication group
|
||||||
|
ports: str
|
||||||
|
# The group name
|
||||||
|
group_name: str = "weight_send_group"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SendWeightsToRemoteInstanceReqOutput:
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InitWeightsUpdateGroupReqInput:
|
class InitWeightsUpdateGroupReqInput:
|
||||||
# The master address
|
# The master address
|
||||||
|
|||||||
@@ -81,6 +81,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
GetInternalStateReqOutput,
|
GetInternalStateReqOutput,
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
HealthCheckOutput,
|
HealthCheckOutput,
|
||||||
|
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||||
|
InitWeightsSendGroupForRemoteInstanceReqOutput,
|
||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
LoadLoRAAdapterReqInput,
|
LoadLoRAAdapterReqInput,
|
||||||
LoadLoRAAdapterReqOutput,
|
LoadLoRAAdapterReqOutput,
|
||||||
@@ -93,6 +95,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
ResumeMemoryOccupationReqInput,
|
ResumeMemoryOccupationReqInput,
|
||||||
RpcReqInput,
|
RpcReqInput,
|
||||||
RpcReqOutput,
|
RpcReqOutput,
|
||||||
|
SendWeightsToRemoteInstanceReqInput,
|
||||||
|
SendWeightsToRemoteInstanceReqOutput,
|
||||||
SetInternalStateReq,
|
SetInternalStateReq,
|
||||||
SetInternalStateReqOutput,
|
SetInternalStateReqOutput,
|
||||||
SlowDownReqInput,
|
SlowDownReqInput,
|
||||||
@@ -538,6 +542,14 @@ class Scheduler(
|
|||||||
(CloseSessionReqInput, self.close_session),
|
(CloseSessionReqInput, self.close_session),
|
||||||
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
||||||
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
||||||
|
(
|
||||||
|
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||||
|
self.init_weights_send_group_for_remote_instance,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
SendWeightsToRemoteInstanceReqInput,
|
||||||
|
self.send_weights_to_remote_instance,
|
||||||
|
),
|
||||||
(
|
(
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
self.update_weights_from_distributed,
|
self.update_weights_from_distributed,
|
||||||
@@ -2429,6 +2441,22 @@ class Scheduler(
|
|||||||
self.send_to_detokenizer.send_pyobj(recv_req)
|
self.send_to_detokenizer.send_pyobj(recv_req)
|
||||||
return recv_req
|
return recv_req
|
||||||
|
|
||||||
|
def init_weights_send_group_for_remote_instance(
|
||||||
|
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
||||||
|
):
|
||||||
|
"""Init the seed and client instance communication group."""
|
||||||
|
success, message = self.tp_worker.init_weights_send_group_for_remote_instance(
|
||||||
|
recv_req
|
||||||
|
)
|
||||||
|
return InitWeightsSendGroupForRemoteInstanceReqOutput(success, message)
|
||||||
|
|
||||||
|
def send_weights_to_remote_instance(
|
||||||
|
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
||||||
|
):
|
||||||
|
"""Send the seed instance weights to the destination instance."""
|
||||||
|
success, message = self.tp_worker.send_weights_to_remote_instance(recv_req)
|
||||||
|
return SendWeightsToRemoteInstanceReqOutput(success, message)
|
||||||
|
|
||||||
def slow_down(self, recv_req: SlowDownReqInput):
|
def slow_down(self, recv_req: SlowDownReqInput):
|
||||||
t = recv_req.forward_sleep_time
|
t = recv_req.forward_sleep_time
|
||||||
if t is not None and t <= 0:
|
if t is not None and t <= 0:
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
GetInternalStateReqOutput,
|
GetInternalStateReqOutput,
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
GetWeightsByNameReqOutput,
|
GetWeightsByNameReqOutput,
|
||||||
|
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||||
|
InitWeightsSendGroupForRemoteInstanceReqOutput,
|
||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
InitWeightsUpdateGroupReqOutput,
|
InitWeightsUpdateGroupReqOutput,
|
||||||
LoadLoRAAdapterReqInput,
|
LoadLoRAAdapterReqInput,
|
||||||
@@ -43,6 +45,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
ReleaseMemoryOccupationReqOutput,
|
ReleaseMemoryOccupationReqOutput,
|
||||||
ResumeMemoryOccupationReqInput,
|
ResumeMemoryOccupationReqInput,
|
||||||
ResumeMemoryOccupationReqOutput,
|
ResumeMemoryOccupationReqOutput,
|
||||||
|
SendWeightsToRemoteInstanceReqInput,
|
||||||
|
SendWeightsToRemoteInstanceReqOutput,
|
||||||
SetInternalStateReq,
|
SetInternalStateReq,
|
||||||
SetInternalStateReqOutput,
|
SetInternalStateReqOutput,
|
||||||
SlowDownReqInput,
|
SlowDownReqInput,
|
||||||
@@ -119,6 +123,12 @@ class TokenizerCommunicatorMixin:
|
|||||||
self.update_weights_from_distributed_communicator = _Communicator(
|
self.update_weights_from_distributed_communicator = _Communicator(
|
||||||
self.send_to_scheduler, server_args.dp_size
|
self.send_to_scheduler, server_args.dp_size
|
||||||
)
|
)
|
||||||
|
self.init_weights_send_group_for_remote_instance_communicator = _Communicator(
|
||||||
|
self.send_to_scheduler, server_args.dp_size
|
||||||
|
)
|
||||||
|
self.send_weights_to_remote_instance_communicator = _Communicator(
|
||||||
|
self.send_to_scheduler, server_args.dp_size
|
||||||
|
)
|
||||||
self.update_weights_from_tensor_communicator = _Communicator(
|
self.update_weights_from_tensor_communicator = _Communicator(
|
||||||
self.send_to_scheduler, server_args.dp_size
|
self.send_to_scheduler, server_args.dp_size
|
||||||
)
|
)
|
||||||
@@ -169,6 +179,14 @@ class TokenizerCommunicatorMixin:
|
|||||||
UpdateWeightsFromDistributedReqOutput,
|
UpdateWeightsFromDistributedReqOutput,
|
||||||
self.update_weights_from_distributed_communicator.handle_recv,
|
self.update_weights_from_distributed_communicator.handle_recv,
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
InitWeightsSendGroupForRemoteInstanceReqOutput,
|
||||||
|
self.init_weights_send_group_for_remote_instance_communicator.handle_recv,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
SendWeightsToRemoteInstanceReqOutput,
|
||||||
|
self.send_weights_to_remote_instance_communicator.handle_recv,
|
||||||
|
),
|
||||||
(
|
(
|
||||||
UpdateWeightsFromTensorReqOutput,
|
UpdateWeightsFromTensorReqOutput,
|
||||||
self.update_weights_from_tensor_communicator.handle_recv,
|
self.update_weights_from_tensor_communicator.handle_recv,
|
||||||
@@ -310,6 +328,34 @@ class TokenizerCommunicatorMixin:
|
|||||||
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
||||||
return result.success, result.message
|
return result.success, result.message
|
||||||
|
|
||||||
|
async def init_weights_send_group_for_remote_instance(
|
||||||
|
self,
|
||||||
|
obj: InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||||
|
request: Optional[fastapi.Request] = None,
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
|
self.auto_create_handle_loop()
|
||||||
|
# TODO: support DP
|
||||||
|
assert (
|
||||||
|
self.server_args.dp_size == 1
|
||||||
|
), "dp_size must be 1 for init_weights_send_group_for_remote_instance"
|
||||||
|
result = (
|
||||||
|
await self.init_weights_send_group_for_remote_instance_communicator(obj)
|
||||||
|
)[0]
|
||||||
|
return result.success, result.message
|
||||||
|
|
||||||
|
async def send_weights_to_remote_instance(
|
||||||
|
self,
|
||||||
|
obj: SendWeightsToRemoteInstanceReqInput,
|
||||||
|
request: Optional[fastapi.Request] = None,
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
|
self.auto_create_handle_loop()
|
||||||
|
# TODO: support DP
|
||||||
|
assert (
|
||||||
|
self.server_args.dp_size == 1
|
||||||
|
), "dp_size must be 1 for send_weights_to_remote_instance"
|
||||||
|
result = (await self.send_weights_to_remote_instance_communicator(obj))[0]
|
||||||
|
return result.success, result.message
|
||||||
|
|
||||||
async def update_weights_from_tensor(
|
async def update_weights_from_tensor(
|
||||||
self: TokenizerManager,
|
self: TokenizerManager,
|
||||||
obj: UpdateWeightsFromTensorReqInput,
|
obj: UpdateWeightsFromTensorReqInput,
|
||||||
|
|||||||
@@ -30,8 +30,10 @@ from sglang.srt.hf_transformers_utils import (
|
|||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
|
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
LoadLoRAAdapterReqInput,
|
LoadLoRAAdapterReqInput,
|
||||||
|
SendWeightsToRemoteInstanceReqInput,
|
||||||
UnloadLoRAAdapterReqInput,
|
UnloadLoRAAdapterReqInput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
@@ -88,6 +90,7 @@ class TpModelWorker:
|
|||||||
else server_args.speculative_draft_model_revision
|
else server_args.speculative_draft_model_revision
|
||||||
),
|
),
|
||||||
is_draft_model=is_draft_worker,
|
is_draft_model=is_draft_worker,
|
||||||
|
tp_rank=tp_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_runner = ModelRunner(
|
self.model_runner = ModelRunner(
|
||||||
@@ -292,6 +295,31 @@ class TpModelWorker:
|
|||||||
)
|
)
|
||||||
return success, message
|
return success, message
|
||||||
|
|
||||||
|
def init_weights_send_group_for_remote_instance(
|
||||||
|
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
||||||
|
):
|
||||||
|
success, message = (
|
||||||
|
self.model_runner.init_weights_send_group_for_remote_instance(
|
||||||
|
recv_req.master_address,
|
||||||
|
recv_req.ports,
|
||||||
|
recv_req.group_rank,
|
||||||
|
recv_req.world_size,
|
||||||
|
recv_req.group_name,
|
||||||
|
recv_req.backend,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return success, message
|
||||||
|
|
||||||
|
def send_weights_to_remote_instance(
|
||||||
|
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
||||||
|
):
|
||||||
|
success, message = self.model_runner.send_weights_to_remote_instance(
|
||||||
|
recv_req.master_address,
|
||||||
|
recv_req.ports,
|
||||||
|
recv_req.group_name,
|
||||||
|
)
|
||||||
|
return success, message
|
||||||
|
|
||||||
def update_weights_from_distributed(
|
def update_weights_from_distributed(
|
||||||
self, recv_req: UpdateWeightsFromDistributedReqInput
|
self, recv_req: UpdateWeightsFromDistributedReqInput
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -26,8 +26,10 @@ import torch
|
|||||||
|
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
|
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
LoadLoRAAdapterReqInput,
|
LoadLoRAAdapterReqInput,
|
||||||
|
SendWeightsToRemoteInstanceReqInput,
|
||||||
UnloadLoRAAdapterReqInput,
|
UnloadLoRAAdapterReqInput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
@@ -267,6 +269,20 @@ class TpModelWorkerClient:
|
|||||||
success, message = self.worker.init_weights_update_group(recv_req)
|
success, message = self.worker.init_weights_update_group(recv_req)
|
||||||
return success, message
|
return success, message
|
||||||
|
|
||||||
|
def init_weights_send_group_for_remote_instance(
|
||||||
|
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
||||||
|
):
|
||||||
|
success, message = self.worker.init_weights_send_group_for_remote_instance(
|
||||||
|
recv_req
|
||||||
|
)
|
||||||
|
return success, message
|
||||||
|
|
||||||
|
def send_weights_to_remote_instance(
|
||||||
|
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
||||||
|
):
|
||||||
|
success, message = self.worker.send_weights_to_remote_instance(recv_req)
|
||||||
|
return success, message
|
||||||
|
|
||||||
def update_weights_from_distributed(
|
def update_weights_from_distributed(
|
||||||
self, recv_req: UpdateWeightsFromDistributedReqInput
|
self, recv_req: UpdateWeightsFromDistributedReqInput
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -19,18 +19,23 @@ import inspect
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import socket
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import requests
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from sglang.srt.configs.device_config import DeviceConfig
|
from sglang.srt.configs.device_config import DeviceConfig
|
||||||
from sglang.srt.configs.load_config import LoadConfig
|
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
||||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||||
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
|
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
|
||||||
|
from sglang.srt.connector import ConnectorType
|
||||||
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
|
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
get_pp_group,
|
get_pp_group,
|
||||||
@@ -106,6 +111,9 @@ from sglang.srt.offloader import (
|
|||||||
set_offloader,
|
set_offloader,
|
||||||
)
|
)
|
||||||
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
||||||
|
from sglang.srt.remote_instance_weight_loader_utils import (
|
||||||
|
trigger_init_weights_send_group_for_remote_instance_request,
|
||||||
|
)
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
@@ -128,6 +136,7 @@ from sglang.srt.utils import (
|
|||||||
is_sm100_supported,
|
is_sm100_supported,
|
||||||
monkey_patch_p2p_access_check,
|
monkey_patch_p2p_access_check,
|
||||||
monkey_patch_vllm_gguf_config,
|
monkey_patch_vllm_gguf_config,
|
||||||
|
parse_connector_type,
|
||||||
set_cuda_arch,
|
set_cuda_arch,
|
||||||
)
|
)
|
||||||
from sglang.srt.weight_sync.tensor_bucket import (
|
from sglang.srt.weight_sync.tensor_bucket import (
|
||||||
@@ -256,6 +265,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
# For weight updates
|
# For weight updates
|
||||||
self._model_update_group = {}
|
self._model_update_group = {}
|
||||||
|
self._weights_send_group = {}
|
||||||
|
|
||||||
def initialize(self, min_per_gpu_memory: float):
|
def initialize(self, min_per_gpu_memory: float):
|
||||||
server_args = self.server_args
|
server_args = self.server_args
|
||||||
@@ -726,6 +736,20 @@ class ModelRunner:
|
|||||||
if self.server_args.load_format == "gguf":
|
if self.server_args.load_format == "gguf":
|
||||||
monkey_patch_vllm_gguf_config()
|
monkey_patch_vllm_gguf_config()
|
||||||
|
|
||||||
|
if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE:
|
||||||
|
if self.tp_rank == 0:
|
||||||
|
instance_ip = socket.gethostbyname(socket.gethostname())
|
||||||
|
t = threading.Thread(
|
||||||
|
target=trigger_init_weights_send_group_for_remote_instance_request,
|
||||||
|
args=(
|
||||||
|
self.server_args.remote_instance_weight_loader_seed_instance_ip,
|
||||||
|
self.server_args.remote_instance_weight_loader_seed_instance_service_port,
|
||||||
|
self.server_args.remote_instance_weight_loader_send_weights_group_ports,
|
||||||
|
instance_ip,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
t.start()
|
||||||
|
|
||||||
# Load the model
|
# Load the model
|
||||||
# Remove monkey_patch when linear.py quant remove dependencies with vllm
|
# Remove monkey_patch when linear.py quant remove dependencies with vllm
|
||||||
monkey_patch_vllm_parallel_state()
|
monkey_patch_vllm_parallel_state()
|
||||||
@@ -735,7 +759,7 @@ class ModelRunner:
|
|||||||
self.model = get_model(
|
self.model = get_model(
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
load_config=self.load_config,
|
load_config=self.load_config,
|
||||||
device_config=DeviceConfig(self.device),
|
device_config=DeviceConfig(self.device, self.gpu_id),
|
||||||
)
|
)
|
||||||
monkey_patch_vllm_parallel_state(reverse=True)
|
monkey_patch_vllm_parallel_state(reverse=True)
|
||||||
monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
|
monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
|
||||||
@@ -867,6 +891,103 @@ class ModelRunner:
|
|||||||
logger.info("Update weights end.")
|
logger.info("Update weights end.")
|
||||||
return True, "Succeeded to update model weights."
|
return True, "Succeeded to update model weights."
|
||||||
|
|
||||||
|
def init_weights_send_group_for_remote_instance(
|
||||||
|
self,
|
||||||
|
master_address,
|
||||||
|
ports,
|
||||||
|
group_rank,
|
||||||
|
world_size,
|
||||||
|
group_name,
|
||||||
|
backend="nccl",
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
torch.distributed.is_initialized()
|
||||||
|
), "Default torch process group must be initialized"
|
||||||
|
assert group_name != "", "Group name cannot be empty"
|
||||||
|
|
||||||
|
ports_list = ports.split(",")
|
||||||
|
assert (
|
||||||
|
len(ports_list) == self.tp_size
|
||||||
|
), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
|
||||||
|
group_port = ports_list[self.tp_rank]
|
||||||
|
group_name = f"{group_name}_{group_port}_{self.tp_rank}"
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"init custom process group: tp_rank={self.tp_rank}, gpu_id={self.gpu_id}, master_address={master_address}, master_port={group_port}, "
|
||||||
|
f"group_rank={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
success = False
|
||||||
|
message = ""
|
||||||
|
try:
|
||||||
|
self._weights_send_group[group_name] = init_custom_process_group(
|
||||||
|
backend=backend,
|
||||||
|
init_method=f"tcp://{master_address}:{group_port}",
|
||||||
|
world_size=world_size,
|
||||||
|
rank=group_rank,
|
||||||
|
group_name=group_name,
|
||||||
|
device_id=torch.device("cuda", self.gpu_id),
|
||||||
|
)
|
||||||
|
dist.barrier(group=self._weights_send_group[group_name])
|
||||||
|
success = True
|
||||||
|
message = (
|
||||||
|
f"Succeeded to init group through {master_address}:{group_port} group."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
message = f"Failed to init group: {e}."
|
||||||
|
logger.error(message)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return success, message
|
||||||
|
|
||||||
|
def send_weights_to_remote_instance(
|
||||||
|
self,
|
||||||
|
master_address,
|
||||||
|
ports,
|
||||||
|
group_name,
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
torch.distributed.is_initialized()
|
||||||
|
), "Default torch process group must be initialized"
|
||||||
|
assert group_name != "", "Group name cannot be empty"
|
||||||
|
|
||||||
|
ports_list = ports.split(",")
|
||||||
|
assert (
|
||||||
|
len(ports_list) == self.tp_size
|
||||||
|
), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
|
||||||
|
group_port = ports_list[self.tp_rank]
|
||||||
|
group_name = f"{group_name}_{group_port}_{self.tp_rank}"
|
||||||
|
|
||||||
|
if self._weights_send_group[group_name] is not None:
|
||||||
|
send_group = self._weights_send_group[group_name]
|
||||||
|
else:
|
||||||
|
message = f"Group {group_name} not in _weights_send_group list. Please call `init_weights_send_group_for_remote_instance` first."
|
||||||
|
logger.error(message)
|
||||||
|
return False, message
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
success = False
|
||||||
|
message = ""
|
||||||
|
try:
|
||||||
|
for _, weights in self.model.named_parameters():
|
||||||
|
torch.distributed.broadcast(
|
||||||
|
weights,
|
||||||
|
src=0,
|
||||||
|
group=send_group,
|
||||||
|
)
|
||||||
|
success = True
|
||||||
|
message = f"Succeeded to send weights through {master_address}:{group_port} {group_name}."
|
||||||
|
except Exception as e:
|
||||||
|
message = f"Failed to send weights: {e}."
|
||||||
|
logger.error(message)
|
||||||
|
|
||||||
|
# destroy the process group after sending weights
|
||||||
|
del self._weights_send_group[group_name]
|
||||||
|
torch.distributed.distributed_c10d.destroy_process_group(send_group)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return success, message
|
||||||
|
|
||||||
def init_weights_update_group(
|
def init_weights_update_group(
|
||||||
self,
|
self,
|
||||||
master_address,
|
master_address,
|
||||||
|
|||||||
@@ -12,6 +12,9 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
import socket
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
@@ -27,9 +30,11 @@ from typing import (
|
|||||||
Tuple,
|
Tuple,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import requests
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import HfApi, hf_hub_download
|
from huggingface_hub import HfApi, hf_hub_download
|
||||||
@@ -56,6 +61,7 @@ from sglang.srt.model_loader.utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.model_loader.weight_utils import (
|
from sglang.srt.model_loader.weight_utils import (
|
||||||
_BAR_FORMAT,
|
_BAR_FORMAT,
|
||||||
|
default_weight_loader,
|
||||||
download_safetensors_index_file_from_hf,
|
download_safetensors_index_file_from_hf,
|
||||||
download_weights_from_hf,
|
download_weights_from_hf,
|
||||||
filter_duplicate_safetensors_files,
|
filter_duplicate_safetensors_files,
|
||||||
@@ -71,6 +77,9 @@ from sglang.srt.model_loader.weight_utils import (
|
|||||||
safetensors_weights_iterator,
|
safetensors_weights_iterator,
|
||||||
set_runai_streamer_env,
|
set_runai_streamer_env,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.remote_instance_weight_loader_utils import (
|
||||||
|
trigger_transferring_weights_request,
|
||||||
|
)
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_device_capability,
|
get_device_capability,
|
||||||
@@ -1380,6 +1389,104 @@ class GGUFModelLoader(BaseModelLoader):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteInstanceModelLoader(BaseModelLoader):
|
||||||
|
"""Model loader that can load Tensors from remote sglang instance."""
|
||||||
|
|
||||||
|
def __init__(self, load_config: LoadConfig):
|
||||||
|
super().__init__(load_config)
|
||||||
|
if load_config.model_loader_extra_config:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model loader extra config is not supported for "
|
||||||
|
f"load format {load_config.load_format}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def download_model(self, model_config: ModelConfig) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def load_model(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
) -> nn.Module:
|
||||||
|
logger.info("Loading weights from remote instance ...")
|
||||||
|
load_config = self.load_config
|
||||||
|
|
||||||
|
assert load_config.load_format == LoadFormat.REMOTE_INSTANCE, (
|
||||||
|
f"Model loader {self.load_config.load_format} is not supported for "
|
||||||
|
f"load format {load_config.load_format}"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_weights = f"instance://{model_config.remote_instance_weight_loader_seed_instance_ip}:{model_config.remote_instance_weight_loader_send_weights_group_ports[model_config.tp_rank]}"
|
||||||
|
|
||||||
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
|
with torch.device(device_config.device):
|
||||||
|
model = _initialize_model(model_config, self.load_config)
|
||||||
|
|
||||||
|
with create_remote_connector(model_weights, device_config.device) as client:
|
||||||
|
connector_type = get_connector_type(client)
|
||||||
|
if connector_type == ConnectorType.INSTANCE:
|
||||||
|
self.load_model_from_remote_instance(
|
||||||
|
model, client, model_config, device_config
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported connector type {connector_type} for "
|
||||||
|
f"remote tensor model loading."
|
||||||
|
)
|
||||||
|
return model.eval()
|
||||||
|
|
||||||
|
def load_model_from_remote_instance(
|
||||||
|
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
|
||||||
|
) -> nn.Module:
|
||||||
|
instance_ip = socket.gethostbyname(socket.gethostname())
|
||||||
|
start_build_group_tic = time.time()
|
||||||
|
client.build_group(
|
||||||
|
gpu_id=device_config.gpu_id,
|
||||||
|
tp_rank=model_config.tp_rank,
|
||||||
|
instance_ip=instance_ip,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end_build_group_tic = time.time()
|
||||||
|
logger.debug(
|
||||||
|
f"finish building group for remote instance, time used: {(end_build_group_tic - start_build_group_tic):.4f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_config.tp_rank == 0:
|
||||||
|
t = threading.Thread(
|
||||||
|
target=trigger_transferring_weights_request,
|
||||||
|
args=(
|
||||||
|
model_config.remote_instance_weight_loader_seed_instance_ip,
|
||||||
|
model_config.remote_instance_weight_loader_seed_instance_service_port,
|
||||||
|
model_config.remote_instance_weight_loader_send_weights_group_ports,
|
||||||
|
instance_ip,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
start_get_weights_tic = time.time()
|
||||||
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
|
for _, tensor in model.named_parameters():
|
||||||
|
torch.distributed.broadcast(
|
||||||
|
tensor.data,
|
||||||
|
src=0,
|
||||||
|
group=client._model_update_group,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
if hasattr(model, "post_load_weights"):
|
||||||
|
model.post_load_weights()
|
||||||
|
end_get_weights_tic = time.time()
|
||||||
|
logger.debug(
|
||||||
|
f"finish getting all weights from remote instance, time used: {(end_get_weights_tic - start_get_weights_tic):.4f}s"
|
||||||
|
)
|
||||||
|
# destroy the process group after loading weights
|
||||||
|
torch.distributed.distributed_c10d.destroy_process_group(
|
||||||
|
client._model_update_group
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
class RemoteModelLoader(BaseModelLoader):
|
class RemoteModelLoader(BaseModelLoader):
|
||||||
"""Model loader that can load Tensors from remote database."""
|
"""Model loader that can load Tensors from remote database."""
|
||||||
|
|
||||||
@@ -1581,4 +1688,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
|||||||
if load_config.load_format == LoadFormat.REMOTE:
|
if load_config.load_format == LoadFormat.REMOTE:
|
||||||
return RemoteModelLoader(load_config)
|
return RemoteModelLoader(load_config)
|
||||||
|
|
||||||
|
if load_config.load_format == LoadFormat.REMOTE_INSTANCE:
|
||||||
|
return RemoteInstanceModelLoader(load_config)
|
||||||
|
|
||||||
return DefaultModelLoader(load_config)
|
return DefaultModelLoader(load_config)
|
||||||
|
|||||||
69
python/sglang/srt/remote_instance_weight_loader_utils.py
Normal file
69
python/sglang/srt/remote_instance_weight_loader_utils.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def trigger_init_weights_send_group_for_remote_instance_request(
|
||||||
|
remote_instance_weight_loader_seed_instance_ip: str,
|
||||||
|
remote_instance_weight_loader_seed_instance_service_port: int,
|
||||||
|
remote_instance_weight_loader_send_weights_group_ports: List[int],
|
||||||
|
remote_instance_weight_loader_client_id: str,
|
||||||
|
):
|
||||||
|
seed_instance_service_url = f"http://{remote_instance_weight_loader_seed_instance_ip}:{remote_instance_weight_loader_seed_instance_service_port}"
|
||||||
|
# Only support loading weights from instance with same parallelism strategy.
|
||||||
|
# Per TP rank pair between seed and dst instances will build a communication group for sending weights.
|
||||||
|
# i.e. seed TP 0 <-> dst TP 0, seed TP 1 <-> dst TP 1, etc.
|
||||||
|
# Each communication group will have a world size 2.
|
||||||
|
try:
|
||||||
|
requests.post(
|
||||||
|
f"{seed_instance_service_url}/init_weights_send_group_for_remote_instance",
|
||||||
|
json={
|
||||||
|
"master_address": remote_instance_weight_loader_seed_instance_ip,
|
||||||
|
"ports": (
|
||||||
|
",".join(
|
||||||
|
str(p)
|
||||||
|
for p in remote_instance_weight_loader_send_weights_group_ports
|
||||||
|
)
|
||||||
|
),
|
||||||
|
"group_rank": 0,
|
||||||
|
"world_size": 2,
|
||||||
|
"group_name": f"send_weights_{remote_instance_weight_loader_client_id}",
|
||||||
|
"backend": "nccl",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to trigger init_weights_send_group_for_remote_instance_request to seed instance {seed_instance_service_url}: {e}."
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def trigger_transferring_weights_request(
|
||||||
|
remote_instance_weight_loader_seed_instance_ip: str,
|
||||||
|
remote_instance_weight_loader_seed_instance_service_port: int,
|
||||||
|
remote_instance_weight_loader_send_weights_group_ports: List[int],
|
||||||
|
remote_instance_weight_loader_client_id: str,
|
||||||
|
):
|
||||||
|
seed_instance_service_url = f"http://{remote_instance_weight_loader_seed_instance_ip}:{remote_instance_weight_loader_seed_instance_service_port}"
|
||||||
|
try:
|
||||||
|
requests.post(
|
||||||
|
f"{seed_instance_service_url}/send_weights_to_remote_instance",
|
||||||
|
json={
|
||||||
|
"master_address": remote_instance_weight_loader_seed_instance_ip,
|
||||||
|
"ports": (
|
||||||
|
",".join(
|
||||||
|
str(p)
|
||||||
|
for p in remote_instance_weight_loader_send_weights_group_ports
|
||||||
|
)
|
||||||
|
),
|
||||||
|
"group_name": f"send_weights_{remote_instance_weight_loader_client_id}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to trigger send weights to remote instance request: {e}")
|
||||||
|
raise
|
||||||
@@ -19,10 +19,12 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import socket
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from sglang.srt.connector import ConnectorType
|
||||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||||
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
|
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
|
||||||
from sglang.srt.lora.lora_registry import LoRARef
|
from sglang.srt.lora.lora_registry import LoRARef
|
||||||
@@ -42,7 +44,9 @@ from sglang.srt.utils import (
|
|||||||
is_sm100_supported,
|
is_sm100_supported,
|
||||||
is_triton_kernels_available,
|
is_triton_kernels_available,
|
||||||
is_valid_ipv6_address,
|
is_valid_ipv6_address,
|
||||||
|
json_list_type,
|
||||||
nullable_str,
|
nullable_str,
|
||||||
|
parse_connector_type,
|
||||||
)
|
)
|
||||||
from sglang.utils import is_in_ci
|
from sglang.utils import is_in_ci
|
||||||
|
|
||||||
@@ -61,6 +65,7 @@ LOAD_FORMAT_CHOICES = [
|
|||||||
"bitsandbytes",
|
"bitsandbytes",
|
||||||
"layered",
|
"layered",
|
||||||
"remote",
|
"remote",
|
||||||
|
"remote_instance",
|
||||||
]
|
]
|
||||||
|
|
||||||
QUANTIZATION_CHOICES = [
|
QUANTIZATION_CHOICES = [
|
||||||
@@ -387,6 +392,11 @@ class ServerArgs:
|
|||||||
custom_weight_loader: Optional[List[str]] = None
|
custom_weight_loader: Optional[List[str]] = None
|
||||||
weight_loader_disable_mmap: bool = False
|
weight_loader_disable_mmap: bool = False
|
||||||
|
|
||||||
|
# Remote instance weight loading
|
||||||
|
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
|
||||||
|
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
|
||||||
|
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
|
||||||
|
|
||||||
# For PD-Multiplexing
|
# For PD-Multiplexing
|
||||||
enable_pdmux: bool = False
|
enable_pdmux: bool = False
|
||||||
sm_group_num: int = 3
|
sm_group_num: int = 3
|
||||||
@@ -445,6 +455,7 @@ class ServerArgs:
|
|||||||
# Set missing default values
|
# Set missing default values
|
||||||
if self.tokenizer_path is None:
|
if self.tokenizer_path is None:
|
||||||
self.tokenizer_path = self.model_path
|
self.tokenizer_path = self.model_path
|
||||||
|
|
||||||
if self.served_model_name is None:
|
if self.served_model_name is None:
|
||||||
self.served_model_name = self.model_path
|
self.served_model_name = self.model_path
|
||||||
if self.device is None:
|
if self.device is None:
|
||||||
@@ -538,7 +549,8 @@ class ServerArgs:
|
|||||||
self.sampling_backend = "pytorch"
|
self.sampling_backend = "pytorch"
|
||||||
|
|
||||||
# Model-specific adjustments
|
# Model-specific adjustments
|
||||||
self.model_specific_adjustments()
|
if parse_connector_type(self.model_path) != ConnectorType.INSTANCE:
|
||||||
|
self.model_specific_adjustments()
|
||||||
|
|
||||||
# Set kernel backends
|
# Set kernel backends
|
||||||
if self.device == "cpu":
|
if self.device == "cpu":
|
||||||
@@ -818,12 +830,19 @@ class ServerArgs:
|
|||||||
) and check_gguf_file(self.model_path):
|
) and check_gguf_file(self.model_path):
|
||||||
self.quantization = self.load_format = "gguf"
|
self.quantization = self.load_format = "gguf"
|
||||||
|
|
||||||
# Model loading
|
|
||||||
if is_remote_url(self.model_path):
|
if is_remote_url(self.model_path):
|
||||||
self.load_format = "remote"
|
self.load_format = "remote"
|
||||||
if self.custom_weight_loader is None:
|
if self.custom_weight_loader is None:
|
||||||
self.custom_weight_loader = []
|
self.custom_weight_loader = []
|
||||||
|
|
||||||
|
if self.load_format == "remote_instance":
|
||||||
|
if (
|
||||||
|
self.remote_instance_weight_loader_seed_instance_ip is None
|
||||||
|
or self.remote_instance_weight_loader_seed_instance_service_port is None
|
||||||
|
or self.remote_instance_weight_loader_send_weights_group_ports is None
|
||||||
|
):
|
||||||
|
self.load_format = "auto"
|
||||||
|
|
||||||
# PD disaggregation
|
# PD disaggregation
|
||||||
if self.disaggregation_mode == "decode":
|
if self.disaggregation_mode == "decode":
|
||||||
assert (
|
assert (
|
||||||
@@ -881,6 +900,24 @@ class ServerArgs:
|
|||||||
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
|
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
|
||||||
required=True,
|
required=True,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--remote-instance-weight-loader-seed-instance-ip",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.remote_instance_weight_loader_seed_instance_ip,
|
||||||
|
help="The ip of the seed instance for loading weights from remote instance.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--remote-instance-weight-loader-seed-instance-service-port",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.remote_instance_weight_loader_seed_instance_service_port,
|
||||||
|
help="The service port of the seed instance for loading weights from remote instance.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--remote-instance-weight-loader-send-weights-group-ports",
|
||||||
|
type=json_list_type,
|
||||||
|
default=ServerArgs.remote_instance_weight_loader_send_weights_group_ports,
|
||||||
|
help="The communication group ports for loading weights from remote instance.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokenizer-path",
|
"--tokenizer-path",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import builtins
|
import builtins
|
||||||
import ctypes
|
import ctypes
|
||||||
@@ -1431,6 +1432,7 @@ def init_custom_process_group(
|
|||||||
store=None,
|
store=None,
|
||||||
group_name=None,
|
group_name=None,
|
||||||
pg_options=None,
|
pg_options=None,
|
||||||
|
device_id=None,
|
||||||
):
|
):
|
||||||
from torch.distributed.distributed_c10d import (
|
from torch.distributed.distributed_c10d import (
|
||||||
Backend,
|
Backend,
|
||||||
@@ -1484,6 +1486,7 @@ def init_custom_process_group(
|
|||||||
group_name=group_name,
|
group_name=group_name,
|
||||||
**{pg_options_param_name: pg_options},
|
**{pg_options_param_name: pg_options},
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
device_id=device_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
|
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
|
||||||
@@ -3046,3 +3049,12 @@ def numa_bind_to_node(node: int):
|
|||||||
|
|
||||||
libnuma.numa_run_on_node(ctypes.c_int(node))
|
libnuma.numa_run_on_node(ctypes.c_int(node))
|
||||||
libnuma.numa_set_localalloc()
|
libnuma.numa_set_localalloc()
|
||||||
|
|
||||||
|
|
||||||
|
def json_list_type(value):
|
||||||
|
try:
|
||||||
|
return json.loads(value)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise argparse.ArgumentTypeError(
|
||||||
|
f"Invalid JSON list: {value}. Please provide a valid JSON list."
|
||||||
|
)
|
||||||
|
|||||||
@@ -123,6 +123,7 @@ suites = {
|
|||||||
TestFile("rl/test_update_weights_from_distributed.py", 103),
|
TestFile("rl/test_update_weights_from_distributed.py", 103),
|
||||||
TestFile("test_data_parallelism.py", 73),
|
TestFile("test_data_parallelism.py", 73),
|
||||||
TestFile("test_dp_attention.py", 277),
|
TestFile("test_dp_attention.py", 277),
|
||||||
|
TestFile("test_load_weights_from_remote_instance.py", 72),
|
||||||
TestFile("test_patch_torch.py", 19),
|
TestFile("test_patch_torch.py", 19),
|
||||||
TestFile("test_release_memory_occupation.py", 127),
|
TestFile("test_release_memory_occupation.py", 127),
|
||||||
TestFile("hicache/test_hicache_storage_file_backend.py", 400),
|
TestFile("hicache/test_hicache_storage_file_backend.py", 400),
|
||||||
@@ -251,6 +252,7 @@ suite_amd = {
|
|||||||
TestFile("lora/test_lora_tp.py", 116),
|
TestFile("lora/test_lora_tp.py", 116),
|
||||||
TestFile("rl/test_update_weights_from_distributed.py", 103),
|
TestFile("rl/test_update_weights_from_distributed.py", 103),
|
||||||
TestFile("test_data_parallelism.py", 73),
|
TestFile("test_data_parallelism.py", 73),
|
||||||
|
TestFile("test_load_weights_from_remote_instance.py", 72),
|
||||||
TestFile("test_patch_torch.py", 19),
|
TestFile("test_patch_torch.py", 19),
|
||||||
],
|
],
|
||||||
"per-commit-4-gpu-amd": [
|
"per-commit-4-gpu-amd": [
|
||||||
|
|||||||
384
test/srt/test_load_weights_from_remote_instance.py
Normal file
384
test/srt/test_load_weights_from_remote_instance.py
Normal file
@@ -0,0 +1,384 @@
|
|||||||
|
"""Test loading weights from remote instance.
|
||||||
|
|
||||||
|
This test suite simulates loading weights from a remote instance.
|
||||||
|
Rank 0 represents the seed instance, while ranks 1 represents the
|
||||||
|
new instance that needs to loading weights from the seed instance.
|
||||||
|
|
||||||
|
Seed instance must be started in `Server` mode, while the dst instance
|
||||||
|
can be either `Engine` mode or `Server` mode.
|
||||||
|
|
||||||
|
Seed instance does not support concurrently serving multiple dst instances.
|
||||||
|
User has to guarantee that there is only one dst instance trying to load
|
||||||
|
weights from the seed instance at any time.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
is_in_ci,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
from sglang.utils import terminate_process
|
||||||
|
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_params_close(params1, params2, error_msg):
|
||||||
|
"""Verify if two parameter arrays are close enough."""
|
||||||
|
try:
|
||||||
|
assert np.allclose(np.array(params1), np.array(params2)), error_msg
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Parameters not close for {error_msg}")
|
||||||
|
print("Params1:", np.array(params1))
|
||||||
|
print("Params2:", np.array(params2))
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def init_process(
|
||||||
|
rank,
|
||||||
|
param_queue,
|
||||||
|
truncate_size,
|
||||||
|
tp_size,
|
||||||
|
model_name,
|
||||||
|
backends,
|
||||||
|
checking_parameters,
|
||||||
|
seed_instance_ip,
|
||||||
|
seed_instance_service_port,
|
||||||
|
seed_instance_group_base_port,
|
||||||
|
event_seed_ready,
|
||||||
|
event_dst_ready_list,
|
||||||
|
):
|
||||||
|
torch.cuda.set_device(rank)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
init_process_seed(
|
||||||
|
rank,
|
||||||
|
param_queue,
|
||||||
|
truncate_size,
|
||||||
|
model_name,
|
||||||
|
checking_parameters,
|
||||||
|
tp_size,
|
||||||
|
event_seed_ready,
|
||||||
|
event_dst_ready_list,
|
||||||
|
)
|
||||||
|
elif rank in [1, 2]:
|
||||||
|
init_process_dst(
|
||||||
|
rank,
|
||||||
|
param_queue,
|
||||||
|
truncate_size,
|
||||||
|
model_name,
|
||||||
|
seed_instance_ip,
|
||||||
|
seed_instance_service_port,
|
||||||
|
seed_instance_group_base_port,
|
||||||
|
checking_parameters,
|
||||||
|
backends[rank - 1],
|
||||||
|
tp_size,
|
||||||
|
event_seed_ready,
|
||||||
|
event_dst_ready_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def init_process_seed(
|
||||||
|
rank,
|
||||||
|
param_queue,
|
||||||
|
truncate_size,
|
||||||
|
model_name,
|
||||||
|
checking_parameters,
|
||||||
|
tp_size,
|
||||||
|
event_seed_ready,
|
||||||
|
event_dst_ready_list,
|
||||||
|
):
|
||||||
|
# These two environment variables are very important
|
||||||
|
# to avoid unexpected behaviors of CUDA and NCCL.
|
||||||
|
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
||||||
|
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
||||||
|
|
||||||
|
# Load model and get parameters
|
||||||
|
torch.cuda.set_device(rank)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
url = DEFAULT_URL_FOR_TEST
|
||||||
|
process = popen_launch_server(
|
||||||
|
model_name,
|
||||||
|
url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=(
|
||||||
|
"--base-gpu-id",
|
||||||
|
str(rank),
|
||||||
|
"--tp-size",
|
||||||
|
str(tp_size),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
seed_params = []
|
||||||
|
# Get the weights of seed instance for correctness check.
|
||||||
|
for parameter_name in checking_parameters:
|
||||||
|
seed_params.append(
|
||||||
|
requests.get(
|
||||||
|
f"{url}/get_weights_by_name",
|
||||||
|
json={
|
||||||
|
"name": parameter_name,
|
||||||
|
"truncate_size": truncate_size,
|
||||||
|
},
|
||||||
|
).json()
|
||||||
|
)
|
||||||
|
param_queue.put((f"seed_params", seed_params))
|
||||||
|
|
||||||
|
event_seed_ready.set()
|
||||||
|
for i in range(len(event_dst_ready_list)):
|
||||||
|
event_dst_ready_list[i].wait()
|
||||||
|
terminate_process(process)
|
||||||
|
|
||||||
|
|
||||||
|
def init_process_dst(
|
||||||
|
rank,
|
||||||
|
param_queue,
|
||||||
|
truncate_size,
|
||||||
|
model_name,
|
||||||
|
seed_instance_ip,
|
||||||
|
seed_instance_service_port,
|
||||||
|
seed_instance_group_base_port,
|
||||||
|
checking_parameters,
|
||||||
|
backend,
|
||||||
|
tp_size,
|
||||||
|
event_seed_ready,
|
||||||
|
event_dst_ready_list,
|
||||||
|
):
|
||||||
|
torch.cuda.set_device(rank * tp_size)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
base_gpu_id = rank * tp_size
|
||||||
|
|
||||||
|
event_seed_ready.wait()
|
||||||
|
print(f"rank {rank}, seed ready")
|
||||||
|
for i in range(rank - 1):
|
||||||
|
print(f"rank {rank}, wait dst {i}")
|
||||||
|
event_dst_ready_list[i].wait()
|
||||||
|
|
||||||
|
ports = []
|
||||||
|
for i in range(tp_size):
|
||||||
|
ports.append(seed_instance_group_base_port + (rank - 1) * tp_size + i)
|
||||||
|
|
||||||
|
if backend == "Engine":
|
||||||
|
print(f"[sgl] rank {rank} init engine")
|
||||||
|
engine = sgl.Engine(
|
||||||
|
model_path=model_name,
|
||||||
|
base_gpu_id=base_gpu_id,
|
||||||
|
tp_size=tp_size,
|
||||||
|
cuda_graph_max_bs=2,
|
||||||
|
tokenizer_path=model_name,
|
||||||
|
remote_instance_weight_loader_seed_instance_ip=seed_instance_ip,
|
||||||
|
remote_instance_weight_loader_seed_instance_service_port=seed_instance_service_port,
|
||||||
|
remote_instance_weight_loader_send_weights_group_ports=ports,
|
||||||
|
load_format="remote_instance",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
host, _, port = DEFAULT_URL_FOR_TEST.rpartition(":")
|
||||||
|
url = ":".join([host, str(int(port) + 10000 + rank)])
|
||||||
|
|
||||||
|
print(f"[sgl] rank {rank} init server on url: {url}")
|
||||||
|
process = popen_launch_server(
|
||||||
|
model_name,
|
||||||
|
url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=(
|
||||||
|
"--base-gpu-id",
|
||||||
|
str(base_gpu_id),
|
||||||
|
"--tp-size",
|
||||||
|
str(tp_size),
|
||||||
|
"--cuda-graph-max-bs",
|
||||||
|
2,
|
||||||
|
"--tokenizer-path",
|
||||||
|
model_name,
|
||||||
|
"--remote-instance-weight-loader-seed-instance-ip",
|
||||||
|
seed_instance_ip,
|
||||||
|
"--remote-instance-weight-loader-seed-instance-service-port",
|
||||||
|
seed_instance_service_port,
|
||||||
|
"--remote-instance-weight-loader-send-weights-group-ports",
|
||||||
|
f"[{','.join(str(port) for port in ports)}]",
|
||||||
|
"--load-format",
|
||||||
|
"remote_instance",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
event_dst_ready_list[rank - 1].set()
|
||||||
|
|
||||||
|
# Get weights of destination instance loaded from remote instance.
|
||||||
|
dst_params = []
|
||||||
|
for parameter_name in checking_parameters:
|
||||||
|
dst_params.append(
|
||||||
|
engine.get_weights_by_name(parameter_name, truncate_size)
|
||||||
|
if backend == "Engine"
|
||||||
|
else requests.get(
|
||||||
|
f"{url}/get_weights_by_name",
|
||||||
|
json={"name": parameter_name, "truncate_size": truncate_size},
|
||||||
|
).json()
|
||||||
|
)
|
||||||
|
|
||||||
|
param_queue.put((f"sgl_dp_{rank}_dst_params", dst_params))
|
||||||
|
|
||||||
|
# Shutdown the engine or terminate the server process.
|
||||||
|
if backend == "Engine":
|
||||||
|
engine.shutdown()
|
||||||
|
else:
|
||||||
|
terminate_process(process)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_weights_from_remote_instance(
|
||||||
|
tp_size,
|
||||||
|
dp_size,
|
||||||
|
model_name,
|
||||||
|
backends,
|
||||||
|
truncate_size,
|
||||||
|
checking_parameters,
|
||||||
|
seed_instance_ip,
|
||||||
|
seed_instance_service_port,
|
||||||
|
seed_instance_group_base_port,
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
f"Testing model: {model_name} tp_size: {tp_size}, dp_size: {dp_size} backend: {backends}"
|
||||||
|
)
|
||||||
|
param_queue = mp.Queue()
|
||||||
|
results = {}
|
||||||
|
event_seed_ready = mp.Event()
|
||||||
|
event_dst_ready_list = []
|
||||||
|
for i in range(dp_size):
|
||||||
|
event_dst_ready = mp.Event()
|
||||||
|
event_dst_ready_list.append(event_dst_ready)
|
||||||
|
|
||||||
|
context = mp.spawn(
|
||||||
|
init_process,
|
||||||
|
args=(
|
||||||
|
param_queue,
|
||||||
|
truncate_size,
|
||||||
|
tp_size,
|
||||||
|
model_name,
|
||||||
|
backends,
|
||||||
|
checking_parameters,
|
||||||
|
seed_instance_ip,
|
||||||
|
seed_instance_service_port,
|
||||||
|
seed_instance_group_base_port,
|
||||||
|
event_seed_ready,
|
||||||
|
event_dst_ready_list,
|
||||||
|
),
|
||||||
|
nprocs=1 + dp_size,
|
||||||
|
join=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
while len(results) < (1 + dp_size):
|
||||||
|
try:
|
||||||
|
key, value = param_queue.get(timeout=5)
|
||||||
|
results[key] = value
|
||||||
|
except Exception as e:
|
||||||
|
if all(not p.is_alive() for p in context.processes):
|
||||||
|
break
|
||||||
|
|
||||||
|
context.join()
|
||||||
|
|
||||||
|
if len(results) != (1 + dp_size):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Expected {(1 + dp_size)} parameters but got {len(results)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"seed": results.get("seed_params"),
|
||||||
|
"sgl_dp_1_dest": results.get("sgl_dp_1_dst_params"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if dp_size == 2:
|
||||||
|
dp2_params = {
|
||||||
|
"sgl_dp_2_dest": results.get("sgl_dp_2_dst_params"),
|
||||||
|
}
|
||||||
|
assert all(v is not None for v in dp2_params.values())
|
||||||
|
params.update(dp2_params)
|
||||||
|
|
||||||
|
# Check the correctness of weights loaded from remote instance
|
||||||
|
# by verifying the weights of seed instance and destination instance.
|
||||||
|
for i in range(len(params["seed"])):
|
||||||
|
verify_params_close(
|
||||||
|
params["seed"][i],
|
||||||
|
params["sgl_dp_1_dest"][i],
|
||||||
|
f"sgl_dp_1_dst_params rank {i}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if dp_size == 2:
|
||||||
|
verify_params_close(
|
||||||
|
params["seed"][i],
|
||||||
|
params["sgl_dp_2_dest"][i],
|
||||||
|
f"sgl_dp_2_dst_params rank {i}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete the context and close the parameter queue.
|
||||||
|
del context
|
||||||
|
param_queue.close()
|
||||||
|
param_queue.join_thread()
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoadWeightsFromRemoteInstance(CustomTestCase):
|
||||||
|
|
||||||
|
def test_load_weights_from_remote_instance(self):
|
||||||
|
|
||||||
|
assert torch.cuda.device_count() >= 2, "At least 2 GPUs are required"
|
||||||
|
# test_suits : tp, dp, model_name, backend, dst_instance_id
|
||||||
|
if is_in_ci():
|
||||||
|
mode = random.choice(["Engine", "Server"])
|
||||||
|
test_suits = [
|
||||||
|
(1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, [mode]),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
test_suits = [
|
||||||
|
(1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, ["Engine"]),
|
||||||
|
(1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, ["Sever"]),
|
||||||
|
(2, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, ["Engine", "Server"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
truncate_size = 10
|
||||||
|
checking_parameters = [
|
||||||
|
"model.embed_tokens.weight",
|
||||||
|
"model.layers.0.input_layernorm.weight",
|
||||||
|
"model.layers.1.self_attn.q_proj.weight",
|
||||||
|
"model.layers.2.self_attn.k_proj.weight",
|
||||||
|
"model.layers.3.self_attn.v_proj.weight",
|
||||||
|
"model.layers.4.self_attn.o_proj.weight",
|
||||||
|
"model.layers.5.mlp.gate_proj.weight",
|
||||||
|
"model.layers.6.mlp.up_proj.weight",
|
||||||
|
"model.layers.7.mlp.down_proj.weight",
|
||||||
|
"model.layers.8.post_attention_layernorm.weight",
|
||||||
|
"model.norm.weight",
|
||||||
|
]
|
||||||
|
|
||||||
|
for tp_size, dp_size, model_name, backends in test_suits:
|
||||||
|
test_load_weights_from_remote_instance(
|
||||||
|
tp_size,
|
||||||
|
dp_size,
|
||||||
|
model_name,
|
||||||
|
backends,
|
||||||
|
truncate_size,
|
||||||
|
checking_parameters,
|
||||||
|
"127.0.0.1",
|
||||||
|
DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000,
|
||||||
|
60000,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user