Support loading weights from remote instance (#8215)
Signed-off-by: Anqi Shen <amy.saq@antgroup.com> Co-authored-by: Chayenne <74843776+zhaochenyang20@users.noreply.github.com>
This commit is contained in:
@@ -19,10 +19,12 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from sglang.srt.connector import ConnectorType
|
||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
|
||||
from sglang.srt.lora.lora_registry import LoRARef
|
||||
@@ -42,7 +44,9 @@ from sglang.srt.utils import (
|
||||
is_sm100_supported,
|
||||
is_triton_kernels_available,
|
||||
is_valid_ipv6_address,
|
||||
json_list_type,
|
||||
nullable_str,
|
||||
parse_connector_type,
|
||||
)
|
||||
from sglang.utils import is_in_ci
|
||||
|
||||
@@ -61,6 +65,7 @@ LOAD_FORMAT_CHOICES = [
|
||||
"bitsandbytes",
|
||||
"layered",
|
||||
"remote",
|
||||
"remote_instance",
|
||||
]
|
||||
|
||||
QUANTIZATION_CHOICES = [
|
||||
@@ -387,6 +392,11 @@ class ServerArgs:
|
||||
custom_weight_loader: Optional[List[str]] = None
|
||||
weight_loader_disable_mmap: bool = False
|
||||
|
||||
# Remote instance weight loading
|
||||
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
|
||||
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
|
||||
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
|
||||
|
||||
# For PD-Multiplexing
|
||||
enable_pdmux: bool = False
|
||||
sm_group_num: int = 3
|
||||
@@ -445,6 +455,7 @@ class ServerArgs:
|
||||
# Set missing default values
|
||||
if self.tokenizer_path is None:
|
||||
self.tokenizer_path = self.model_path
|
||||
|
||||
if self.served_model_name is None:
|
||||
self.served_model_name = self.model_path
|
||||
if self.device is None:
|
||||
@@ -538,7 +549,8 @@ class ServerArgs:
|
||||
self.sampling_backend = "pytorch"
|
||||
|
||||
# Model-specific adjustments
|
||||
self.model_specific_adjustments()
|
||||
if parse_connector_type(self.model_path) != ConnectorType.INSTANCE:
|
||||
self.model_specific_adjustments()
|
||||
|
||||
# Set kernel backends
|
||||
if self.device == "cpu":
|
||||
@@ -818,12 +830,19 @@ class ServerArgs:
|
||||
) and check_gguf_file(self.model_path):
|
||||
self.quantization = self.load_format = "gguf"
|
||||
|
||||
# Model loading
|
||||
if is_remote_url(self.model_path):
|
||||
self.load_format = "remote"
|
||||
if self.custom_weight_loader is None:
|
||||
self.custom_weight_loader = []
|
||||
|
||||
if self.load_format == "remote_instance":
|
||||
if (
|
||||
self.remote_instance_weight_loader_seed_instance_ip is None
|
||||
or self.remote_instance_weight_loader_seed_instance_service_port is None
|
||||
or self.remote_instance_weight_loader_send_weights_group_ports is None
|
||||
):
|
||||
self.load_format = "auto"
|
||||
|
||||
# PD disaggregation
|
||||
if self.disaggregation_mode == "decode":
|
||||
assert (
|
||||
@@ -881,6 +900,24 @@ class ServerArgs:
|
||||
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote-instance-weight-loader-seed-instance-ip",
|
||||
type=str,
|
||||
default=ServerArgs.remote_instance_weight_loader_seed_instance_ip,
|
||||
help="The ip of the seed instance for loading weights from remote instance.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote-instance-weight-loader-seed-instance-service-port",
|
||||
type=int,
|
||||
default=ServerArgs.remote_instance_weight_loader_seed_instance_service_port,
|
||||
help="The service port of the seed instance for loading weights from remote instance.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote-instance-weight-loader-send-weights-group-ports",
|
||||
type=json_list_type,
|
||||
default=ServerArgs.remote_instance_weight_loader_send_weights_group_ports,
|
||||
help="The communication group ports for loading weights from remote instance.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer-path",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user