[BugFix] NetLoader: No backend type associated with device type npu (#5700)
**What this PR does / why we need it?**
This PR fixes a bug in NetLoader
[PR#2888](https://github.com/vllm-project/vllm-ascend/pull/2888). The
bug was caused by
[PR#3612](https://github.com/vllm-project/vllm-ascend/pull/3612)
([1/N][Refactor] Refactor code to adapt with vllm main), which removed
the `stateless_init_device_torch_dist_pg` function from platform.py,
leading to a failure in the call. This PR adds a way to create a
stateless process group that does not depend on external code.
**Does this PR introduce any user-facing change?**
No
**How was this patch tested?**
Same with
[PR#2888](https://github.com/vllm-project/vllm-ascend/pull/2888)
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: destinysky <kangrui10@126.com>
This commit is contained in:
@@ -16,11 +16,11 @@
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.distributed.utils import (
|
||||
stateless_destroy_torch_distributed_process_group,
|
||||
stateless_init_torch_distributed_process_group)
|
||||
from vllm.logger import logger
|
||||
|
||||
from .netloader_pg import (destroy_stateless_process_group,
|
||||
stateless_init_process_group)
|
||||
|
||||
|
||||
class P2PLoad:
|
||||
"""
|
||||
@@ -62,12 +62,12 @@ class P2PLoad:
|
||||
receiver_pg = None
|
||||
loaded_model = None
|
||||
try:
|
||||
receiver_pg = stateless_init_torch_distributed_process_group(
|
||||
receiver_pg = stateless_init_process_group(
|
||||
host=self.world_name.split(":")[0],
|
||||
port=self.source_port,
|
||||
rank=0,
|
||||
world_size=2,
|
||||
backend='hccl',
|
||||
group_name='netloader',
|
||||
)
|
||||
logger.info(
|
||||
f"Finish init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
|
||||
@@ -97,7 +97,7 @@ class P2PLoad:
|
||||
logger.error("Failed to recv model: {}".format(e))
|
||||
finally:
|
||||
if receiver_pg:
|
||||
stateless_destroy_torch_distributed_process_group(receiver_pg)
|
||||
destroy_stateless_process_group(receiver_pg)
|
||||
return loaded_model
|
||||
|
||||
|
||||
@@ -134,12 +134,12 @@ class P2PSend:
|
||||
)
|
||||
sender_pg = None
|
||||
try:
|
||||
sender_pg = stateless_init_torch_distributed_process_group(
|
||||
sender_pg = stateless_init_process_group(
|
||||
host=self.comm_name.split(":")[0],
|
||||
port=self.listen_port,
|
||||
rank=1,
|
||||
world_size=2,
|
||||
backend='hccl',
|
||||
group_name='netloader',
|
||||
)
|
||||
logger.info(
|
||||
f"Finish init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
|
||||
@@ -167,4 +167,4 @@ class P2PSend:
|
||||
)
|
||||
finally:
|
||||
if sender_pg:
|
||||
stateless_destroy_torch_distributed_process_group(sender_pg)
|
||||
destroy_stateless_process_group(sender_pg)
|
||||
188
vllm_ascend/model_loader/netloader/executor/netloader_pg.py
Normal file
188
vllm_ascend/model_loader/netloader/executor/netloader_pg.py
Normal file
@@ -0,0 +1,188 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import gc
|
||||
import ipaddress
|
||||
from datetime import timedelta
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch._C._distributed_c10d import (_DEFAULT_PG_TIMEOUT,
|
||||
_register_process_group,
|
||||
_unregister_process_group)
|
||||
from torch.distributed import ProcessGroup, is_hccl_available
|
||||
from torch.distributed.distributed_c10d import (Backend, BackendConfig,
|
||||
PrefixStore, _world)
|
||||
from torch.distributed.rendezvous import rendezvous
|
||||
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
def stateless_init_process_group(
|
||||
host: str,
|
||||
port: int,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
timeout: timedelta = _DEFAULT_PG_TIMEOUT,
|
||||
group_name: str = "",
|
||||
pg_options: Optional[Any] = None,
|
||||
) -> ProcessGroup:
|
||||
"""
|
||||
Initializes a stateless process group.
|
||||
|
||||
Args:
|
||||
host: Hostname.
|
||||
port: Port number.
|
||||
world_size: Size of the process group.
|
||||
rank: Rank of the current process.
|
||||
timeout: Timeout duration, defaults to _DEFAULT_PG_TIMEOUT.
|
||||
group_name: Name of the process group, defaults to an empty string.
|
||||
pg_options: Options for the process group, defaults to None.
|
||||
|
||||
Returns:
|
||||
ProcessGroup: The initialized process group.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If world_size is not positive, or if rank is not within [0, world_size - 1], or if HCCL is unavailable.
|
||||
TypeError: If timeout is not a timedelta type.
|
||||
ValueError: If group_name already exists.
|
||||
"""
|
||||
|
||||
# Check if world_size is positive
|
||||
if not world_size > 0:
|
||||
raise RuntimeError("world_size must be positive")
|
||||
# Check if rank is within [0, world_size - 1]
|
||||
if not (rank >= 0 and rank <= world_size - 1):
|
||||
raise RuntimeError(
|
||||
"rank should be a number between 0 and ``world_size``-1")
|
||||
# Check if HCCL is available
|
||||
if not is_hccl_available():
|
||||
raise RuntimeError("HCCL is not available")
|
||||
# Check if timeout is a timedelta type
|
||||
if not isinstance(timeout, timedelta):
|
||||
raise TypeError(
|
||||
f"Expected timeout argument to be of type datetime.timedelta, got {timeout}"
|
||||
)
|
||||
# Check if group_name already exists
|
||||
if group_name in _world.pg_names.values():
|
||||
raise ValueError(
|
||||
f"The specified group name {group_name} has already been "
|
||||
"created, please use a different group name")
|
||||
|
||||
# Function to check if an IPv6 address is valid
|
||||
def is_valid_ipv6_address(address: str) -> bool:
|
||||
try:
|
||||
ipaddress.IPv6Address(address)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
# Function to get TCP URI
|
||||
def get_tcp_uri(ip: str, port: int) -> str:
|
||||
if is_valid_ipv6_address(ip):
|
||||
return f"tcp://[{ip}]:{port}"
|
||||
else:
|
||||
return f"tcp://{ip}:{port}"
|
||||
|
||||
# Get initialization method
|
||||
init_method = get_tcp_uri(host, port)
|
||||
# Create Backend object
|
||||
backend = Backend('hccl')
|
||||
# Use rendezvous function to get store, rank, and world_size
|
||||
store, rank, world_size = next(
|
||||
rendezvous(init_method, rank, world_size, timeout=timeout))
|
||||
|
||||
# Set timeout for store
|
||||
store.set_timeout(timeout)
|
||||
# Create PrefixStore object
|
||||
prefix_store = PrefixStore(f"{init_method}/{group_name}/", store)
|
||||
# Set group_rank and group_size
|
||||
group_rank = rank
|
||||
group_size = world_size
|
||||
# Create ProcessGroup object
|
||||
pg: ProcessGroup = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
)
|
||||
# Create BackendConfig object
|
||||
backend_config = BackendConfig(backend)
|
||||
# Set default backend for ProcessGroup
|
||||
pg._set_default_backend(Backend.backend_type_map[backend])
|
||||
|
||||
# Check if pg_options is None or not of type ProcessGroupHCCL.Options
|
||||
if pg_options is None or not isinstance(
|
||||
pg_options,
|
||||
torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options):
|
||||
pg_options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options()
|
||||
# Set attributes for pg_options
|
||||
pg_options.is_high_priority_stream = False
|
||||
pg_options._timeout = timeout
|
||||
pg_options.global_ranks_in_group = []
|
||||
pg_options.group_id = f"{init_method}/{group_name}/"
|
||||
# Create ProcessGroupHCCL object
|
||||
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
|
||||
pg_options)
|
||||
# Set sequence number for backend_class
|
||||
backend_class._set_sequence_number_for_group()
|
||||
# Set backend_type
|
||||
backend_type = ProcessGroup.BackendType.CUSTOM
|
||||
# Register backend
|
||||
pg._register_backend(torch.device("npu"), backend_type, backend_class)
|
||||
|
||||
# Set group_desc and pg_tag
|
||||
group_desc = "undefined"
|
||||
assert group_name is not None
|
||||
assert group_desc is not None
|
||||
pg._set_group_name(group_name)
|
||||
pg._set_group_desc(group_desc)
|
||||
|
||||
# Update attributes in _world
|
||||
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
|
||||
_world.pg_map[pg] = (backend, prefix_store)
|
||||
_world.pg_names[pg] = group_name
|
||||
_register_process_group(group_name, pg)
|
||||
_world.pg_backend_config[pg] = str(backend_config)
|
||||
return pg
|
||||
|
||||
|
||||
def destroy_stateless_process_group(pg: ProcessGroup, manual_gc: bool = False):
|
||||
"""
|
||||
Destroy a stateless process group.
|
||||
|
||||
Args:
|
||||
pg: Process group to be destroyed.
|
||||
manual_gc: Whether to manually perform garbage collection, defaults to False.
|
||||
"""
|
||||
# Shutdown the process group
|
||||
pg.shutdown()
|
||||
# Remove related attributes from _world
|
||||
_world.pg_map.pop(pg, None)
|
||||
_world.pg_names.pop(pg, None)
|
||||
_world.pg_group_ranks.pop(pg, None)
|
||||
_world.pg_backend_config.pop(pg, None)
|
||||
# Check if pg is in keys of _world.pg_coalesce_state
|
||||
if pg in _world.pg_coalesce_state.keys():
|
||||
logger.warning("Some coalesced collectives haven't been launched when "
|
||||
"ProcessGroup is destroyed. They will be cleaned.")
|
||||
del _world.pg_coalesce_state[pg]
|
||||
# Unregister the process group
|
||||
_unregister_process_group(pg.group_name)
|
||||
|
||||
# If manual_gc is True, perform garbage collection
|
||||
if manual_gc:
|
||||
gc.collect()
|
||||
Reference in New Issue
Block a user