# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass from typing import TypeAlias import numpy as np import torch from vllm.config import ParallelConfig from vllm.v1.attention.backend import CommonAttentionMetadata @dataclass class UBatchSlice: request_slice: slice token_slice: slice def is_empty(self) -> bool: return ( self.request_slice.start == self.request_slice.stop or self.token_slice.start == self.token_slice.stop ) @property def num_tokens(self) -> int: return self.token_slice.stop - self.token_slice.start UBatchSlices: TypeAlias = list[UBatchSlice] def is_last_ubatch_empty( orig_num_tokens: int, padded_num_tokens: int, num_ubatches: int ) -> bool: return (padded_num_tokens // num_ubatches) * (num_ubatches - 1) >= orig_num_tokens def check_ubatch_thresholds( config: ParallelConfig, num_tokens: int, uniform_decode: bool ) -> bool: if not config.use_ubatching: return False if uniform_decode: return num_tokens >= config.dbo_decode_token_threshold else: return num_tokens >= config.dbo_prefill_token_threshold # This pads the last ubatch slice out to the total number of tokens # (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding. def _pad_out_ubatch_slices( ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int ) -> UBatchSlices: last_slice = ubatch_slices[-1] padded_last_request_slice = slice(last_slice.request_slice.start, num_reqs_padded) padded_last_token_slice = slice(last_slice.token_slice.start, num_total_tokens) return ubatch_slices[:-1] + [ UBatchSlice(padded_last_request_slice, padded_last_token_slice) ] def maybe_create_ubatch_slices( should_ubatch: bool, num_scheduled_tokens: np.ndarray, num_tokens_padded: int, num_reqs_padded: int, num_ubatches: int, split_point: list[int] | int | None = None, ) -> tuple[UBatchSlices | None, UBatchSlices | None]: if not should_ubatch: return None, None if split_point is None: split_point = int(num_tokens_padded) // num_ubatches token_split_points = [split_point * i for i in range(1, num_ubatches)] # TODO(lucas): Refactor the gpu_model_runner.py so we can pass # in cu_num_tokens directly (i.e. query_start_loc) cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32) np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:]) ubatch_slices = [] start_token = 0 # Add the end point to the split points to make iteration easier all_points = token_split_points + [cu_num_tokens[-1]] for end_token in all_points: token_slice = slice(start_token, end_token) # Determine request slices using exclusive stop semantics # Ubatch includes requests whose tokens overlap [start_token, end_token) # Start at the request that contains the start_token # or the request starting exactly at start_token (if on boundary) req_start = int(np.searchsorted(cu_num_tokens, start_token, side="right") - 1) # Stop at the request that starts at or after end_token req_stop = int(np.searchsorted(cu_num_tokens, end_token, side="left")) req_slice = slice(req_start, req_stop) ubatch_slices.append(UBatchSlice(req_slice, token_slice)) start_token = end_token ubatch_slices_padded = _pad_out_ubatch_slices( ubatch_slices, num_tokens_padded, num_reqs_padded ) assert sum(s.num_tokens for s in ubatch_slices_padded) == num_tokens_padded return ubatch_slices, ubatch_slices_padded def slice_query_start_locs( query_start_loc: torch.Tensor, request_slice: slice, ) -> torch.Tensor: """ Creates a new query_start_loc that corresponds to the requests in request_slice. Note: This function creates a new tensor to hold the new query_start_locs. This will break cudagraph compatibility. """ return ( query_start_loc[request_slice.start : request_slice.stop + 1] - query_start_loc[request_slice.start] ) def _make_metadata_with_slice( ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata ) -> CommonAttentionMetadata: """ This function creates a new CommonAttentionMetadata that corresponds to the requests included in ubatch_slice """ assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty" request_slice = ubatch_slice.request_slice token_slice = ubatch_slice.token_slice start_locs = attn_metadata.query_start_loc_cpu first_req = request_slice.start first_tok = token_slice.start last_req = request_slice.stop - 1 last_tok = token_slice.stop - 1 assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], ( "Token slice start outside of first request" ) # NOTE: last token can be outside of the last request if we have CG padding. # If the request is split across ubatches, we have to adjust the metadata. # splits_first_request: The first request in this slice is the continuation of # a request that started in a previous slice. # splits_last_request: The last request in this slice continues into the # next slice. splits_first_request = first_tok > start_locs[first_req] splits_last_request = last_tok < start_locs[last_req + 1] - 1 query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice) query_start_loc = slice_query_start_locs( attn_metadata.query_start_loc, request_slice ) assert len(query_start_loc) >= 2, ( f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}" ) if splits_first_request: tokens_skipped = first_tok - start_locs[first_req] query_start_loc[1:] -= tokens_skipped query_start_loc_cpu[1:] -= tokens_skipped seq_lens = attn_metadata.seq_lens[request_slice] seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice] if splits_last_request: # NOTE: We use start_locs (the original query_start_loc_cpu) to calculate # the tokens skipped because query_start_loc_cpu might have been modified # if splits_first_request is True. tokens_skipped = start_locs[last_req + 1] - token_slice.stop query_start_loc[-1] -= tokens_skipped query_start_loc_cpu[-1] -= tokens_skipped # Make sure we don't modify the seq_lens tensors # (not cudagraph compatible) seq_lens = seq_lens.clone() seq_lens_cpu = seq_lens_cpu.clone() seq_lens[-1] -= tokens_skipped seq_lens_cpu[-1] -= tokens_skipped max_seq_len = int(seq_lens_cpu.max()) num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice] num_requests = request_slice.stop - request_slice.start num_actual_tokens = token_slice.stop - token_slice.start max_query_len = int( torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item() ) # This is to account for the case where we are in a dummy # run and query_start_loc_cpu is full of 0s if max_query_len == 0: max_query_len = attn_metadata.max_query_len block_table_tensor = attn_metadata.block_table_tensor[request_slice] slot_mapping = attn_metadata.slot_mapping[token_slice] return CommonAttentionMetadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, seq_lens=seq_lens, num_reqs=num_requests, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, max_seq_len=max_seq_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, _seq_lens_cpu=seq_lens_cpu, _num_computed_tokens_cpu=num_computed_tokens_cpu, ) def split_attn_metadata( ubatch_slices: list[UBatchSlice], common_attn_metadata: CommonAttentionMetadata, ) -> list[CommonAttentionMetadata]: """ Creates a new CommonAttentionMetadata instance that corresponds to the requests for each UBatchSlice in ubatch_slices. Note: This function does not modify common_attn_metadata """ results = [] for ubatch_slice in ubatch_slices: results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata)) return results