[feat] support customized and separated hccl_buffer_size for process group initialization (#3073)
### What this PR does / why we need it? Currently, users have to set `HCCL_BUFFSIZE` to 512~1024 to perform mc2 operators (dispatch and combine) while running moe models with large `ep_size` and `batch_size`. This environmental variable not only affects allocated VRAM for mc2 group, but also increases VRAM allocation for dp, tp & ep groups, leading to significant kvcache and free_memory drops. This PR supports to automatically calculate and set `hccl_buffer_size` for each process group **(except mc2 group)** separately when users set `HCCL_BUFFSIZE` for mc2 group. This can significantly reduce wasted buffer_size set for dp, tp & ep groups. Note that current mc2 operators can only perform communication space partitioning based on `HCCL_BUFFSIZE` configuration. Once they support `hccl_buffer_size` configuration with `pg_options` while initializing process group, we'll caculate the required buffer size and users would avoid set `HCCL_BUFFSIZE` themselves. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? We performed E2E serving with deepseek_r1 initializing DP/TP/EP/MC2 process group and observed significant kv_cache and free_memory increase! - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -15,17 +15,82 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import vllm
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from torch.distributed import Backend
|
||||
from vllm.distributed.parallel_state import (GroupCoordinator,
|
||||
_get_unique_name, _register_group)
|
||||
|
||||
from vllm_ascend.distributed.communicator import NPUCommunicator
|
||||
from vllm_ascend.utils import create_hccl_pg_options
|
||||
|
||||
|
||||
class GroupCoordinatorPatch(GroupCoordinator):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(
|
||||
self,
|
||||
group_ranks: list[list[int]],
|
||||
local_rank: int,
|
||||
torch_distributed_backend: Union[str, Backend],
|
||||
use_device_communicator: bool, # whether to use device communicator
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: Optional[str] = None,
|
||||
):
|
||||
group_name = group_name or "anonymous"
|
||||
self.unique_name = _get_unique_name(group_name)
|
||||
_register_group(self)
|
||||
|
||||
self.rank = torch.distributed.get_rank()
|
||||
self.local_rank = local_rank
|
||||
|
||||
self_device_group = None
|
||||
self_cpu_group = None
|
||||
hccl_pg_options = create_hccl_pg_options(group_name)
|
||||
|
||||
for ranks in group_ranks:
|
||||
device_group = torch.distributed.new_group(
|
||||
ranks,
|
||||
backend=torch_distributed_backend,
|
||||
pg_options=hccl_pg_options)
|
||||
|
||||
# a group with `gloo` backend, to allow direct coordination between
|
||||
# processes through the CPU.
|
||||
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
||||
if self.rank in ranks:
|
||||
self.ranks = ranks
|
||||
self.world_size = len(ranks)
|
||||
self.rank_in_group = ranks.index(self.rank)
|
||||
self_device_group = device_group
|
||||
self_cpu_group = cpu_group
|
||||
|
||||
assert self_cpu_group is not None
|
||||
assert self_device_group is not None
|
||||
|
||||
self.cpu_group = self_cpu_group
|
||||
self.device_group = self_device_group
|
||||
self.device = torch.npu.current_device()
|
||||
|
||||
self.use_device_communicator = use_device_communicator
|
||||
self.device_communicator = None
|
||||
if use_device_communicator and self.world_size > 1:
|
||||
self.device_communicator = NPUCommunicator(
|
||||
cpu_group=self.cpu_group,
|
||||
device=self.device,
|
||||
device_group=self.device_group,
|
||||
unique_name=self.unique_name,
|
||||
)
|
||||
|
||||
from vllm.distributed.device_communicators.shm_broadcast import \
|
||||
MessageQueue
|
||||
self.mq_broadcaster: Optional[MessageQueue] = None
|
||||
if use_message_queue_broadcaster and self.world_size > 1:
|
||||
self.mq_broadcaster = MessageQueue.create_from_process_group(
|
||||
self.cpu_group, 1 << 22, 6)
|
||||
|
||||
self.use_custom_op_call = False
|
||||
self.use_cpu_custom_send_recv = False
|
||||
|
||||
def all_to_all(self,
|
||||
input_: torch.Tensor,
|
||||
@@ -41,9 +106,10 @@ class GroupCoordinatorPatch(GroupCoordinator):
|
||||
assert -input_.dim() <= gather_dim < input_.dim(), (
|
||||
f"Invalid gather dim ({gather_dim}) for input tensor with shape {input_.size()}"
|
||||
)
|
||||
assert self.device_communicator is not None, "device_communicator should be initialized when world_size > 1"
|
||||
return self.device_communicator.all_to_all(input_, scatter_dim,
|
||||
gather_dim, scatter_sizes,
|
||||
gather_sizes)
|
||||
|
||||
|
||||
vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch # Note: check the GroupCoordinator with online serving
|
||||
vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch
|
||||
|
||||
Reference in New Issue
Block a user