85 lines
3.0 KiB
Python
85 lines
3.0 KiB
Python
|
|
#
|
||
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||
|
|
#
|
||
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
|
# you may not use this file except in compliance with the License.
|
||
|
|
# You may obtain a copy of the License at
|
||
|
|
#
|
||
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
#
|
||
|
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
|
# See the License for the specific language governing permissions and
|
||
|
|
# limitations under the License.
|
||
|
|
#
|
||
|
|
|
||
|
|
import time
|
||
|
|
|
||
|
|
from vllm.logger import logger
|
||
|
|
|
||
|
|
from .executor.elastic_load import P2PLoad
|
||
|
|
from .interaction.elastic import ElasticClient
|
||
|
|
|
||
|
|
|
||
|
|
def elastic_load(
|
||
|
|
model,
|
||
|
|
device_id: int,
|
||
|
|
model_path: str,
|
||
|
|
sources: list,
|
||
|
|
tp: int,
|
||
|
|
pp: int,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Loads a model using elastic loading across multiple devices.
|
||
|
|
|
||
|
|
Parameters:
|
||
|
|
- model: The model instance to be loaded.
|
||
|
|
- device_id: The ID of the current device (i.e. global rank).
|
||
|
|
- model_path: The path to the model file.
|
||
|
|
- sources: A list of source configurations, each containing device_id and sources.
|
||
|
|
- tp: Tensor parallel size, indicating the number of devices for tensor parallelism.
|
||
|
|
- pp: Pipeline parallel size, indicating the number of devices for pipeline parallelism.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
- The loaded model if successful, otherwise None.
|
||
|
|
"""
|
||
|
|
|
||
|
|
# Filter sources for the current device
|
||
|
|
sources_this_device = []
|
||
|
|
for s in sources:
|
||
|
|
if isinstance(
|
||
|
|
s, dict
|
||
|
|
) and "device_id" in s and s["device_id"] == device_id and isinstance(
|
||
|
|
s["sources"], list):
|
||
|
|
sources_this_device += s["sources"]
|
||
|
|
if len(sources_this_device) == 0:
|
||
|
|
return None
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Initialize the interaction layer with the ElasticClient
|
||
|
|
with ElasticClient(sources_this_device, device_id, model_path, tp,
|
||
|
|
pp) as client_interaction_layer:
|
||
|
|
if client_interaction_layer.s is None or client_interaction_layer.server_addr is None:
|
||
|
|
raise RuntimeError(
|
||
|
|
"Failed to initialize ElasticClient: socket or server_addr is None"
|
||
|
|
)
|
||
|
|
ack = client_interaction_layer.ack
|
||
|
|
if ack is None:
|
||
|
|
raise RuntimeError("ElasticClient.register did not return ack")
|
||
|
|
|
||
|
|
t0 = time.perf_counter()
|
||
|
|
elastic_loader = P2PLoad(ack[0],
|
||
|
|
client_interaction_layer.server_addr,
|
||
|
|
ack[1])
|
||
|
|
model_loaded = elastic_loader.load(model=model)
|
||
|
|
if model_loaded is None:
|
||
|
|
logger.error("Failed to load model")
|
||
|
|
return None
|
||
|
|
logger.info("Finish elastic load (duration: {}s)".format(
|
||
|
|
time.perf_counter() - t0))
|
||
|
|
return model_loaded
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"elastic_load error: {e}")
|
||
|
|
return None
|