Sync from v0.13
This commit is contained in:
31
vllm/v1/worker/gpu/dp_utils.py
Normal file
31
vllm/v1/worker/gpu/dp_utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
|
||||
|
||||
def get_batch_metadata_across_dp(
|
||||
num_tokens: int,
|
||||
cudagraph_size: int,
|
||||
dp_size: int,
|
||||
dp_rank: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert dp_size > 1
|
||||
# Use CPU group to avoid CPU-GPU synchronization.
|
||||
group = get_dp_group().cpu_group
|
||||
tensor = torch.zeros(2, dp_size, dtype=torch.int32, device="cpu")
|
||||
tensor[0][dp_rank] = num_tokens
|
||||
tensor[1][dp_rank] = cudagraph_size
|
||||
dist.all_reduce(tensor, group=group)
|
||||
return tensor[0], tensor[1]
|
||||
|
||||
|
||||
def make_num_tokens_across_dp(
|
||||
dp_size: int,
|
||||
num_tokens: int,
|
||||
) -> torch.Tensor | None:
|
||||
if dp_size == 1:
|
||||
return None
|
||||
return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu")
|
||||
Reference in New Issue
Block a user