# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch from vllm.triton_utils import tl, triton # Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 # can be used to combine partial attention results (in the split-KV case) def merge_attn_states( output: torch.Tensor, prefix_output: torch.Tensor, prefix_lse: torch.Tensor, suffix_output: torch.Tensor, suffix_lse: torch.Tensor, output_lse: torch.Tensor | None = None, ) -> None: num_tokens = output.shape[0] num_query_heads = output.shape[1] head_size = output.shape[2] padded_head_size = triton.next_power_of_2(head_size) # TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead. merge_attn_states_kernel[(num_tokens, num_query_heads)]( output, output_lse, prefix_output, prefix_lse, suffix_output, suffix_lse, head_size, padded_head_size, output_lse is not None, ) @triton.jit def merge_attn_states_kernel( output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] output_lse, # [NUM_HEADS, NUM_TOKENS] prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] prefix_lse, # [NUM_HEADS, NUM_TOKENS] suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] suffix_lse, # [NUM_HEADS, NUM_TOKENS] HEAD_SIZE: tl.constexpr, PADDED_HEAD_SIZE: tl.constexpr, OUTPUT_LSE: tl.constexpr, ): token_idx = tl.program_id(0) num_tokens = tl.num_programs(0) head_idx = tl.program_id(1) num_heads = tl.num_programs(1) p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx) s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx) # FA2 and FA3 have different behavior for when the sum-exp is 0, this namely # arises with 0 len seqlens. FA3 returns -inf here while FA2 returns inf. # If we see an inf assume FA2 and convert inf to -inf for consistency # and correctness. Inf generally doesn't make sense in this context outside # of undefined-behavior/FA2-case, so I think this a safe assumption. p_lse = float("-inf") if p_lse == float("inf") else p_lse s_lse = float("-inf") if s_lse == float("inf") else s_lse max_lse = tl.maximum(p_lse, s_lse) p_lse = p_lse - max_lse s_lse = s_lse - max_lse # Will reuse precomputed Exp values for scale factor computation. p_se = tl.exp(p_lse) s_se = tl.exp(s_lse) out_se = p_se + s_se if OUTPUT_LSE: out_lse = tl.log(out_se) + max_lse tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse) head_arange = tl.arange(0, PADDED_HEAD_SIZE) head_mask = head_arange < HEAD_SIZE p_out = tl.load( prefix_output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange, mask=head_mask, ) s_out = tl.load( suffix_output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange, mask=head_mask, ) # NOTE(woosuk): Be careful with the numerical stability. # We should compute the scale first, and then multiply it with the output. # Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly. p_scale = p_se / out_se s_scale = s_se / out_se out = p_out * p_scale + s_out * s_scale tl.store( output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange, out, mask=head_mask, )