# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Authors: # - Burkhard Ringlein # - Jan van Lunteren # - Chih-Chieh Yang # - Thomas Parnell import torch from vllm import _custom_ops as ops from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from .prefix_prefill import context_attention_fwd float8_info = torch.finfo(current_platform.fp8_dtype()) @triton.jit def cdiv_fn(x, y): return (x + y - 1) // y @triton.jit def kernel_paged_attention_2d( output_ptr, # [num_tokens, num_query_heads, head_size] query_ptr, # [num_tokens, num_query_heads, head_size] key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] sink_ptr, # [num_query_heads] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] scale, # float32 k_scale, # float32 v_scale, # float32 out_scale_inv, num_query_heads: tl.constexpr, # int num_queries_per_kv: tl.constexpr, # int num_queries_per_kv_padded: tl.constexpr, # int block_table_stride: tl.int64, # int query_stride_0: tl.int64, # int query_stride_1: tl.int64, # int, should be equal to head_size output_stride_0: tl.int64, # int output_stride_1: tl.int64, # int, should be equal to head_size BLOCK_SIZE: tl.constexpr, # int HEAD_SIZE: tl.constexpr, # int HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 USE_ALIBI_SLOPES: tl.constexpr, # bool SLIDING_WINDOW: tl.constexpr, # int x: tl.constexpr, # int stride_k_cache_0: tl.int64, # int stride_k_cache_1: tl.int64, # int stride_k_cache_2: tl.int64, # int stride_k_cache_3: tl.int64, # int stride_k_cache_4: tl.int64, # int stride_v_cache_0: tl.int64, # int stride_v_cache_1: tl.int64, # int stride_v_cache_2: tl.int64, # int stride_v_cache_3: tl.int64, # int filter_by_query_len: tl.constexpr, # bool query_start_len_ptr, # [num_seqs+1] USE_SINKS: tl.constexpr, # bool USE_FP8: tl.constexpr, FP8_MIN: tl.constexpr = float8_info.min, FP8_MAX: tl.constexpr = float8_info.max, ): seq_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) if filter_by_query_len: cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index if cur_batch_query_len > 1: return else: cur_batch_in_all_start_index = seq_idx query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange( 0, num_queries_per_kv_padded ) query_offset = ( cur_batch_in_all_start_index * query_stride_0 + query_head_idx[:, None] * query_stride_1 ) head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv head_mask = head_mask & (query_head_idx < num_query_heads) dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, 0).to(tl.int1) # Q : (num_queries_per_kv, HEAD_SIZE,) Q = tl.load( query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :], mask=dim_mask[None, :] & head_mask[:, None], other=0.0, ) block_table_offset = seq_idx * block_table_stride if not USE_SINKS: M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32) else: M = tl.load( sink_ptr + query_head_idx, mask=head_mask, other=float("-inf"), ).to(dtype=tl.float32) L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32) acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], dtype=tl.float32) # sequence len for this particular sequence seq_len = tl.load(seq_lens_ptr + seq_idx) # alibi slope for this head if USE_ALIBI_SLOPES: alibi_slope = tl.load( alibi_slopes_ptr + query_head_idx, mask=head_mask, other=0.0 ) num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) # iterate through tiles for j in range(0, num_blocks): physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) offs_n = tl.arange(0, BLOCK_SIZE) offs_d = tl.arange(0, HEAD_SIZE_PADDED) v_offset = ( physical_block_idx * stride_v_cache_0 + kv_head_idx * stride_v_cache_1 + offs_d[None, :] * stride_v_cache_2 + offs_n[:, None] * stride_v_cache_3 ) k_offset = ( physical_block_idx * stride_k_cache_0 + kv_head_idx * stride_k_cache_1 + (offs_d[:, None] // x) * stride_k_cache_2 + offs_n[None, :] * stride_k_cache_3 + (offs_d[:, None] % x) * stride_k_cache_4 ) # K : (HEAD_SIZE, BLOCK_SIZE) K_load = tl.load(key_cache_ptr + k_offset, mask=dim_mask[:, None], other=0.0) if K_load.dtype.is_fp8(): K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) else: K = K_load # V : (BLOCK_SIZE, HEAD_SIZE) V_load = tl.load(value_cache_ptr + v_offset, mask=dim_mask[None, :], other=0.0) if V_load.dtype.is_fp8(): V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) else: V = V_load seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32) seq_mask = seq_offset[None, :] < boundary # S : (num_queries_per_kv, BLOCK_SIZE,) S = tl.where(head_mask[:, None] & seq_mask, 0.0, float("-inf")).to(tl.float32) S += scale * tl.dot(Q, K) context_len = seq_len - 1 if SLIDING_WINDOW > 0: S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S, -10000) if USE_ALIBI_SLOPES: S += alibi_slope[:, None] * (seq_offset - context_len) # compute running maximum # m_j : (num_queries_per_kv,) m_j = tl.maximum(M, tl.max(S, axis=1)) # P : (num_queries_per_kv, BLOCK_SIZE,) P = tl.exp(S - m_j[:, None]) # l_j : (num_queries_per_kv,) l_j = tl.sum(P, axis=1) # alpha : (num_queries_per_kv, ) alpha = tl.exp(M - m_j) # acc : (num_queries_per_kv, BLOCK_SIZE,) acc = acc * alpha[:, None] # update constants L = L * alpha + l_j M = m_j # acc : (num_queries_per_kv, BLOCK_SIZE,) acc += tl.dot(P.to(V.dtype), V) # epilogue acc = acc / L[:, None] if USE_FP8: acc = acc * tl.load(out_scale_inv) acc = tl.clamp(acc, FP8_MIN, FP8_MAX) output_offset = ( cur_batch_in_all_start_index * output_stride_0 + query_head_idx * output_stride_1 ) tl.store( output_ptr + output_offset[:, None] + tl.arange(0, HEAD_SIZE_PADDED)[None, :], acc, mask=dim_mask[None, :] & head_mask[:, None], ) def chunked_prefill_paged_decode( query, key, value, output, kv_cache_dtype, key_cache, value_cache, block_table, query_start_loc, seq_lens, max_seq_len, max_query_len, k_scale, v_scale, alibi_slopes=None, sliding_window=None, sm_scale=None, output_scale=None, # Optional tensor for sinks sinks=None, ): if sm_scale is None: sm_scale = 1.0 / (query.shape[1] ** 0.5) use_alibi_slopes = alibi_slopes is not None if sliding_window is None or sliding_window <= 0: sliding_window = 0 if max_query_len > 1: context_attention_fwd( q=query, k=key, v=value, o=output, kv_cache_dtype=kv_cache_dtype, k_cache=key_cache, v_cache=value_cache, b_loc=block_table, b_start_loc=query_start_loc, b_seq_len=seq_lens, max_seq_len=max_seq_len, max_input_len=max_query_len, k_scale=k_scale, v_scale=v_scale, alibi_slopes=alibi_slopes, sliding_window=sliding_window, sm_scale=sm_scale, skip_decode=True, fp8_out_scale=output_scale, sinks=sinks, ) block_size = value_cache.shape[3] num_seqs = len(seq_lens) num_query_heads = query.shape[1] num_kv_heads = key.shape[1] num_queries_per_kv = query.shape[1] // key.shape[1] head_size = query.shape[2] # Conversion of FP8 Tensor from uint8 storage to # appropriate torch.dtype for interpretation by Triton if "fp8" in kv_cache_dtype: assert key_cache.dtype in [torch.uint8, current_platform.fp8_dtype()] assert value_cache.dtype in [torch.uint8, current_platform.fp8_dtype()] if kv_cache_dtype in ("fp8", "fp8_e4m3"): target_dtype = current_platform.fp8_dtype() elif kv_cache_dtype == "fp8_e5m2": target_dtype = torch.float8_e5m2 else: raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) key_cache = key_cache.view(target_dtype) value_cache = value_cache.view(target_dtype) num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), 16) from vllm.platforms.rocm import use_rocm_custom_paged_attention use_custom = use_rocm_custom_paged_attention( query.dtype, head_size, block_size, num_queries_per_kv, max_seq_len, sliding_window, kv_cache_dtype, alibi_slopes, sinks, ) if use_custom: _PARTITION_SIZE_ROCM = 256 max_num_partitions = ( max_seq_len + _PARTITION_SIZE_ROCM - 1 ) // _PARTITION_SIZE_ROCM assert _PARTITION_SIZE_ROCM % block_size == 0 total_num_seq = block_table.shape[0] tmp_output = torch.empty( size=(total_num_seq, num_query_heads, max_num_partitions, head_size), dtype=query.dtype, device=output.device, ) exp_sums = torch.empty( size=(total_num_seq, num_query_heads, max_num_partitions), dtype=torch.float32, device=output.device, ) max_logits = torch.empty_like(exp_sums) ops.paged_attention_rocm( output, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, num_kv_heads, scale=sm_scale, block_tables=block_table, seq_lens=seq_lens, query_start_loc=query_start_loc, block_size=block_size, max_seq_len=max_seq_len, alibi_slopes=alibi_slopes, kv_cache_dtype=kv_cache_dtype, k_scale=k_scale, v_scale=v_scale, fp8_out_scale=output_scale, ) else: kernel_paged_attention_2d[ ( num_seqs, num_kv_heads, ) ]( output_ptr=output, query_ptr=query, key_cache_ptr=key_cache, value_cache_ptr=value_cache, sink_ptr=sinks, block_tables_ptr=block_table, seq_lens_ptr=seq_lens, alibi_slopes_ptr=alibi_slopes, scale=sm_scale, k_scale=k_scale, v_scale=v_scale, out_scale_inv=1.0 / output_scale if output_scale is not None else 1.0, num_query_heads=num_query_heads, num_queries_per_kv=num_queries_per_kv, num_queries_per_kv_padded=num_queries_per_kv_padded, block_table_stride=block_table.stride(0), query_stride_0=query.stride(0), query_stride_1=query.stride(1), output_stride_0=output.stride(0), output_stride_1=output.stride(1), BLOCK_SIZE=block_size, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, SLIDING_WINDOW=sliding_window, x=key_cache.shape[4], stride_k_cache_0=key_cache.stride(0), stride_k_cache_1=key_cache.stride(1), stride_k_cache_2=key_cache.stride(2), stride_k_cache_3=key_cache.stride(3), stride_k_cache_4=key_cache.stride(4), stride_v_cache_0=value_cache.stride(0), stride_v_cache_1=value_cache.stride(1), stride_v_cache_2=value_cache.stride(2), stride_v_cache_3=value_cache.stride(3), filter_by_query_len=True, query_start_len_ptr=query_start_loc, USE_SINKS=sinks is not None, USE_FP8=output_scale is not None, )