init
This commit is contained in:
192
vllm/v1/worker/ubatch_splitting.py
Normal file
192
vllm/v1/worker/ubatch_splitting.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.forward_context import DPMetadata
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import round_up
|
||||
from vllm.v1.worker.ubatch_utils import (UBatchSlice, UBatchSlices,
|
||||
is_second_ubatch_empty)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def should_ubatch_with_num_tokens(
|
||||
should_ubatch: bool,
|
||||
orig_num_tokens_per_ubatch: int,
|
||||
padded_num_tokens_per_ubatch: int,
|
||||
vllm_config: VllmConfig,
|
||||
) -> tuple[bool, Optional[torch.Tensor]]:
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
return DPMetadata.should_ubatch_across_dp(should_ubatch,
|
||||
orig_num_tokens_per_ubatch,
|
||||
padded_num_tokens_per_ubatch,
|
||||
dp_size, dp_rank)
|
||||
|
||||
|
||||
def check_ubatch_thresholds(config: ParallelConfig, num_tokens: int,
|
||||
uniform_decode: bool) -> bool:
|
||||
if not config.enable_dbo:
|
||||
return False
|
||||
if uniform_decode:
|
||||
return num_tokens >= config.dbo_decode_token_threshold
|
||||
else:
|
||||
return num_tokens >= config.dbo_prefill_token_threshold
|
||||
|
||||
|
||||
def get_dp_padding_ubatch(
|
||||
num_tokens_unpadded: int, num_tokens_padded: int,
|
||||
should_attempt_ubatching: bool,
|
||||
vllm_config: VllmConfig) -> tuple[bool, Optional[torch.Tensor]]:
|
||||
"""
|
||||
1. Decides if each DP rank is going to microbatch. Either all ranks
|
||||
run with microbatching or none of them do. If this function decides
|
||||
not to run with microbatching. It will "abort" meaning that no padding
|
||||
information will be returned to the caller. It will return (False, None)
|
||||
|
||||
2. Determines the total number of tokens that each rank will run.
|
||||
All ranks will be padded out so that the run with the same number
|
||||
of tokens
|
||||
|
||||
Returns: tuple[
|
||||
should_ubatch: Are all DP ranks going to microbatch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens per-microbatch for each DP rank including padding. Will be
|
||||
None if should_ubatch if False
|
||||
]
|
||||
|
||||
"""
|
||||
assert num_tokens_padded >= num_tokens_unpadded
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
if dp_size == 1:
|
||||
# Early exit.
|
||||
return False, None
|
||||
|
||||
# If this DP rank doesn't want to attempt microbatching
|
||||
if not should_attempt_ubatching:
|
||||
(should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens(
|
||||
False, 0, 0, vllm_config)
|
||||
assert should_ubatch is False
|
||||
assert num_tokens_across_dp is None
|
||||
return should_ubatch, num_tokens_across_dp
|
||||
|
||||
# Round up to the next multiple of two for even divisibility
|
||||
num_tokens_padded = round_up(num_tokens_padded, 2)
|
||||
num_tokens_per_ubatch = num_tokens_padded // 2
|
||||
should_ubatch = True
|
||||
|
||||
# Sanity Check that the existing padding isn't giving us an empty second
|
||||
# ubatch. Abort if so
|
||||
if is_second_ubatch_empty(num_tokens_unpadded, num_tokens_padded):
|
||||
logger.debug(
|
||||
"Empty second µbatch detected: unpadded tokens: %s, padded "
|
||||
"tokens: %s", num_tokens_unpadded, num_tokens_padded)
|
||||
should_ubatch = False
|
||||
|
||||
# Note that we compute the number of padded tokens per ubatch
|
||||
(should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens(
|
||||
should_ubatch, num_tokens_unpadded // 2, num_tokens_per_ubatch,
|
||||
vllm_config)
|
||||
if not should_ubatch:
|
||||
assert num_tokens_across_dp is None
|
||||
return should_ubatch, num_tokens_across_dp
|
||||
|
||||
assert num_tokens_across_dp is not None
|
||||
|
||||
max_tokens_across_dp_cpu = int(torch.max(num_tokens_across_dp).item())
|
||||
num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
|
||||
dp_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
return should_ubatch, num_tokens_after_padding
|
||||
|
||||
def create_ubatch_slices(num_scheduled_tokens: np.ndarray, split_point: int) \
|
||||
-> UBatchSlices:
|
||||
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
|
||||
# in cu_num_tokens directly (i.e. query_start_loc)
|
||||
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
|
||||
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
|
||||
|
||||
first_ubatch_token_slice = slice(0, split_point)
|
||||
second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1])
|
||||
|
||||
# Determine request slices using exclusive stop semantics
|
||||
# First ubatch includes requests whose tokens overlap [0, split_point)
|
||||
first_ubatch_req_stop = int(
|
||||
np.searchsorted(cu_num_tokens, split_point, side="left"))
|
||||
first_ubatch_req_slice = slice(0, first_ubatch_req_stop)
|
||||
|
||||
# Second ubatch starts at the request that contains the split_point
|
||||
# or the request starting exactly at split_point (if on boundary)
|
||||
second_ubatch_req_start = int(
|
||||
np.searchsorted(cu_num_tokens, split_point, side="right") - 1)
|
||||
second_ubatch_req_slice = slice(second_ubatch_req_start,
|
||||
len(cu_num_tokens) - 1)
|
||||
|
||||
return [
|
||||
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
|
||||
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice)
|
||||
]
|
||||
|
||||
|
||||
def ubatch_split(
|
||||
num_scheduled_tokens_per_request: np.ndarray,
|
||||
num_tokens_unpadded: int,
|
||||
num_tokens_padded: int,
|
||||
uniform_decode: bool,
|
||||
vllm_config: VllmConfig,
|
||||
) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Coordinates amongst all DP ranks to determine if and how the full batch
|
||||
should be split into microbatches.
|
||||
|
||||
Returns: tuple[
|
||||
ubatch_slices: if this is set then all DP ranks have agreed to
|
||||
microbatch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens per-microbatch for each DP rank including padding. Will be
|
||||
None if ubatch_slices is None
|
||||
]
|
||||
|
||||
"""
|
||||
parallel_config = vllm_config.parallel_config
|
||||
# Don't bother with the should_ubatch handshaking unless microbatching
|
||||
# is enabled
|
||||
if not parallel_config.enable_dbo:
|
||||
return (None, None)
|
||||
|
||||
# Check preconditions for microbatching
|
||||
should_attempt_ubatching = check_ubatch_thresholds(
|
||||
parallel_config,
|
||||
num_tokens_unpadded,
|
||||
uniform_decode=uniform_decode,
|
||||
)
|
||||
|
||||
# Don't microbatch unless every other DP worker is also microbatching
|
||||
should_ubatch, num_tokens_after_padding = get_dp_padding_ubatch(
|
||||
num_tokens_unpadded,
|
||||
num_tokens_padded,
|
||||
should_attempt_ubatching,
|
||||
vllm_config,
|
||||
)
|
||||
|
||||
if not should_ubatch:
|
||||
return (None, None)
|
||||
|
||||
# This doesn't actually pad the ubatch slices. It just initializes the
|
||||
# split point to the padded value so that padding can be applied
|
||||
# to the second ubatch in pad_out_ubatch_slice after attention
|
||||
# metadata creation
|
||||
assert num_tokens_after_padding is not None
|
||||
token_split_point = int(num_tokens_after_padding[0].item())
|
||||
|
||||
ubatch_slices = create_ubatch_slices(num_scheduled_tokens_per_request,
|
||||
token_split_point)
|
||||
|
||||
return (ubatch_slices, num_tokens_after_padding)
|
||||
Reference in New Issue
Block a user