# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Custom Sparse Attention Indexer layers.""" import torch from vllm._aiter_ops import rocm_aiter_ops from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.platforms import current_platform from vllm.utils.deep_gemm import ( fp8_mqa_logits, fp8_mqa_logits_torch, fp8_paged_mqa_logits, fp8_paged_mqa_logits_torch, is_deep_gemm_supported, ) from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerMetadata, ) from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.v1.worker.workspace import current_workspace_manager from vllm.utils.math_utils import cdiv if current_platform.is_cuda_alike(): from vllm import _custom_ops as ops elif current_platform.is_xpu(): from vllm._xpu_ops import xpu_ops as ops import ixformer.inference.functions as ixfops logger = init_logger(__name__) @torch.inference_mode() def cp_gather_indexer_k_quant_cache( kv_cache, # [num_blocks, block_size, head_dim] dst_value, # [cu_seq_lens[-1], head_dim] block_table, # [batch_size, num_blocks] cu_seq_lens, # [batch_size + 1, ] batch_size, ): num_blocks, block_size, _ = kv_cache.shape head_dim = dst_value.shape[-1] kv_cache = kv_cache.view(num_blocks, -1) expected_value = [] # expected_scale = [] for b in range(batch_size): s = cu_seq_lens[b + 1] - cu_seq_lens[b] if s == 0: continue tot = cdiv(s, block_size) blocks = block_table[b, :tot] value = [] scale = [] full_block = torch.arange(tot - 1, device=kv_cache.device, dtype=torch.int32) non_remaining_value = kv_cache[blocks[full_block], :block_size * head_dim].view(-1, head_dim) # non_remaining_scale = kv_cache[blocks[full_block], # block_size * head_dim:].view(-1, 4) remaining = s - (tot - 1) * block_size value = torch.cat([ non_remaining_value, kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim) ], dim=0) # scale = torch.cat([ # non_remaining_scale, # kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim + # remaining * 4].view(-1, 4) # ], # dim=0) expected_value.append(value) # expected_scale.append(scale) gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim) # gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4) gather_value = gather_value.view(torch.bfloat16) # gather_scale = gather_scale.view(torch.float32) dst_value.copy_(gather_value) # dst_scale.copy_(gather_scale) def sparse_attn_indexer( hidden_states: torch.Tensor, k_cache_prefix: str, kv_cache: torch.Tensor, q: torch.Tensor, k: torch.Tensor, weights: torch.Tensor, topk_tokens: int, head_dim: int, max_model_len: int, total_seq_lens: int, topk_indices_buffer: torch.Tensor | None, ) -> torch.Tensor: # careful! this will be None in dummy run attn_metadata = get_forward_context().attn_metadata # assert isinstance(attn_metadata, dict) if not isinstance(attn_metadata, dict): return sparse_attn_indexer_fake( hidden_states, k_cache_prefix, kv_cache, q, k, weights, topk_tokens, head_dim, max_model_len, total_seq_lens, topk_indices_buffer, ) attn_metadata = attn_metadata[k_cache_prefix] assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) slot_mapping = attn_metadata.slot_mapping has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens ops.indexer_k_cache( k, kv_cache, slot_mapping ) # topk_indices_buffer[: hidden_states.shape[0]] = -1 if has_prefill: prefill_metadata = attn_metadata.prefill for chunk in prefill_metadata.chunks: logits = ixfops.dsa_indexer_mqa_logits_with_blocks( q[chunk.token_start:chunk.token_end], chunk.cu_seqlens_q, chunk.cu_seq_lens, kv_cache, chunk.block_table, weights[chunk.token_start : chunk.token_end], max_q_len=chunk.max_q_len, max_kv_len=chunk.max_kv_len, max_context_len=chunk.max_context_len ) ixfops.dsa_update_topk_indices( logits, chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, topk_tokens, topk_indices_buffer[chunk.token_start:chunk.token_end] ) if has_decode: decode_metadata = attn_metadata.decode # TODO: support speculative decode if decode_metadata.requires_padding: raise NotImplementedError( "Sparse attention indexer does not support requires_padding" ) # Use dsa_indexer_mqa_logits_with_blocks similar to prefill logits = ixfops.dsa_indexer_mqa_logits_with_blocks( q[:num_decode_tokens], decode_metadata.cu_seqlens_q, decode_metadata.cu_seqlens_kv, kv_cache, decode_metadata.block_table, weights[:num_decode_tokens], max_q_len=decode_metadata.max_q_len, max_kv_len=decode_metadata.max_kv_len, max_context_len=decode_metadata.max_context_len, ) ixfops.dsa_update_topk_indices( logits, decode_metadata.cu_seqlen_ks, decode_metadata.cu_seqlen_ke, topk_tokens, topk_indices_buffer[:num_decode_tokens], ) return topk_indices_buffer def sparse_attn_indexer_original( hidden_states: torch.Tensor, k_cache_prefix: str, kv_cache: torch.Tensor, q: torch.Tensor, k: torch.Tensor, weights: torch.Tensor, topk_tokens: int, head_dim: int, max_model_len: int, total_seq_lens: int, topk_indices_buffer: torch.Tensor, ) -> torch.Tensor: # careful! this will be None in dummy run attn_metadata = get_forward_context().attn_metadata # fp8_dtype = current_platform.fp8_dtype() # assert isinstance(attn_metadata, dict) if not isinstance(attn_metadata, dict): # Reserve workspace for indexer during profiling run current_workspace_manager().get_simultaneous( ((total_seq_lens, head_dim), torch.float8_e4m3fn), ((total_seq_lens, 4), torch.uint8), ) return sparse_attn_indexer_fake( hidden_states, k_cache_prefix, kv_cache, q, k, weights, topk_tokens, head_dim, max_model_len, total_seq_lens, topk_indices_buffer, ) attn_metadata = attn_metadata[k_cache_prefix] assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) slot_mapping = attn_metadata.slot_mapping has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens # During speculative decoding, k may be padded to the CUDA graph batch # size while slot_mapping only covers actual tokens. Truncate k to avoid # out-of-bounds reads in the kernel. num_tokens = slot_mapping.shape[0] k = k[:num_tokens] ops.indexer_k_cache( k, kv_cache, slot_mapping, ) topk_indices_buffer[: hidden_states.shape[0]] = -1 if has_prefill: prefill_metadata = attn_metadata.prefill # Get the full shared workspace buffers once (will allocate on first use) workspace_manager = current_workspace_manager() k_full = workspace_manager.get_simultaneous( ((total_seq_lens, head_dim), torch.bfloat16), )[0] for chunk in prefill_metadata.chunks: k = k_full[: chunk.total_seq_lens] # k_scale = k_scale_full[: chunk.total_seq_lens] cp_gather_indexer_k_quant_cache( kv_cache, k, chunk.block_table, chunk.cu_seq_lens, chunk.num_reqs, ) logits = ops.ref_mqa_logits( q[chunk.token_start:chunk.token_end], k, weights[chunk.token_start : chunk.token_end], chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, ) topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]), dim=-1)[1] topk_indices -= chunk.cu_seqlen_ks[:, None] mask_lo = topk_indices >= 0 mask_hi = topk_indices - (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks)[:, None] < 0 mask = torch.full_like(topk_indices, False, dtype=torch.bool, device=topk_indices.device) mask = mask_lo & mask_hi topk_indices = topk_indices.masked_fill(~mask, -1) topk_indices_buffer[ chunk.token_start:chunk.token_end, :topk_indices. shape[-1]] = topk_indices.to(dtype=torch.int32) # Compute lengths from row spans # lengths = (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks).to(torch.int32) # torch.ops._C.large_context_topk( # logits, # topk_indices, # lengths, # chunk.cu_seqlen_ks, # row_starts # ) if has_decode: decode_metadata = attn_metadata.decode # kv_cache size requirement [num_block, block_size, n_head, head_dim], # we only have [num_block, block_size, head_dim], kv_cache = kv_cache.unsqueeze(-2) decode_lens = decode_metadata.decode_lens if decode_metadata.requires_padding: # pad in edge case where we have short chunked prefill length < # decode_threshold since we unstrictly split # prefill and decode by decode_threshold # (currently set to 1 + speculative tokens) padded_q_decode_tokens = pack_seq_triton( q[:num_decode_tokens], decode_lens) else: padded_q_decode_tokens = q[:num_decode_tokens].reshape( decode_lens.shape[0], -1, *q.shape[1:]) # TODO: move and optimize below logic with triton kernels batch_size = padded_q_decode_tokens.shape[0] next_n = padded_q_decode_tokens.shape[1] assert batch_size == decode_metadata.seq_lens.shape[0] num_padded_tokens = batch_size * next_n logits = ops.ref_paged_mqa_logits( padded_q_decode_tokens, kv_cache, weights[:num_padded_tokens], decode_metadata.seq_lens, decode_metadata.block_table, max_model_len=max_model_len, clean_logits=False, ) # padded query len current_device = padded_q_decode_tokens.device padded_num_tokens = batch_size * next_n positions = torch.arange(max_model_len, device=current_device).unsqueeze(0).expand( batch_size * next_n, -1) row_indices = torch.arange(padded_num_tokens, device=current_device) // next_n next_n_offset = torch.arange( padded_num_tokens, device=padded_q_decode_tokens.device) % next_n index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n + next_n_offset).unsqueeze(1) # index_end_pos: [B * N, 1] mask = positions <= index_end_pos # mask: [B * N, L] logits = logits.masked_fill(~mask, float('-inf')) topk_indices = logits.topk(topk_tokens, dim=-1)[1].to(torch.int32) # [B * N, K] # ensure we don't set indices for the top k # that is out of range(masked already) # this will happen if context length is shorter than K topk_indices[topk_indices > index_end_pos] = -1 if decode_metadata.requires_padding: # if padded, we need to unpack # the topk indices removing padded tokens topk_indices = unpack_seq_triton( topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), decode_lens, ) topk_indices_buffer[:num_decode_tokens, :topk_indices. shape[-1]] = topk_indices.to(dtype=torch.int32) return topk_indices_buffer def sparse_attn_indexer_fake( hidden_states: torch.Tensor, k_cache_prefix: str, kv_cache: torch.Tensor, q: torch.Tensor, k: torch.Tensor, weights: torch.Tensor, topk_tokens: int, head_dim: int, max_model_len: int, total_seq_lens: int, topk_indices_buffer: torch.Tensor | None, ) -> torch.Tensor: return topk_indices_buffer direct_register_custom_op( op_name="sparse_attn_indexer", op_func=sparse_attn_indexer, mutates_args=["topk_indices_buffer"], fake_impl=sparse_attn_indexer_fake, dispatch_key=current_platform.dispatch_key, ) @CustomOp.register("sparse_attn_indexer") class SparseAttnIndexer(CustomOp): """Sparse Attention Indexer Custom Op Layer. This layer is extracted as a separate custom op since it involves heavy custom kernels like `mqa_logits`, `paged_mqa_logits` and `top_k_per_row`, etc. Those kernels maybe requires specific memory layout or implementation for different hardware backends to achieve optimal performance. For now, the default native path will use CUDA backend path. Other platform may requires add the corresponding Custom Op name `sparse_attn_indexer` to `custom_ops` in `CompilationConfig` to enable the platform specific path. """ def __init__( self, k_cache, quant_block_size: int, scale_fmt: str, topk_tokens: int, head_dim: int, max_model_len: int, max_total_seq_len: int, topk_indices_buffer: torch.Tensor, ): super().__init__() self.k_cache = k_cache self.quant_block_size = quant_block_size self.scale_fmt = scale_fmt self.topk_tokens = topk_tokens self.head_dim = head_dim self.max_model_len = max_model_len self.max_total_seq_len = max_total_seq_len self.topk_indices_buffer = topk_indices_buffer if current_platform.is_cuda() and not is_deep_gemm_supported(): logger.warning_once( "DeepGEMM is not supported or available. SparseAttnIndexer will use a " "less efficient PyTorch implementation. " "Please make sure you have the required hardware and software setup " "for DeepGEMM to achieve optimal performance." ) def forward_native( self, hidden_states: torch.Tensor, q_fp8: torch.Tensor, k: torch.Tensor, weights: torch.Tensor, ): if current_platform.is_cuda(): return self.forward_cuda(hidden_states, q_fp8, k, weights) elif current_platform.is_rocm(): return self.forward_hip(hidden_states, q_fp8, k, weights) else: raise NotImplementedError( "SparseAttnIndexer native forward is only implemented for " "CUDA and ROCm platform." ) def forward_cuda( self, hidden_states: torch.Tensor, q: torch.Tensor, k: torch.Tensor, weights: torch.Tensor, ): return torch.ops.vllm.sparse_attn_indexer( hidden_states, self.k_cache.prefix, self.k_cache.kv_cache[0], q, k, weights, self.topk_tokens, self.head_dim, self.max_model_len, self.max_total_seq_len, self.topk_indices_buffer, ) def forward_hip( self, hidden_states: torch.Tensor, q_fp8: torch.Tensor, k: torch.Tensor, weights: torch.Tensor, ): if rocm_aiter_ops.is_enabled(): return torch.ops.vllm.rocm_aiter_sparse_attn_indexer( hidden_states, self.k_cache.prefix, self.k_cache.kv_cache[0], q_fp8, k, weights, self.quant_block_size, self.scale_fmt, self.topk_tokens, self.head_dim, self.max_model_len, self.max_total_seq_len, self.topk_indices_buffer, ) else: raise RuntimeError( "Sparse attention indexer ROCm custom op requires ROCm " "Aiter ops to be enabled." )