Support loading weights from remote instance (#8215)

Signed-off-by: Anqi Shen <amy.saq@antgroup.com>
Co-authored-by: Chayenne <74843776+zhaochenyang20@users.noreply.github.com>
This commit is contained in:
amysaq2023
2025-09-12 17:40:22 +08:00
committed by GitHub
parent 1b1701f1f7
commit 30d20ce84f
18 changed files with 1042 additions and 6 deletions

View File

@@ -19,10 +19,12 @@ import json
import logging
import os
import random
import socket
import sys
import tempfile
from typing import List, Literal, Optional, Union
from sglang.srt.connector import ConnectorType
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
from sglang.srt.lora.lora_registry import LoRARef
@@ -42,7 +44,9 @@ from sglang.srt.utils import (
is_sm100_supported,
is_triton_kernels_available,
is_valid_ipv6_address,
json_list_type,
nullable_str,
parse_connector_type,
)
from sglang.utils import is_in_ci
@@ -61,6 +65,7 @@ LOAD_FORMAT_CHOICES = [
"bitsandbytes",
"layered",
"remote",
"remote_instance",
]
QUANTIZATION_CHOICES = [
@@ -387,6 +392,11 @@ class ServerArgs:
custom_weight_loader: Optional[List[str]] = None
weight_loader_disable_mmap: bool = False
# Remote instance weight loading
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
# For PD-Multiplexing
enable_pdmux: bool = False
sm_group_num: int = 3
@@ -445,6 +455,7 @@ class ServerArgs:
# Set missing default values
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
if self.served_model_name is None:
self.served_model_name = self.model_path
if self.device is None:
@@ -538,7 +549,8 @@ class ServerArgs:
self.sampling_backend = "pytorch"
# Model-specific adjustments
self.model_specific_adjustments()
if parse_connector_type(self.model_path) != ConnectorType.INSTANCE:
self.model_specific_adjustments()
# Set kernel backends
if self.device == "cpu":
@@ -818,12 +830,19 @@ class ServerArgs:
) and check_gguf_file(self.model_path):
self.quantization = self.load_format = "gguf"
# Model loading
if is_remote_url(self.model_path):
self.load_format = "remote"
if self.custom_weight_loader is None:
self.custom_weight_loader = []
if self.load_format == "remote_instance":
if (
self.remote_instance_weight_loader_seed_instance_ip is None
or self.remote_instance_weight_loader_seed_instance_service_port is None
or self.remote_instance_weight_loader_send_weights_group_ports is None
):
self.load_format = "auto"
# PD disaggregation
if self.disaggregation_mode == "decode":
assert (
@@ -881,6 +900,24 @@ class ServerArgs:
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
required=True,
)
parser.add_argument(
"--remote-instance-weight-loader-seed-instance-ip",
type=str,
default=ServerArgs.remote_instance_weight_loader_seed_instance_ip,
help="The ip of the seed instance for loading weights from remote instance.",
)
parser.add_argument(
"--remote-instance-weight-loader-seed-instance-service-port",
type=int,
default=ServerArgs.remote_instance_weight_loader_seed_instance_service_port,
help="The service port of the seed instance for loading weights from remote instance.",
)
parser.add_argument(
"--remote-instance-weight-loader-send-weights-group-ports",
type=json_list_type,
default=ServerArgs.remote_instance_weight_loader_send_weights_group_ports,
help="The communication group ports for loading weights from remote instance.",
)
parser.add_argument(
"--tokenizer-path",
type=str,