refactor loading weights from remote instance coding format (#10941)

Signed-off-by: Anqi Shen <amy.saq@antgroup.com>
This commit is contained in:
amysaq2023
2025-09-27 06:25:39 +08:00
committed by GitHub
parent 777eb53897
commit 2bdaf482f9
6 changed files with 21 additions and 34 deletions

View File

@@ -58,6 +58,10 @@ class LoadConfig:
ignore_patterns: Optional[Union[List[str], str]] = None
decryption_key_file: Optional[str] = None
decrypt_max_concurrency: int = -1
tp_rank: Optional[int] = None
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}

View File

@@ -64,12 +64,6 @@ class ModelConfig:
is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
tp_rank: Optional[int] = None,
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None,
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None,
remote_instance_weight_loader_send_weights_group_ports: Optional[
List[int]
] = None,
) -> None:
# Parse args
self.model_path = model_path
@@ -78,18 +72,6 @@ class ModelConfig:
self.is_draft_model = is_draft_model
self.model_impl = model_impl
# TODO: remove these fields
self.tp_rank = tp_rank
self.remote_instance_weight_loader_seed_instance_ip = (
remote_instance_weight_loader_seed_instance_ip
)
self.remote_instance_weight_loader_seed_instance_service_port = (
remote_instance_weight_loader_seed_instance_service_port
)
self.remote_instance_weight_loader_send_weights_group_ports = (
remote_instance_weight_loader_send_weights_group_ports
)
# Get hf config
self._maybe_pull_model_tokenizer_from_remote()
self.model_override_args = json.loads(model_override_args)
@@ -204,9 +186,6 @@ class ModelConfig:
quantization=server_args.quantization,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
model_impl=server_args.model_impl,
remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
**kwargs,
)