[BugFix] NetLoader: No backend type associated with device type npu (#5700)
**What this PR does / why we need it?**
This PR fixes a bug in NetLoader
[PR#2888](https://github.com/vllm-project/vllm-ascend/pull/2888). The
bug was caused by
[PR#3612](https://github.com/vllm-project/vllm-ascend/pull/3612)
([1/N][Refactor] Refactor code to adapt with vllm main), which removed
the `stateless_init_device_torch_dist_pg` function from platform.py,
leading to a failure in the call. This PR adds a way to create a
stateless process group that does not depend on external code.
**Does this PR introduce any user-facing change?**
No
**How was this patch tested?**
Same with
[PR#2888](https://github.com/vllm-project/vllm-ascend/pull/2888)
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: destinysky <kangrui10@126.com>
This commit is contained in:
@@ -16,11 +16,11 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.distributed.utils import (
|
|
||||||
stateless_destroy_torch_distributed_process_group,
|
|
||||||
stateless_init_torch_distributed_process_group)
|
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
|
|
||||||
|
from .netloader_pg import (destroy_stateless_process_group,
|
||||||
|
stateless_init_process_group)
|
||||||
|
|
||||||
|
|
||||||
class P2PLoad:
|
class P2PLoad:
|
||||||
"""
|
"""
|
||||||
@@ -62,12 +62,12 @@ class P2PLoad:
|
|||||||
receiver_pg = None
|
receiver_pg = None
|
||||||
loaded_model = None
|
loaded_model = None
|
||||||
try:
|
try:
|
||||||
receiver_pg = stateless_init_torch_distributed_process_group(
|
receiver_pg = stateless_init_process_group(
|
||||||
host=self.world_name.split(":")[0],
|
host=self.world_name.split(":")[0],
|
||||||
port=self.source_port,
|
port=self.source_port,
|
||||||
rank=0,
|
rank=0,
|
||||||
world_size=2,
|
world_size=2,
|
||||||
backend='hccl',
|
group_name='netloader',
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Finish init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
|
f"Finish init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
|
||||||
@@ -97,7 +97,7 @@ class P2PLoad:
|
|||||||
logger.error("Failed to recv model: {}".format(e))
|
logger.error("Failed to recv model: {}".format(e))
|
||||||
finally:
|
finally:
|
||||||
if receiver_pg:
|
if receiver_pg:
|
||||||
stateless_destroy_torch_distributed_process_group(receiver_pg)
|
destroy_stateless_process_group(receiver_pg)
|
||||||
return loaded_model
|
return loaded_model
|
||||||
|
|
||||||
|
|
||||||
@@ -134,12 +134,12 @@ class P2PSend:
|
|||||||
)
|
)
|
||||||
sender_pg = None
|
sender_pg = None
|
||||||
try:
|
try:
|
||||||
sender_pg = stateless_init_torch_distributed_process_group(
|
sender_pg = stateless_init_process_group(
|
||||||
host=self.comm_name.split(":")[0],
|
host=self.comm_name.split(":")[0],
|
||||||
port=self.listen_port,
|
port=self.listen_port,
|
||||||
rank=1,
|
rank=1,
|
||||||
world_size=2,
|
world_size=2,
|
||||||
backend='hccl',
|
group_name='netloader',
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Finish init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
|
f"Finish init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
|
||||||
@@ -167,4 +167,4 @@ class P2PSend:
|
|||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
if sender_pg:
|
if sender_pg:
|
||||||
stateless_destroy_torch_distributed_process_group(sender_pg)
|
destroy_stateless_process_group(sender_pg)
|
||||||
188
vllm_ascend/model_loader/netloader/executor/netloader_pg.py
Normal file
188
vllm_ascend/model_loader/netloader/executor/netloader_pg.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import ipaddress
|
||||||
|
from datetime import timedelta
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_npu
|
||||||
|
from torch._C._distributed_c10d import (_DEFAULT_PG_TIMEOUT,
|
||||||
|
_register_process_group,
|
||||||
|
_unregister_process_group)
|
||||||
|
from torch.distributed import ProcessGroup, is_hccl_available
|
||||||
|
from torch.distributed.distributed_c10d import (Backend, BackendConfig,
|
||||||
|
PrefixStore, _world)
|
||||||
|
from torch.distributed.rendezvous import rendezvous
|
||||||
|
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
|
||||||
|
from vllm.logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
def stateless_init_process_group(
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
world_size: int,
|
||||||
|
rank: int,
|
||||||
|
timeout: timedelta = _DEFAULT_PG_TIMEOUT,
|
||||||
|
group_name: str = "",
|
||||||
|
pg_options: Optional[Any] = None,
|
||||||
|
) -> ProcessGroup:
|
||||||
|
"""
|
||||||
|
Initializes a stateless process group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host: Hostname.
|
||||||
|
port: Port number.
|
||||||
|
world_size: Size of the process group.
|
||||||
|
rank: Rank of the current process.
|
||||||
|
timeout: Timeout duration, defaults to _DEFAULT_PG_TIMEOUT.
|
||||||
|
group_name: Name of the process group, defaults to an empty string.
|
||||||
|
pg_options: Options for the process group, defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ProcessGroup: The initialized process group.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If world_size is not positive, or if rank is not within [0, world_size - 1], or if HCCL is unavailable.
|
||||||
|
TypeError: If timeout is not a timedelta type.
|
||||||
|
ValueError: If group_name already exists.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Check if world_size is positive
|
||||||
|
if not world_size > 0:
|
||||||
|
raise RuntimeError("world_size must be positive")
|
||||||
|
# Check if rank is within [0, world_size - 1]
|
||||||
|
if not (rank >= 0 and rank <= world_size - 1):
|
||||||
|
raise RuntimeError(
|
||||||
|
"rank should be a number between 0 and ``world_size``-1")
|
||||||
|
# Check if HCCL is available
|
||||||
|
if not is_hccl_available():
|
||||||
|
raise RuntimeError("HCCL is not available")
|
||||||
|
# Check if timeout is a timedelta type
|
||||||
|
if not isinstance(timeout, timedelta):
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected timeout argument to be of type datetime.timedelta, got {timeout}"
|
||||||
|
)
|
||||||
|
# Check if group_name already exists
|
||||||
|
if group_name in _world.pg_names.values():
|
||||||
|
raise ValueError(
|
||||||
|
f"The specified group name {group_name} has already been "
|
||||||
|
"created, please use a different group name")
|
||||||
|
|
||||||
|
# Function to check if an IPv6 address is valid
|
||||||
|
def is_valid_ipv6_address(address: str) -> bool:
|
||||||
|
try:
|
||||||
|
ipaddress.IPv6Address(address)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Function to get TCP URI
|
||||||
|
def get_tcp_uri(ip: str, port: int) -> str:
|
||||||
|
if is_valid_ipv6_address(ip):
|
||||||
|
return f"tcp://[{ip}]:{port}"
|
||||||
|
else:
|
||||||
|
return f"tcp://{ip}:{port}"
|
||||||
|
|
||||||
|
# Get initialization method
|
||||||
|
init_method = get_tcp_uri(host, port)
|
||||||
|
# Create Backend object
|
||||||
|
backend = Backend('hccl')
|
||||||
|
# Use rendezvous function to get store, rank, and world_size
|
||||||
|
store, rank, world_size = next(
|
||||||
|
rendezvous(init_method, rank, world_size, timeout=timeout))
|
||||||
|
|
||||||
|
# Set timeout for store
|
||||||
|
store.set_timeout(timeout)
|
||||||
|
# Create PrefixStore object
|
||||||
|
prefix_store = PrefixStore(f"{init_method}/{group_name}/", store)
|
||||||
|
# Set group_rank and group_size
|
||||||
|
group_rank = rank
|
||||||
|
group_size = world_size
|
||||||
|
# Create ProcessGroup object
|
||||||
|
pg: ProcessGroup = ProcessGroup(
|
||||||
|
prefix_store,
|
||||||
|
group_rank,
|
||||||
|
group_size,
|
||||||
|
)
|
||||||
|
# Create BackendConfig object
|
||||||
|
backend_config = BackendConfig(backend)
|
||||||
|
# Set default backend for ProcessGroup
|
||||||
|
pg._set_default_backend(Backend.backend_type_map[backend])
|
||||||
|
|
||||||
|
# Check if pg_options is None or not of type ProcessGroupHCCL.Options
|
||||||
|
if pg_options is None or not isinstance(
|
||||||
|
pg_options,
|
||||||
|
torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options):
|
||||||
|
pg_options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options()
|
||||||
|
# Set attributes for pg_options
|
||||||
|
pg_options.is_high_priority_stream = False
|
||||||
|
pg_options._timeout = timeout
|
||||||
|
pg_options.global_ranks_in_group = []
|
||||||
|
pg_options.group_id = f"{init_method}/{group_name}/"
|
||||||
|
# Create ProcessGroupHCCL object
|
||||||
|
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
|
||||||
|
pg_options)
|
||||||
|
# Set sequence number for backend_class
|
||||||
|
backend_class._set_sequence_number_for_group()
|
||||||
|
# Set backend_type
|
||||||
|
backend_type = ProcessGroup.BackendType.CUSTOM
|
||||||
|
# Register backend
|
||||||
|
pg._register_backend(torch.device("npu"), backend_type, backend_class)
|
||||||
|
|
||||||
|
# Set group_desc and pg_tag
|
||||||
|
group_desc = "undefined"
|
||||||
|
assert group_name is not None
|
||||||
|
assert group_desc is not None
|
||||||
|
pg._set_group_name(group_name)
|
||||||
|
pg._set_group_desc(group_desc)
|
||||||
|
|
||||||
|
# Update attributes in _world
|
||||||
|
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
|
||||||
|
_world.pg_map[pg] = (backend, prefix_store)
|
||||||
|
_world.pg_names[pg] = group_name
|
||||||
|
_register_process_group(group_name, pg)
|
||||||
|
_world.pg_backend_config[pg] = str(backend_config)
|
||||||
|
return pg
|
||||||
|
|
||||||
|
|
||||||
|
def destroy_stateless_process_group(pg: ProcessGroup, manual_gc: bool = False):
|
||||||
|
"""
|
||||||
|
Destroy a stateless process group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pg: Process group to be destroyed.
|
||||||
|
manual_gc: Whether to manually perform garbage collection, defaults to False.
|
||||||
|
"""
|
||||||
|
# Shutdown the process group
|
||||||
|
pg.shutdown()
|
||||||
|
# Remove related attributes from _world
|
||||||
|
_world.pg_map.pop(pg, None)
|
||||||
|
_world.pg_names.pop(pg, None)
|
||||||
|
_world.pg_group_ranks.pop(pg, None)
|
||||||
|
_world.pg_backend_config.pop(pg, None)
|
||||||
|
# Check if pg is in keys of _world.pg_coalesce_state
|
||||||
|
if pg in _world.pg_coalesce_state.keys():
|
||||||
|
logger.warning("Some coalesced collectives haven't been launched when "
|
||||||
|
"ProcessGroup is destroyed. They will be cleaned.")
|
||||||
|
del _world.pg_coalesce_state[pg]
|
||||||
|
# Unregister the process group
|
||||||
|
_unregister_process_group(pg.group_name)
|
||||||
|
|
||||||
|
# If manual_gc is True, perform garbage collection
|
||||||
|
if manual_gc:
|
||||||
|
gc.collect()
|
||||||
Reference in New Issue
Block a user