Support loading weights from remote instance (#8215)

Signed-off-by: Anqi Shen <amy.saq@antgroup.com>
Co-authored-by: Chayenne <74843776+zhaochenyang20@users.noreply.github.com>
This commit is contained in:
amysaq2023
2025-09-12 17:40:22 +08:00
committed by GitHub
parent 1b1701f1f7
commit 30d20ce84f
18 changed files with 1042 additions and 6 deletions

View File

@@ -8,10 +8,12 @@ logger = logging.getLogger(__name__)
class DeviceConfig:
device: Optional[torch.device]
gpu_id: Optional[int]
def __init__(self, device: str = "cuda") -> None:
def __init__(self, device: str = "cuda", gpu_id: int = -1) -> None:
if device in ["cuda", "xpu", "hpu", "cpu", "npu"]:
self.device_type = device
else:
raise RuntimeError(f"Not supported device type: {device}")
self.device = torch.device(self.device_type)
self.gpu_id = gpu_id

View File

@@ -23,6 +23,7 @@ class LoadFormat(str, enum.Enum):
LAYERED = "layered"
JAX = "jax"
REMOTE = "remote"
REMOTE_INSTANCE = "remote_instance"
@dataclass

View File

@@ -64,12 +64,28 @@ class ModelConfig:
is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
tp_rank: Optional[int] = None,
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None,
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None,
remote_instance_weight_loader_send_weights_group_ports: Optional[
List[int]
] = None,
) -> None:
# Parse args
self.model_path = model_path
self.revision = revision
self.quantization = quantization
self.model_impl = model_impl
self.tp_rank = tp_rank
self.remote_instance_weight_loader_seed_instance_ip = (
remote_instance_weight_loader_seed_instance_ip
)
self.remote_instance_weight_loader_seed_instance_service_port = (
remote_instance_weight_loader_seed_instance_service_port
)
self.remote_instance_weight_loader_send_weights_group_ports = (
remote_instance_weight_loader_send_weights_group_ports
)
self.maybe_pull_model_tokenizer_from_remote()
self.model_override_args = json.loads(model_override_args)
@@ -329,6 +345,9 @@ class ModelConfig:
quantization=server_args.quantization,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
model_impl=server_args.model_impl,
remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
**kwargs,
)