Files
xc-llm-ascend/vllm_ascend/model_loader/netloader/load.py
SILONG ZENG 4e53c1d900 [Lint]Style: Convert vllm-ascend/ to ruff format(Batch #6) (#6001)
### What this PR does / why we need it?
| File Path |
| :--- |
| ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` |
| ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` |
| ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` |
| ` vllm_ascend/eplb/core/eplb_utils.py` |
| ` vllm_ascend/eplb/core/eplb_worker.py` |
| ` vllm_ascend/eplb/core/policy/policy_abstract.py` |
| ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` |
| ` vllm_ascend/eplb/core/policy/policy_factory.py` |
| ` vllm_ascend/eplb/core/policy/policy_flashlb.py` |
| ` vllm_ascend/eplb/core/policy/policy_random.py` |
| ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` |
| ` vllm_ascend/eplb/eplb_updator.py` |
| ` vllm_ascend/eplb/utils.py` |
| ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` |
| ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` |
| ` vllm_ascend/model_loader/netloader/interaction/elastic.py` |
| ` vllm_ascend/model_loader/netloader/load.py` |
| ` vllm_ascend/model_loader/netloader/netloader.py` |
| ` vllm_ascend/model_loader/netloader/utils.py` |
| ` vllm_ascend/patch/platform/__init__.py` |
| ` vllm_ascend/patch/platform/patch_balance_schedule.py` |
| ` vllm_ascend/patch/platform/patch_ec_connector.py` |
| ` vllm_ascend/patch/platform/patch_mamba_config.py` |
| ` vllm_ascend/patch/platform/patch_multiproc_executor.py` |
| ` vllm_ascend/patch/platform/patch_sched_yield.py` |


- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00

76 lines
2.8 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
from vllm.logger import logger
from .executor.elastic_load import P2PLoad
from .interaction.elastic import ElasticClient
def elastic_load(
model,
device_id: int,
model_path: str,
sources: list,
tp: int,
pp: int,
):
"""
Loads a model using elastic loading across multiple devices.
Parameters:
- model: The model instance to be loaded.
- device_id: The ID of the current device (i.e. global rank).
- model_path: The path to the model file.
- sources: A list of source configurations, each containing device_id and sources.
- tp: Tensor parallel size, indicating the number of devices for tensor parallelism.
- pp: Pipeline parallel size, indicating the number of devices for pipeline parallelism.
Returns:
- The loaded model if successful, otherwise None.
"""
# Filter sources for the current device
sources_this_device = []
for s in sources:
if isinstance(s, dict) and "device_id" in s and s["device_id"] == device_id and isinstance(s["sources"], list):
sources_this_device += s["sources"]
if len(sources_this_device) == 0:
return None
try:
# Initialize the interaction layer with the ElasticClient
with ElasticClient(sources_this_device, device_id, model_path, tp, pp) as client_interaction_layer:
if client_interaction_layer.s is None or client_interaction_layer.server_addr is None:
raise RuntimeError("Failed to initialize ElasticClient: socket or server_addr is None")
ack = client_interaction_layer.ack
if ack is None:
raise RuntimeError("ElasticClient.register did not return ack")
t0 = time.perf_counter()
elastic_loader = P2PLoad(ack[0], client_interaction_layer.server_addr, ack[1])
model_loaded = elastic_loader.load(model=model)
if model_loaded is None:
logger.error("Failed to load model")
return None
logger.info("Finish elastic load (duration: {}s)".format(time.perf_counter() - t0))
return model_loaded
except Exception as e:
logger.info(f"elastic_load error: {e}")
return None