refactor loading weights from remote instance coding format (#10941)
Signed-off-by: Anqi Shen <amy.saq@antgroup.com>
This commit is contained in:
@@ -58,6 +58,10 @@ class LoadConfig:
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None
|
||||
decryption_key_file: Optional[str] = None
|
||||
decrypt_max_concurrency: int = -1
|
||||
tp_rank: Optional[int] = None
|
||||
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
|
||||
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
|
||||
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||
|
||||
@@ -64,12 +64,6 @@ class ModelConfig:
|
||||
is_draft_model: bool = False,
|
||||
hybrid_kvcache_ratio: Optional[float] = None,
|
||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||
tp_rank: Optional[int] = None,
|
||||
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None,
|
||||
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None,
|
||||
remote_instance_weight_loader_send_weights_group_ports: Optional[
|
||||
List[int]
|
||||
] = None,
|
||||
) -> None:
|
||||
# Parse args
|
||||
self.model_path = model_path
|
||||
@@ -78,18 +72,6 @@ class ModelConfig:
|
||||
self.is_draft_model = is_draft_model
|
||||
self.model_impl = model_impl
|
||||
|
||||
# TODO: remove these fields
|
||||
self.tp_rank = tp_rank
|
||||
self.remote_instance_weight_loader_seed_instance_ip = (
|
||||
remote_instance_weight_loader_seed_instance_ip
|
||||
)
|
||||
self.remote_instance_weight_loader_seed_instance_service_port = (
|
||||
remote_instance_weight_loader_seed_instance_service_port
|
||||
)
|
||||
self.remote_instance_weight_loader_send_weights_group_ports = (
|
||||
remote_instance_weight_loader_send_weights_group_ports
|
||||
)
|
||||
|
||||
# Get hf config
|
||||
self._maybe_pull_model_tokenizer_from_remote()
|
||||
self.model_override_args = json.loads(model_override_args)
|
||||
@@ -204,9 +186,6 @@ class ModelConfig:
|
||||
quantization=server_args.quantization,
|
||||
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
||||
model_impl=server_args.model_impl,
|
||||
remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
|
||||
remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
|
||||
remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user