Sync from v0.13
This commit is contained in:
109
vllm/v1/worker/ubatch_utils.py
Normal file
109
vllm/v1/worker/ubatch_utils.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeAlias
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class UBatchSlice:
|
||||
request_slice: slice
|
||||
token_slice: slice
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return (
|
||||
self.request_slice.start == self.request_slice.stop
|
||||
or self.token_slice.start == self.token_slice.stop
|
||||
)
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
return self.token_slice.stop - self.token_slice.start
|
||||
|
||||
|
||||
UBatchSlices: TypeAlias = list[UBatchSlice]
|
||||
|
||||
|
||||
def is_second_ubatch_empty(orig_num_tokens: int, padded_num_tokens: int) -> bool:
|
||||
return (padded_num_tokens // 2) >= orig_num_tokens
|
||||
|
||||
|
||||
def check_ubatch_thresholds(
|
||||
config: ParallelConfig, num_tokens: int, uniform_decode: bool
|
||||
) -> bool:
|
||||
if not config.enable_dbo:
|
||||
return False
|
||||
if uniform_decode:
|
||||
return num_tokens >= config.dbo_decode_token_threshold
|
||||
else:
|
||||
return num_tokens >= config.dbo_prefill_token_threshold
|
||||
|
||||
|
||||
# This just pads the second ubatch slice out to the total number of tokens
|
||||
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
|
||||
def _pad_out_ubatch_slices(
|
||||
ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int
|
||||
) -> UBatchSlices:
|
||||
# TODO(lucas): handle empty second ubatch
|
||||
padded_second_request_slice = slice(
|
||||
ubatch_slices[1].request_slice.start, num_reqs_padded
|
||||
)
|
||||
padded_second_token_slice = slice(
|
||||
ubatch_slices[1].token_slice.start, num_total_tokens
|
||||
)
|
||||
return [
|
||||
ubatch_slices[0],
|
||||
UBatchSlice(padded_second_request_slice, padded_second_token_slice),
|
||||
]
|
||||
|
||||
|
||||
def maybe_create_ubatch_slices(
|
||||
should_ubatch: bool,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
num_tokens_padded: int,
|
||||
num_reqs_padded: int,
|
||||
split_point: int | None = None,
|
||||
) -> tuple[UBatchSlices | None, UBatchSlices | None]:
|
||||
if not should_ubatch:
|
||||
return None, None
|
||||
|
||||
if split_point is None:
|
||||
split_point = int(num_tokens_padded) // 2
|
||||
|
||||
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
|
||||
# in cu_num_tokens directly (i.e. query_start_loc)
|
||||
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
|
||||
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
|
||||
|
||||
first_ubatch_token_slice = slice(0, split_point)
|
||||
second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1])
|
||||
|
||||
# Determine request slices using exclusive stop semantics
|
||||
# First ubatch includes requests whose tokens overlap [0, split_point)
|
||||
first_ubatch_req_stop = int(
|
||||
np.searchsorted(cu_num_tokens, split_point, side="left")
|
||||
)
|
||||
first_ubatch_req_slice = slice(0, first_ubatch_req_stop)
|
||||
|
||||
# Second ubatch starts at the request that contains the split_point
|
||||
# or the request starting exactly at split_point (if on boundary)
|
||||
second_ubatch_req_start = int(
|
||||
np.searchsorted(cu_num_tokens, split_point, side="right") - 1
|
||||
)
|
||||
second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1)
|
||||
|
||||
ubatch_slices = [
|
||||
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
|
||||
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice),
|
||||
]
|
||||
|
||||
ubatch_slices_padded = _pad_out_ubatch_slices(
|
||||
ubatch_slices, num_tokens_padded, num_reqs_padded
|
||||
)
|
||||
|
||||
assert sum(s.num_tokens for s in ubatch_slices_padded) == num_tokens_padded
|
||||
|
||||
return ubatch_slices, ubatch_slices_padded
|
||||
Reference in New Issue
Block a user