From 2bdaf482f95e375f46bf3eeaa8870b36c34ba420 Mon Sep 17 00:00:00 2001 From: amysaq2023 <130625925+amysaq2023@users.noreply.github.com> Date: Sat, 27 Sep 2025 06:25:39 +0800 Subject: [PATCH] refactor loading weights from remote instance coding format (#10941) Signed-off-by: Anqi Shen --- python/sglang/srt/configs/load_config.py | 4 ++++ python/sglang/srt/configs/model_config.py | 21 ------------------- python/sglang/srt/managers/tp_worker.py | 1 - .../sglang/srt/model_executor/model_runner.py | 10 ++++++--- python/sglang/srt/model_loader/loader.py | 19 +++++++++-------- .../remote_instance_weight_loader_utils.py | 0 6 files changed, 21 insertions(+), 34 deletions(-) rename python/sglang/srt/{ => model_loader}/remote_instance_weight_loader_utils.py (100%) diff --git a/python/sglang/srt/configs/load_config.py b/python/sglang/srt/configs/load_config.py index c734bd2e6..fb8be846b 100644 --- a/python/sglang/srt/configs/load_config.py +++ b/python/sglang/srt/configs/load_config.py @@ -58,6 +58,10 @@ class LoadConfig: ignore_patterns: Optional[Union[List[str], str]] = None decryption_key_file: Optional[str] = None decrypt_max_concurrency: int = -1 + tp_rank: Optional[int] = None + remote_instance_weight_loader_seed_instance_ip: Optional[str] = None + remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None + remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None def __post_init__(self): model_loader_extra_config = self.model_loader_extra_config or {} diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index ae8cc0282..92d0e130f 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -64,12 +64,6 @@ class ModelConfig: is_draft_model: bool = False, hybrid_kvcache_ratio: Optional[float] = None, model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, - tp_rank: Optional[int] = None, - remote_instance_weight_loader_seed_instance_ip: Optional[str] = None, - remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None, - remote_instance_weight_loader_send_weights_group_ports: Optional[ - List[int] - ] = None, ) -> None: # Parse args self.model_path = model_path @@ -78,18 +72,6 @@ class ModelConfig: self.is_draft_model = is_draft_model self.model_impl = model_impl - # TODO: remove these fields - self.tp_rank = tp_rank - self.remote_instance_weight_loader_seed_instance_ip = ( - remote_instance_weight_loader_seed_instance_ip - ) - self.remote_instance_weight_loader_seed_instance_service_port = ( - remote_instance_weight_loader_seed_instance_service_port - ) - self.remote_instance_weight_loader_send_weights_group_ports = ( - remote_instance_weight_loader_send_weights_group_ports - ) - # Get hf config self._maybe_pull_model_tokenizer_from_remote() self.model_override_args = json.loads(model_override_args) @@ -204,9 +186,6 @@ class ModelConfig: quantization=server_args.quantization, hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio, model_impl=server_args.model_impl, - remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip, - remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port, - remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports, **kwargs, ) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 6453b5675..0d3f76658 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -91,7 +91,6 @@ class TpModelWorker: else server_args.speculative_draft_model_revision ), is_draft_model=is_draft_worker, - tp_rank=tp_rank, ) self.model_runner = ModelRunner( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d053e2bb8..bd5461397 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -104,6 +104,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner from sglang.srt.model_loader import get_model from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader +from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( + trigger_init_weights_send_group_for_remote_instance_request, +) from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.offloader import ( @@ -112,9 +115,6 @@ from sglang.srt.offloader import ( set_offloader, ) from sglang.srt.patch_torch import monkey_patch_torch_reductions -from sglang.srt.remote_instance_weight_loader_utils import ( - trigger_init_weights_send_group_for_remote_instance_request, -) from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm @@ -743,6 +743,10 @@ class ModelRunner: load_format=self.server_args.load_format, download_dir=self.server_args.download_dir, model_loader_extra_config=self.server_args.model_loader_extra_config, + tp_rank=self.tp_rank, + remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip, + remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port, + remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports, ) if self.device == "cpu": self.model_config = adjust_config_with_unaligned_cpu_tp( diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index e5bf320be..12b4575f9 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -54,6 +54,9 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) +from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( + trigger_transferring_weights_request, +) from sglang.srt.model_loader.utils import ( get_model_architecture, post_load_weights, @@ -77,9 +80,6 @@ from sglang.srt.model_loader.weight_utils import ( safetensors_weights_iterator, set_runai_streamer_env, ) -from sglang.srt.remote_instance_weight_loader_utils import ( - trigger_transferring_weights_request, -) from sglang.srt.utils import ( get_bool_env_var, get_device_capability, @@ -1420,7 +1420,7 @@ class RemoteInstanceModelLoader(BaseModelLoader): f"load format {load_config.load_format}" ) - model_weights = f"instance://{model_config.remote_instance_weight_loader_seed_instance_ip}:{model_config.remote_instance_weight_loader_send_weights_group_ports[model_config.tp_rank]}" + model_weights = f"instance://{load_config.remote_instance_weight_loader_seed_instance_ip}:{load_config.remote_instance_weight_loader_send_weights_group_ports[load_config.tp_rank]}" with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): @@ -1442,11 +1442,12 @@ class RemoteInstanceModelLoader(BaseModelLoader): def load_model_from_remote_instance( self, model, client, model_config: ModelConfig, device_config: DeviceConfig ) -> nn.Module: + load_config = self.load_config instance_ip = socket.gethostbyname(socket.gethostname()) start_build_group_tic = time.time() client.build_group( gpu_id=device_config.gpu_id, - tp_rank=model_config.tp_rank, + tp_rank=load_config.tp_rank, instance_ip=instance_ip, ) torch.cuda.synchronize() @@ -1455,13 +1456,13 @@ class RemoteInstanceModelLoader(BaseModelLoader): f"finish building group for remote instance, time used: {(end_build_group_tic - start_build_group_tic):.4f}s" ) - if model_config.tp_rank == 0: + if load_config.tp_rank == 0: t = threading.Thread( target=trigger_transferring_weights_request, args=( - model_config.remote_instance_weight_loader_seed_instance_ip, - model_config.remote_instance_weight_loader_seed_instance_service_port, - model_config.remote_instance_weight_loader_send_weights_group_ports, + load_config.remote_instance_weight_loader_seed_instance_ip, + load_config.remote_instance_weight_loader_seed_instance_service_port, + load_config.remote_instance_weight_loader_send_weights_group_ports, instance_ip, ), ) diff --git a/python/sglang/srt/remote_instance_weight_loader_utils.py b/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py similarity index 100% rename from python/sglang/srt/remote_instance_weight_loader_utils.py rename to python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py