Files
sglang/python/sglang/srt/remote_instance_weight_loader_utils.py
2025-09-12 17:40:22 +08:00

70 lines
2.8 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import logging
from typing import List
import requests
logger = logging.getLogger(__name__)
def trigger_init_weights_send_group_for_remote_instance_request(
remote_instance_weight_loader_seed_instance_ip: str,
remote_instance_weight_loader_seed_instance_service_port: int,
remote_instance_weight_loader_send_weights_group_ports: List[int],
remote_instance_weight_loader_client_id: str,
):
seed_instance_service_url = f"http://{remote_instance_weight_loader_seed_instance_ip}:{remote_instance_weight_loader_seed_instance_service_port}"
# Only support loading weights from instance with same parallelism strategy.
# Per TP rank pair between seed and dst instances will build a communication group for sending weights.
# i.e. seed TP 0 <-> dst TP 0, seed TP 1 <-> dst TP 1, etc.
# Each communication group will have a world size 2.
try:
requests.post(
f"{seed_instance_service_url}/init_weights_send_group_for_remote_instance",
json={
"master_address": remote_instance_weight_loader_seed_instance_ip,
"ports": (
",".join(
str(p)
for p in remote_instance_weight_loader_send_weights_group_ports
)
),
"group_rank": 0,
"world_size": 2,
"group_name": f"send_weights_{remote_instance_weight_loader_client_id}",
"backend": "nccl",
},
)
except Exception as e:
logger.error(
f"Failed to trigger init_weights_send_group_for_remote_instance_request to seed instance {seed_instance_service_url}: {e}."
)
raise
def trigger_transferring_weights_request(
remote_instance_weight_loader_seed_instance_ip: str,
remote_instance_weight_loader_seed_instance_service_port: int,
remote_instance_weight_loader_send_weights_group_ports: List[int],
remote_instance_weight_loader_client_id: str,
):
seed_instance_service_url = f"http://{remote_instance_weight_loader_seed_instance_ip}:{remote_instance_weight_loader_seed_instance_service_port}"
try:
requests.post(
f"{seed_instance_service_url}/send_weights_to_remote_instance",
json={
"master_address": remote_instance_weight_loader_seed_instance_ip,
"ports": (
",".join(
str(p)
for p in remote_instance_weight_loader_send_weights_group_ports
)
),
"group_name": f"send_weights_{remote_instance_weight_loader_client_id}",
},
)
except Exception as e:
logger.error(f"Failed to trigger send weights to remote instance request: {e}")
raise