32 lines
920 B
Python
32 lines
920 B
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from vllm.distributed.parallel_state import get_dp_group
|
|
|
|
|
|
def get_batch_metadata_across_dp(
|
|
num_tokens: int,
|
|
cudagraph_size: int,
|
|
dp_size: int,
|
|
dp_rank: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
assert dp_size > 1
|
|
# Use CPU group to avoid CPU-GPU synchronization.
|
|
group = get_dp_group().cpu_group
|
|
tensor = torch.zeros(2, dp_size, dtype=torch.int32, device="cpu")
|
|
tensor[0][dp_rank] = num_tokens
|
|
tensor[1][dp_rank] = cudagraph_size
|
|
dist.all_reduce(tensor, group=group)
|
|
return tensor[0], tensor[1]
|
|
|
|
|
|
def make_num_tokens_across_dp(
|
|
dp_size: int,
|
|
num_tokens: int,
|
|
) -> torch.Tensor | None:
|
|
if dp_size == 1:
|
|
return None
|
|
return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu")
|