Support loading weights from remote instance (#8215)

Signed-off-by: Anqi Shen <amy.saq@antgroup.com>
Co-authored-by: Chayenne <74843776+zhaochenyang20@users.noreply.github.com>
This commit is contained in:
amysaq2023
2025-09-12 17:40:22 +08:00
committed by GitHub
parent 1b1701f1f7
commit 30d20ce84f
18 changed files with 1042 additions and 6 deletions

View File

@@ -15,6 +15,7 @@
from __future__ import annotations
import argparse
import asyncio
import builtins
import ctypes
@@ -1431,6 +1432,7 @@ def init_custom_process_group(
store=None,
group_name=None,
pg_options=None,
device_id=None,
):
from torch.distributed.distributed_c10d import (
Backend,
@@ -1484,6 +1486,7 @@ def init_custom_process_group(
group_name=group_name,
**{pg_options_param_name: pg_options},
timeout=timeout,
device_id=device_id,
)
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
@@ -3046,3 +3049,12 @@ def numa_bind_to_node(node: int):
libnuma.numa_run_on_node(ctypes.c_int(node))
libnuma.numa_set_localalloc()
def json_list_type(value):
try:
return json.loads(value)
except json.JSONDecodeError:
raise argparse.ArgumentTypeError(
f"Invalid JSON list: {value}. Please provide a valid JSON list."
)