refactor loading weights from remote instance coding format (#10941)
Signed-off-by: Anqi Shen <amy.saq@antgroup.com>
This commit is contained in:
@@ -58,6 +58,10 @@ class LoadConfig:
|
|||||||
ignore_patterns: Optional[Union[List[str], str]] = None
|
ignore_patterns: Optional[Union[List[str], str]] = None
|
||||||
decryption_key_file: Optional[str] = None
|
decryption_key_file: Optional[str] = None
|
||||||
decrypt_max_concurrency: int = -1
|
decrypt_max_concurrency: int = -1
|
||||||
|
tp_rank: Optional[int] = None
|
||||||
|
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
|
||||||
|
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
|
||||||
|
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
model_loader_extra_config = self.model_loader_extra_config or {}
|
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||||
|
|||||||
@@ -64,12 +64,6 @@ class ModelConfig:
|
|||||||
is_draft_model: bool = False,
|
is_draft_model: bool = False,
|
||||||
hybrid_kvcache_ratio: Optional[float] = None,
|
hybrid_kvcache_ratio: Optional[float] = None,
|
||||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||||
tp_rank: Optional[int] = None,
|
|
||||||
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None,
|
|
||||||
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None,
|
|
||||||
remote_instance_weight_loader_send_weights_group_ports: Optional[
|
|
||||||
List[int]
|
|
||||||
] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
# Parse args
|
# Parse args
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
@@ -78,18 +72,6 @@ class ModelConfig:
|
|||||||
self.is_draft_model = is_draft_model
|
self.is_draft_model = is_draft_model
|
||||||
self.model_impl = model_impl
|
self.model_impl = model_impl
|
||||||
|
|
||||||
# TODO: remove these fields
|
|
||||||
self.tp_rank = tp_rank
|
|
||||||
self.remote_instance_weight_loader_seed_instance_ip = (
|
|
||||||
remote_instance_weight_loader_seed_instance_ip
|
|
||||||
)
|
|
||||||
self.remote_instance_weight_loader_seed_instance_service_port = (
|
|
||||||
remote_instance_weight_loader_seed_instance_service_port
|
|
||||||
)
|
|
||||||
self.remote_instance_weight_loader_send_weights_group_ports = (
|
|
||||||
remote_instance_weight_loader_send_weights_group_ports
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get hf config
|
# Get hf config
|
||||||
self._maybe_pull_model_tokenizer_from_remote()
|
self._maybe_pull_model_tokenizer_from_remote()
|
||||||
self.model_override_args = json.loads(model_override_args)
|
self.model_override_args = json.loads(model_override_args)
|
||||||
@@ -204,9 +186,6 @@ class ModelConfig:
|
|||||||
quantization=server_args.quantization,
|
quantization=server_args.quantization,
|
||||||
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
||||||
model_impl=server_args.model_impl,
|
model_impl=server_args.model_impl,
|
||||||
remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
|
|
||||||
remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
|
|
||||||
remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -91,7 +91,6 @@ class TpModelWorker:
|
|||||||
else server_args.speculative_draft_model_revision
|
else server_args.speculative_draft_model_revision
|
||||||
),
|
),
|
||||||
is_draft_model=is_draft_worker,
|
is_draft_model=is_draft_worker,
|
||||||
tp_rank=tp_rank,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_runner = ModelRunner(
|
self.model_runner = ModelRunner(
|
||||||
|
|||||||
@@ -104,6 +104,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
|
|||||||
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
|
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
|
||||||
from sglang.srt.model_loader import get_model
|
from sglang.srt.model_loader import get_model
|
||||||
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
|
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
|
||||||
|
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
|
||||||
|
trigger_init_weights_send_group_for_remote_instance_request,
|
||||||
|
)
|
||||||
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.offloader import (
|
from sglang.srt.offloader import (
|
||||||
@@ -112,9 +115,6 @@ from sglang.srt.offloader import (
|
|||||||
set_offloader,
|
set_offloader,
|
||||||
)
|
)
|
||||||
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
||||||
from sglang.srt.remote_instance_weight_loader_utils import (
|
|
||||||
trigger_init_weights_send_group_for_remote_instance_request,
|
|
||||||
)
|
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
@@ -743,6 +743,10 @@ class ModelRunner:
|
|||||||
load_format=self.server_args.load_format,
|
load_format=self.server_args.load_format,
|
||||||
download_dir=self.server_args.download_dir,
|
download_dir=self.server_args.download_dir,
|
||||||
model_loader_extra_config=self.server_args.model_loader_extra_config,
|
model_loader_extra_config=self.server_args.model_loader_extra_config,
|
||||||
|
tp_rank=self.tp_rank,
|
||||||
|
remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
|
||||||
|
remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
|
||||||
|
remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
|
||||||
)
|
)
|
||||||
if self.device == "cpu":
|
if self.device == "cpu":
|
||||||
self.model_config = adjust_config_with_unaligned_cpu_tp(
|
self.model_config = adjust_config_with_unaligned_cpu_tp(
|
||||||
|
|||||||
@@ -54,6 +54,9 @@ from sglang.srt.distributed import (
|
|||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
|
||||||
|
trigger_transferring_weights_request,
|
||||||
|
)
|
||||||
from sglang.srt.model_loader.utils import (
|
from sglang.srt.model_loader.utils import (
|
||||||
get_model_architecture,
|
get_model_architecture,
|
||||||
post_load_weights,
|
post_load_weights,
|
||||||
@@ -77,9 +80,6 @@ from sglang.srt.model_loader.weight_utils import (
|
|||||||
safetensors_weights_iterator,
|
safetensors_weights_iterator,
|
||||||
set_runai_streamer_env,
|
set_runai_streamer_env,
|
||||||
)
|
)
|
||||||
from sglang.srt.remote_instance_weight_loader_utils import (
|
|
||||||
trigger_transferring_weights_request,
|
|
||||||
)
|
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_device_capability,
|
get_device_capability,
|
||||||
@@ -1420,7 +1420,7 @@ class RemoteInstanceModelLoader(BaseModelLoader):
|
|||||||
f"load format {load_config.load_format}"
|
f"load format {load_config.load_format}"
|
||||||
)
|
)
|
||||||
|
|
||||||
model_weights = f"instance://{model_config.remote_instance_weight_loader_seed_instance_ip}:{model_config.remote_instance_weight_loader_send_weights_group_ports[model_config.tp_rank]}"
|
model_weights = f"instance://{load_config.remote_instance_weight_loader_seed_instance_ip}:{load_config.remote_instance_weight_loader_send_weights_group_ports[load_config.tp_rank]}"
|
||||||
|
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with torch.device(device_config.device):
|
with torch.device(device_config.device):
|
||||||
@@ -1442,11 +1442,12 @@ class RemoteInstanceModelLoader(BaseModelLoader):
|
|||||||
def load_model_from_remote_instance(
|
def load_model_from_remote_instance(
|
||||||
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
|
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
|
load_config = self.load_config
|
||||||
instance_ip = socket.gethostbyname(socket.gethostname())
|
instance_ip = socket.gethostbyname(socket.gethostname())
|
||||||
start_build_group_tic = time.time()
|
start_build_group_tic = time.time()
|
||||||
client.build_group(
|
client.build_group(
|
||||||
gpu_id=device_config.gpu_id,
|
gpu_id=device_config.gpu_id,
|
||||||
tp_rank=model_config.tp_rank,
|
tp_rank=load_config.tp_rank,
|
||||||
instance_ip=instance_ip,
|
instance_ip=instance_ip,
|
||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@@ -1455,13 +1456,13 @@ class RemoteInstanceModelLoader(BaseModelLoader):
|
|||||||
f"finish building group for remote instance, time used: {(end_build_group_tic - start_build_group_tic):.4f}s"
|
f"finish building group for remote instance, time used: {(end_build_group_tic - start_build_group_tic):.4f}s"
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_config.tp_rank == 0:
|
if load_config.tp_rank == 0:
|
||||||
t = threading.Thread(
|
t = threading.Thread(
|
||||||
target=trigger_transferring_weights_request,
|
target=trigger_transferring_weights_request,
|
||||||
args=(
|
args=(
|
||||||
model_config.remote_instance_weight_loader_seed_instance_ip,
|
load_config.remote_instance_weight_loader_seed_instance_ip,
|
||||||
model_config.remote_instance_weight_loader_seed_instance_service_port,
|
load_config.remote_instance_weight_loader_seed_instance_service_port,
|
||||||
model_config.remote_instance_weight_loader_send_weights_group_ports,
|
load_config.remote_instance_weight_loader_send_weights_group_ports,
|
||||||
instance_ip,
|
instance_ip,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user