Signed-off-by: Anqi Shen <amy.saq@antgroup.com> Co-authored-by: Chayenne <74843776+zhaochenyang20@users.noreply.github.com>
70 lines
2.8 KiB
Python
70 lines
2.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import logging
|
|
from typing import List
|
|
|
|
import requests
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def trigger_init_weights_send_group_for_remote_instance_request(
|
|
remote_instance_weight_loader_seed_instance_ip: str,
|
|
remote_instance_weight_loader_seed_instance_service_port: int,
|
|
remote_instance_weight_loader_send_weights_group_ports: List[int],
|
|
remote_instance_weight_loader_client_id: str,
|
|
):
|
|
seed_instance_service_url = f"http://{remote_instance_weight_loader_seed_instance_ip}:{remote_instance_weight_loader_seed_instance_service_port}"
|
|
# Only support loading weights from instance with same parallelism strategy.
|
|
# Per TP rank pair between seed and dst instances will build a communication group for sending weights.
|
|
# i.e. seed TP 0 <-> dst TP 0, seed TP 1 <-> dst TP 1, etc.
|
|
# Each communication group will have a world size 2.
|
|
try:
|
|
requests.post(
|
|
f"{seed_instance_service_url}/init_weights_send_group_for_remote_instance",
|
|
json={
|
|
"master_address": remote_instance_weight_loader_seed_instance_ip,
|
|
"ports": (
|
|
",".join(
|
|
str(p)
|
|
for p in remote_instance_weight_loader_send_weights_group_ports
|
|
)
|
|
),
|
|
"group_rank": 0,
|
|
"world_size": 2,
|
|
"group_name": f"send_weights_{remote_instance_weight_loader_client_id}",
|
|
"backend": "nccl",
|
|
},
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to trigger init_weights_send_group_for_remote_instance_request to seed instance {seed_instance_service_url}: {e}."
|
|
)
|
|
raise
|
|
|
|
|
|
def trigger_transferring_weights_request(
|
|
remote_instance_weight_loader_seed_instance_ip: str,
|
|
remote_instance_weight_loader_seed_instance_service_port: int,
|
|
remote_instance_weight_loader_send_weights_group_ports: List[int],
|
|
remote_instance_weight_loader_client_id: str,
|
|
):
|
|
seed_instance_service_url = f"http://{remote_instance_weight_loader_seed_instance_ip}:{remote_instance_weight_loader_seed_instance_service_port}"
|
|
try:
|
|
requests.post(
|
|
f"{seed_instance_service_url}/send_weights_to_remote_instance",
|
|
json={
|
|
"master_address": remote_instance_weight_loader_seed_instance_ip,
|
|
"ports": (
|
|
",".join(
|
|
str(p)
|
|
for p in remote_instance_weight_loader_send_weights_group_ports
|
|
)
|
|
),
|
|
"group_name": f"send_weights_{remote_instance_weight_loader_client_id}",
|
|
},
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to trigger send weights to remote instance request: {e}")
|
|
raise
|