# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib from functools import lru_cache import torch from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.platforms import current_platform logger = init_logger(__name__) # Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84 def fp8_mqa_logits_torch( q: torch.Tensor, kv: tuple[torch.Tensor, torch.Tensor], weights: torch.Tensor, cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, ) -> torch.Tensor: """Compute FP8 MQA logits for a single sequence without KV paging. Args: q: Query tensor of shape [M, H, D]. Casted to `torch.float8_e4m3fn` by caller. kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or [N, 1]) with dtype `torch.float32`. weights: weights of shape [M, H], dtype `torch.float32`. cu_seqlen_ks: Start indices (inclusive) for valid K per query position, shape [M], dtype int32. cu_seqlen_ke: End indices (exclusive) for valid K per query position, shape [M], dtype int32. Returns: Logits tensor of shape [M, N], dtype `torch.float32`. """ kv, scale = kv seq_len_kv = kv.shape[0] k = kv.to(torch.bfloat16) q = q.to(torch.bfloat16) mask_lo = ( torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] ) mask_hi = ( torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] ) mask = mask_lo & mask_hi score = torch.einsum("mhd,nd->hmn", q, k).float() * scale logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) logits = logits.masked_fill(~mask, float("-inf")) return logits def rocm_fp8_mqa_logits( q: torch.Tensor, kv: tuple[torch.Tensor, torch.Tensor], weights: torch.Tensor, cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, ) -> torch.Tensor: """Compute FP8 MQA logits for a single sequence without KV paging. Args: q: Query tensor of shape [M, H, D]. Casted to `torch.float8_e4m3fn` by caller. kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or [N, 1]) with dtype `torch.float32`. weights: weights of shape [M, H], dtype `torch.float32`. cu_seqlen_ks: Start indices (inclusive) for valid K per query position, shape [M], dtype int32. cu_seqlen_ke: End indices (exclusive) for valid K per query position, shape [M], dtype int32. Returns: Logits tensor of shape [M, N], dtype `torch.float32`. """ # TODO(ganyi): Temporarily workaround, will remove the module check and reference # path after aiter merge this kernel into main @lru_cache def has_mqa_logits_module(): return importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None if rocm_aiter_ops.is_enabled() and has_mqa_logits_module(): from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits kv, scale = kv return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke) else: return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) # Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156 def fp8_paged_mqa_logits_torch( q: torch.Tensor, kv_cache: torch.Tensor, weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor, max_model_len: int, ): from vllm.utils.math_utils import cdiv fp8_dtype = current_platform.fp8_dtype() batch_size, next_n, _, dim = q.size() kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:] scale = scale.contiguous().view(torch.float) q = q.float() kv_cache = kv_cache.view(fp8_dtype).float() * scale num_block, block_size, _, dim = kv_cache.size() logits = torch.full( [batch_size * next_n, max_model_len], float("-inf"), device=q.device, dtype=torch.float32, ) context_lens = context_lens.tolist() for i in range(batch_size): context_len = context_lens[i] q_offsets = torch.arange(context_len - next_n, context_len, device="cuda") weight_slice = ( weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous() ) for block_rk in range(cdiv(context_len, block_size)): block_idx = block_tables[i][block_rk] qx, kx = q[i], kv_cache[block_idx] k_offsets = torch.arange( block_rk * block_size, (block_rk + 1) * block_size, device="cuda" ) mask = (k_offsets[None, :] < context_len) & ( k_offsets[None, :] <= q_offsets[:, None] ) s = torch.where( mask[None, :, :], (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to( logits.dtype ), float("-inf"), ) s = torch.relu(s) * weight_slice[..., None] s = s.sum(dim=0) logits[ i * next_n : (i + 1) * next_n, block_rk * block_size : (block_rk + 1) * block_size, ] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf")) return logits def rocm_fp8_paged_mqa_logits( q_fp8: torch.Tensor, kv_cache_fp8: torch.Tensor, weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor, schedule_metadata: torch.Tensor, max_model_len: int, ) -> torch.Tensor: """Compute FP8 MQA logits using paged KV-cache. Args: q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to `torch.float8_e4m3fn` by caller. kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last 4 bytes per (block,pos) store the `float` dequant scale. weights: Tensor of shape [B * next_n, H], dtype `torch.float32`. context_lens: Tensor of shape [B], dtype int32; effective context length for each batch element. block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical block indices to physical blocks in the paged cache. schedule_metadata: Returned by `get_paged_mqa_logits_metadata`; used to distribute work across SMs. max_model_len: Maximum sequence length used to size the logits output. Returns: Logits tensor of shape [B * next_n, max_model_len], dtype `torch.float32`. """ if rocm_aiter_ops.is_enabled(): from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits_stage1 batch_size, next_n, heads, _ = q_fp8.shape out_qk = torch.full( (heads, batch_size * next_n, max_model_len), float("-inf"), device="cuda", dtype=torch.float32, ) deepgemm_fp8_paged_mqa_logits_stage1( q_fp8, kv_cache_fp8, weights, out_qk, context_lens, block_tables, max_model_len, ) return out_qk.sum(dim=0) else: return fp8_paged_mqa_logits_torch( q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len )