update
This commit is contained in:
77
vllm/v1/worker/gpu/dp_utils.py
Normal file
77
vllm/v1/worker/gpu/dp_utils.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
|
||||
|
||||
def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | None:
|
||||
if dp_size == 1:
|
||||
return None
|
||||
return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu")
|
||||
|
||||
|
||||
def get_batch_metadata_across_dp(
|
||||
num_tokens: int,
|
||||
cudagraph_size: int,
|
||||
cudagraph_runtime_mode: int,
|
||||
dp_size: int,
|
||||
dp_rank: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert dp_size > 1
|
||||
# Use CPU group to avoid CPU-GPU synchronization.
|
||||
group = get_dp_group().cpu_group
|
||||
tensor = torch.zeros(3, dp_size, dtype=torch.int32, device="cpu")
|
||||
tensor[0][dp_rank] = num_tokens
|
||||
tensor[1][dp_rank] = cudagraph_size
|
||||
tensor[2][dp_rank] = cudagraph_runtime_mode
|
||||
dist.all_reduce(tensor, group=group)
|
||||
return tensor[0], tensor[1], tensor[2]
|
||||
|
||||
|
||||
def get_cudagraph_and_dp_padding(
|
||||
num_tokens: int,
|
||||
cudagraph_size: int | None,
|
||||
cudagraph_runtime_mode: int,
|
||||
dp_size: int,
|
||||
dp_rank: int,
|
||||
) -> tuple[int, torch.Tensor | None, int]:
|
||||
if dp_size == 1:
|
||||
if cudagraph_size is not None:
|
||||
return cudagraph_size, None, cudagraph_runtime_mode
|
||||
else:
|
||||
return num_tokens, None, cudagraph_runtime_mode
|
||||
|
||||
# Convert None to -1 for sync (indicates no cudagraph available)
|
||||
if num_tokens == 0:
|
||||
cudagraph_size = 0
|
||||
elif cudagraph_size is None:
|
||||
cudagraph_size = -1
|
||||
|
||||
num_tokens_across_dp, cudagraph_size_across_dp, cudagraph_mode_across_dp = (
|
||||
get_batch_metadata_across_dp(
|
||||
num_tokens, cudagraph_size, cudagraph_runtime_mode, dp_size, dp_rank
|
||||
)
|
||||
)
|
||||
if torch.all(num_tokens_across_dp == 0).item():
|
||||
# All ranks have zero tokens to run.
|
||||
return 0, None, 0
|
||||
|
||||
# Synchronize cudagraph_runtime_mode across ranks by taking the minimum.
|
||||
synced_cudagraph_mode = int(cudagraph_mode_across_dp.min().item())
|
||||
# Check if all ranks have valid cudagraph_size.
|
||||
all_have_cudagraph = torch.all(cudagraph_size_across_dp != -1).item()
|
||||
|
||||
if synced_cudagraph_mode != 0 and all_have_cudagraph:
|
||||
# All ranks use cudagraph. Pad to max cudagraph_size.
|
||||
max_cudagraph_size = int(cudagraph_size_across_dp.max().item())
|
||||
num_tokens_across_dp[:] = max_cudagraph_size
|
||||
return max_cudagraph_size, num_tokens_across_dp, synced_cudagraph_mode
|
||||
else:
|
||||
# Fall back to eager mode (no cudagraph).
|
||||
# Either some rank doesn't have cudagraph size or mode is NONE.
|
||||
synced_cudagraph_mode = 0
|
||||
num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1)
|
||||
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank].item())
|
||||
return num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode
|
||||
Reference in New Issue
Block a user