# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utility functions for sparse MLA backends.""" import torch from vllm.triton_utils import tl, triton # Kernel with prefill workspace support and valid count tracking @triton.jit def _convert_req_index_to_global_index_kernel( req_id_ptr, # int32 [num_tokens] block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] valid_count_ptr, # int32 [num_tokens] - output valid count per row prefill_request_id_ptr, # int32 [num_tokens], -1 for decode, >=0 for prefill workspace_starts_ptr, # int32 [num_prefill_reqs+1] or nullptr # shapes (compile-time where possible) max_num_blocks_per_req: tl.constexpr, BLOCK_SIZE: tl.constexpr, BLOCK_N: tl.constexpr, # tile width along columns HAS_PREFILL: tl.constexpr, COUNT_VALID: tl.constexpr, # whether to count valid indices # strides (in elements) bt_stride0, bt_stride1, ti_stride0, ti_stride1, out_stride0, out_stride1, ): # program_id(0) -> token_id (row) # program_id(1) -> tile index along columns token_id = tl.program_id(0) tile_id = tl.program_id(1) # Each program covers BLOCK_N consecutive columns indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) # Load request id for this token (no mask: grid is exact) req = tl.load(req_id_ptr + token_id) # Load token indices for this tile ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 tok = tl.load(ti_ptr) # int32 # Only token == -1 should propagate as -1 is_invalid_tok = tok < 0 is_prefill = False if HAS_PREFILL: prefill_req_id = tl.load(prefill_request_id_ptr + token_id) is_prefill = prefill_req_id >= 0 # Compute block id and in-block offset block_id = tok // BLOCK_SIZE inblock_off = tok % BLOCK_SIZE # Guard block_table access valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0) bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 is_invalid_tok |= ~valid_block base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0) out_val = base * BLOCK_SIZE + inblock_off # Override with prefill output if prefill is enabled if HAS_PREFILL: workspace_start = tl.load( workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0 ) prefill_out = workspace_start + tok out_val = tl.where(is_prefill, prefill_out, out_val) out_val = tl.where(is_invalid_tok, -1, out_val) # Store results out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 tl.store(out_ptr_ij, out_val) # Count valid indices in this tile and atomically add to row total if COUNT_VALID: tile_valid_count = tl.sum((~is_invalid_tok).to(tl.int32)) tl.atomic_add(valid_count_ptr + token_id, tile_valid_count) def triton_convert_req_index_to_global_index( req_id: torch.Tensor, # int32 [num_tokens] block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] BLOCK_SIZE: int = 64, NUM_TOPK_TOKENS: int = 2048, BLOCK_N: int = 128, # tile width along columns HAS_PREFILL_WORKSPACE: bool = False, prefill_workspace_request_ids: torch.Tensor | None = None, prefill_workspace_starts: torch.Tensor | None = None, return_valid_counts: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ out[token_id, indice_id] = block_table[req_id[token_id], token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + token_indices[token_id, indice_id] % BLOCK_SIZE Only when token_indices[token_id, indice_id] == -1 do we output -1. For safety, we also output -1 if the derived block_id would be out-of-bounds. When HAS_PREFILL_WORKSPACE is True, prefill tokens are mapped to workspace offsets instead of global cache slots. prefill_workspace_request_ids and prefill_workspace_starts must be provided. prefill_workspace_request_ids: int32 [num_tokens], -1 for decode else prefill request index (maps to prefill_workspace_starts) prefill_workspace_starts: int32 [num_prefills], 0-indexed workspace starts for each prefill request When return_valid_counts is True, also returns the count of valid (non -1) indices per row, computed during the same kernel pass (no extra overhead). """ assert req_id.dtype == torch.int32 assert block_table.dtype == torch.int32 assert token_indices.dtype == torch.int32 assert token_indices.shape[1] == NUM_TOPK_TOKENS assert NUM_TOPK_TOKENS % BLOCK_N == 0, ( f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})" ) if HAS_PREFILL_WORKSPACE: assert prefill_workspace_request_ids is not None assert prefill_workspace_starts is not None assert prefill_workspace_request_ids.dtype == torch.int32 assert prefill_workspace_starts.dtype == torch.int32 num_tokens = req_id.shape[0] max_num_blocks_per_req = block_table.shape[1] tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N # Ensure contiguous tensors on the same device req_id_c = req_id.contiguous() block_table_c = block_table.contiguous() token_indices_c = token_indices.contiguous() out = torch.empty_like(token_indices_c) # Allocate valid count buffer if needed (must be zero-initialized for atomics) valid_counts: torch.Tensor | None = None if return_valid_counts: valid_counts = torch.zeros( num_tokens, dtype=torch.int32, device=token_indices.device ) # Strides in elements bt_stride0, bt_stride1 = block_table_c.stride() ti_stride0, ti_stride1 = token_indices_c.stride() out_stride0, out_stride1 = out.stride() # Prepare prefill pointers if HAS_PREFILL_WORKSPACE: assert prefill_workspace_request_ids is not None # for mypy assert prefill_workspace_starts is not None # for mypy assert prefill_workspace_request_ids.is_contiguous() assert prefill_workspace_starts.is_contiguous() # Exact 2D grid: tokens × column tiles grid = (num_tokens, tiles_per_row) _convert_req_index_to_global_index_kernel[grid]( req_id_c, block_table_c, token_indices_c, out, valid_counts, prefill_workspace_request_ids, prefill_workspace_starts, # shapes / constexprs max_num_blocks_per_req, BLOCK_SIZE, BLOCK_N, HAS_PREFILL_WORKSPACE, return_valid_counts, # strides bt_stride0, bt_stride1, ti_stride0, ti_stride1, out_stride0, out_stride1, ) if return_valid_counts: assert valid_counts is not None return out, valid_counts return out