Support loading weights from remote instance (#8215)
Signed-off-by: Anqi Shen <amy.saq@antgroup.com> Co-authored-by: Chayenne <74843776+zhaochenyang20@users.noreply.github.com>
This commit is contained in:
@@ -8,10 +8,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class DeviceConfig:
|
||||
device: Optional[torch.device]
|
||||
gpu_id: Optional[int]
|
||||
|
||||
def __init__(self, device: str = "cuda") -> None:
|
||||
def __init__(self, device: str = "cuda", gpu_id: int = -1) -> None:
|
||||
if device in ["cuda", "xpu", "hpu", "cpu", "npu"]:
|
||||
self.device_type = device
|
||||
else:
|
||||
raise RuntimeError(f"Not supported device type: {device}")
|
||||
self.device = torch.device(self.device_type)
|
||||
self.gpu_id = gpu_id
|
||||
|
||||
@@ -23,6 +23,7 @@ class LoadFormat(str, enum.Enum):
|
||||
LAYERED = "layered"
|
||||
JAX = "jax"
|
||||
REMOTE = "remote"
|
||||
REMOTE_INSTANCE = "remote_instance"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -64,12 +64,28 @@ class ModelConfig:
|
||||
is_draft_model: bool = False,
|
||||
hybrid_kvcache_ratio: Optional[float] = None,
|
||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||
tp_rank: Optional[int] = None,
|
||||
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None,
|
||||
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None,
|
||||
remote_instance_weight_loader_send_weights_group_ports: Optional[
|
||||
List[int]
|
||||
] = None,
|
||||
) -> None:
|
||||
# Parse args
|
||||
self.model_path = model_path
|
||||
self.revision = revision
|
||||
self.quantization = quantization
|
||||
self.model_impl = model_impl
|
||||
self.tp_rank = tp_rank
|
||||
self.remote_instance_weight_loader_seed_instance_ip = (
|
||||
remote_instance_weight_loader_seed_instance_ip
|
||||
)
|
||||
self.remote_instance_weight_loader_seed_instance_service_port = (
|
||||
remote_instance_weight_loader_seed_instance_service_port
|
||||
)
|
||||
self.remote_instance_weight_loader_send_weights_group_ports = (
|
||||
remote_instance_weight_loader_send_weights_group_ports
|
||||
)
|
||||
|
||||
self.maybe_pull_model_tokenizer_from_remote()
|
||||
self.model_override_args = json.loads(model_override_args)
|
||||
@@ -329,6 +345,9 @@ class ModelConfig:
|
||||
quantization=server_args.quantization,
|
||||
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
||||
model_impl=server_args.model_impl,
|
||||
remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
|
||||
remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
|
||||
remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user