### What this PR does / why we need it?
| File Path |
| :--- |
| ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` |
| ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` |
| ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` |
| ` vllm_ascend/eplb/core/eplb_utils.py` |
| ` vllm_ascend/eplb/core/eplb_worker.py` |
| ` vllm_ascend/eplb/core/policy/policy_abstract.py` |
| ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` |
| ` vllm_ascend/eplb/core/policy/policy_factory.py` |
| ` vllm_ascend/eplb/core/policy/policy_flashlb.py` |
| ` vllm_ascend/eplb/core/policy/policy_random.py` |
| ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` |
| ` vllm_ascend/eplb/eplb_updator.py` |
| ` vllm_ascend/eplb/utils.py` |
| ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` |
| ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` |
| ` vllm_ascend/model_loader/netloader/interaction/elastic.py` |
| ` vllm_ascend/model_loader/netloader/load.py` |
| ` vllm_ascend/model_loader/netloader/netloader.py` |
| ` vllm_ascend/model_loader/netloader/utils.py` |
| ` vllm_ascend/patch/platform/__init__.py` |
| ` vllm_ascend/patch/platform/patch_balance_schedule.py` |
| ` vllm_ascend/patch/platform/patch_ec_connector.py` |
| ` vllm_ascend/patch/platform/patch_mamba_config.py` |
| ` vllm_ascend/patch/platform/patch_multiproc_executor.py` |
| ` vllm_ascend/patch/platform/patch_sched_yield.py` |
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -18,8 +18,7 @@ import torch
|
||||
import torch_npu
|
||||
from vllm.logger import logger
|
||||
|
||||
from .netloader_pg import (destroy_stateless_process_group,
|
||||
stateless_init_process_group)
|
||||
from .netloader_pg import destroy_stateless_process_group, stateless_init_process_group
|
||||
|
||||
|
||||
class P2PLoad:
|
||||
@@ -56,9 +55,7 @@ class P2PLoad:
|
||||
- The model if loading is successful, otherwise None.
|
||||
"""
|
||||
model_device = next(model.parameters()).device
|
||||
logger.info(
|
||||
f"Start init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
|
||||
)
|
||||
logger.info(f"Start init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}")
|
||||
receiver_pg = None
|
||||
loaded_model = None
|
||||
try:
|
||||
@@ -67,15 +64,13 @@ class P2PLoad:
|
||||
port=self.source_port,
|
||||
rank=0,
|
||||
world_size=2,
|
||||
group_name='netloader',
|
||||
group_name="netloader",
|
||||
)
|
||||
logger.info(
|
||||
f"Finish init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Start recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
|
||||
)
|
||||
logger.info(f"Start recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}")
|
||||
logger.info(f"Model device: {model_device}")
|
||||
|
||||
trans_stream = torch_npu.npu.Stream()
|
||||
@@ -84,14 +79,11 @@ class P2PLoad:
|
||||
if len(param.shape) == 0:
|
||||
continue
|
||||
receiver_pg.recv([param], 1, 0).wait()
|
||||
torch.distributed.barrier(group=receiver_pg,
|
||||
device_ids=[model_device.index])
|
||||
torch.distributed.barrier(group=receiver_pg, device_ids=[model_device.index])
|
||||
|
||||
torch_npu.npu.synchronize(trans_stream)
|
||||
|
||||
logger.info(
|
||||
f"Finish recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
|
||||
)
|
||||
logger.info(f"Finish recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}")
|
||||
loaded_model = model
|
||||
except Exception as e:
|
||||
logger.error("Failed to recv model: {}".format(e))
|
||||
@@ -129,9 +121,7 @@ class P2PSend:
|
||||
"""
|
||||
model_device = next(model.parameters()).device
|
||||
torch.npu.set_device(model_device)
|
||||
logger.info(
|
||||
f"Start init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
|
||||
)
|
||||
logger.info(f"Start init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}")
|
||||
sender_pg = None
|
||||
try:
|
||||
sender_pg = stateless_init_process_group(
|
||||
@@ -139,14 +129,10 @@ class P2PSend:
|
||||
port=self.listen_port,
|
||||
rank=1,
|
||||
world_size=2,
|
||||
group_name='netloader',
|
||||
)
|
||||
logger.info(
|
||||
f"Finish init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
|
||||
)
|
||||
logger.info(
|
||||
f"Start send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
|
||||
group_name="netloader",
|
||||
)
|
||||
logger.info(f"Finish init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}")
|
||||
logger.info(f"Start send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}")
|
||||
logger.info(f"Model device: {model_device}")
|
||||
|
||||
trans_stream = torch_npu.npu.Stream()
|
||||
@@ -155,16 +141,12 @@ class P2PSend:
|
||||
if "aclnn_input_scale" in name:
|
||||
continue
|
||||
if name in int8_params:
|
||||
sender_pg.send([int8_params[name].to(model_device)], 0,
|
||||
0).wait()
|
||||
sender_pg.send([int8_params[name].to(model_device)], 0, 0).wait()
|
||||
else:
|
||||
sender_pg.send([param.contiguous()], 0, 0).wait()
|
||||
torch.distributed.barrier(group=sender_pg,
|
||||
device_ids=[model_device.index])
|
||||
torch.distributed.barrier(group=sender_pg, device_ids=[model_device.index])
|
||||
torch_npu.npu.synchronize(trans_stream)
|
||||
logger.info(
|
||||
f"Finish send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
|
||||
)
|
||||
logger.info(f"Finish send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}")
|
||||
finally:
|
||||
if sender_pg:
|
||||
destroy_stateless_process_group(sender_pg)
|
||||
destroy_stateless_process_group(sender_pg)
|
||||
|
||||
@@ -17,16 +17,13 @@
|
||||
import gc
|
||||
import ipaddress
|
||||
from datetime import timedelta
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch._C._distributed_c10d import (_DEFAULT_PG_TIMEOUT,
|
||||
_register_process_group,
|
||||
_unregister_process_group)
|
||||
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT, _register_process_group, _unregister_process_group
|
||||
from torch.distributed import ProcessGroup, is_hccl_available
|
||||
from torch.distributed.distributed_c10d import (Backend, BackendConfig,
|
||||
PrefixStore, _world)
|
||||
from torch.distributed.distributed_c10d import Backend, BackendConfig, PrefixStore, _world
|
||||
from torch.distributed.rendezvous import rendezvous
|
||||
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
|
||||
from vllm.logger import logger
|
||||
@@ -39,7 +36,7 @@ def stateless_init_process_group(
|
||||
rank: int,
|
||||
timeout: timedelta = _DEFAULT_PG_TIMEOUT,
|
||||
group_name: str = "",
|
||||
pg_options: Optional[Any] = None,
|
||||
pg_options: Any | None = None,
|
||||
) -> ProcessGroup:
|
||||
"""
|
||||
Initializes a stateless process group.
|
||||
@@ -57,7 +54,8 @@ def stateless_init_process_group(
|
||||
ProcessGroup: The initialized process group.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If world_size is not positive, or if rank is not within [0, world_size - 1], or if HCCL is unavailable.
|
||||
RuntimeError: If world_size is not positive, or if rank is not within
|
||||
[0, world_size - 1], or if HCCL is unavailable.
|
||||
TypeError: If timeout is not a timedelta type.
|
||||
ValueError: If group_name already exists.
|
||||
"""
|
||||
@@ -67,21 +65,18 @@ def stateless_init_process_group(
|
||||
raise RuntimeError("world_size must be positive")
|
||||
# Check if rank is within [0, world_size - 1]
|
||||
if not (rank >= 0 and rank <= world_size - 1):
|
||||
raise RuntimeError(
|
||||
"rank should be a number between 0 and ``world_size``-1")
|
||||
raise RuntimeError("rank should be a number between 0 and ``world_size``-1")
|
||||
# Check if HCCL is available
|
||||
if not is_hccl_available():
|
||||
raise RuntimeError("HCCL is not available")
|
||||
# Check if timeout is a timedelta type
|
||||
if not isinstance(timeout, timedelta):
|
||||
raise TypeError(
|
||||
f"Expected timeout argument to be of type datetime.timedelta, got {timeout}"
|
||||
)
|
||||
raise TypeError(f"Expected timeout argument to be of type datetime.timedelta, got {timeout}")
|
||||
# Check if group_name already exists
|
||||
if group_name in _world.pg_names.values():
|
||||
raise ValueError(
|
||||
f"The specified group name {group_name} has already been "
|
||||
"created, please use a different group name")
|
||||
f"The specified group name {group_name} has already been created, please use a different group name"
|
||||
)
|
||||
|
||||
# Function to check if an IPv6 address is valid
|
||||
def is_valid_ipv6_address(address: str) -> bool:
|
||||
@@ -101,10 +96,9 @@ def stateless_init_process_group(
|
||||
# Get initialization method
|
||||
init_method = get_tcp_uri(host, port)
|
||||
# Create Backend object
|
||||
backend = Backend('hccl')
|
||||
backend = Backend("hccl")
|
||||
# Use rendezvous function to get store, rank, and world_size
|
||||
store, rank, world_size = next(
|
||||
rendezvous(init_method, rank, world_size, timeout=timeout))
|
||||
store, rank, world_size = next(rendezvous(init_method, rank, world_size, timeout=timeout))
|
||||
|
||||
# Set timeout for store
|
||||
store.set_timeout(timeout)
|
||||
@@ -125,9 +119,7 @@ def stateless_init_process_group(
|
||||
pg._set_default_backend(Backend.backend_type_map[backend])
|
||||
|
||||
# Check if pg_options is None or not of type ProcessGroupHCCL.Options
|
||||
if pg_options is None or not isinstance(
|
||||
pg_options,
|
||||
torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options):
|
||||
if pg_options is None or not isinstance(pg_options, torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options):
|
||||
pg_options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options()
|
||||
# Set attributes for pg_options
|
||||
pg_options.is_high_priority_stream = False
|
||||
@@ -135,8 +127,7 @@ def stateless_init_process_group(
|
||||
pg_options.global_ranks_in_group = []
|
||||
pg_options.group_id = f"{init_method}/{group_name}/"
|
||||
# Create ProcessGroupHCCL object
|
||||
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
|
||||
pg_options)
|
||||
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size, pg_options)
|
||||
# Set sequence number for backend_class
|
||||
backend_class._set_sequence_number_for_group()
|
||||
# Set backend_type
|
||||
@@ -176,9 +167,10 @@ def destroy_stateless_process_group(pg: ProcessGroup, manual_gc: bool = False):
|
||||
_world.pg_group_ranks.pop(pg, None)
|
||||
_world.pg_backend_config.pop(pg, None)
|
||||
# Check if pg is in keys of _world.pg_coalesce_state
|
||||
if pg in _world.pg_coalesce_state.keys():
|
||||
logger.warning("Some coalesced collectives haven't been launched when "
|
||||
"ProcessGroup is destroyed. They will be cleaned.")
|
||||
if pg in _world.pg_coalesce_state:
|
||||
logger.warning(
|
||||
"Some coalesced collectives haven't been launched when ProcessGroup is destroyed. They will be cleaned."
|
||||
)
|
||||
del _world.pg_coalesce_state[pg]
|
||||
# Unregister the process group
|
||||
_unregister_process_group(pg.group_name)
|
||||
|
||||
Reference in New Issue
Block a user