Support loading weights from remote instance (#8215)
Signed-off-by: Anqi Shen <amy.saq@antgroup.com> Co-authored-by: Chayenne <74843776+zhaochenyang20@users.noreply.github.com>
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import builtins
|
||||
import ctypes
|
||||
@@ -1431,6 +1432,7 @@ def init_custom_process_group(
|
||||
store=None,
|
||||
group_name=None,
|
||||
pg_options=None,
|
||||
device_id=None,
|
||||
):
|
||||
from torch.distributed.distributed_c10d import (
|
||||
Backend,
|
||||
@@ -1484,6 +1486,7 @@ def init_custom_process_group(
|
||||
group_name=group_name,
|
||||
**{pg_options_param_name: pg_options},
|
||||
timeout=timeout,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
|
||||
@@ -3046,3 +3049,12 @@ def numa_bind_to_node(node: int):
|
||||
|
||||
libnuma.numa_run_on_node(ctypes.c_int(node))
|
||||
libnuma.numa_set_localalloc()
|
||||
|
||||
|
||||
def json_list_type(value):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Invalid JSON list: {value}. Please provide a valid JSON list."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user