106 lines
3.5 KiB
Python
106 lines
3.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import torch
|
|
|
|
from vllm.triton_utils import tl, triton
|
|
|
|
|
|
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
|
# can be used to combine partial attention results (in the split-KV case)
|
|
def merge_attn_states(
|
|
output: torch.Tensor,
|
|
prefix_output: torch.Tensor,
|
|
prefix_lse: torch.Tensor,
|
|
suffix_output: torch.Tensor,
|
|
suffix_lse: torch.Tensor,
|
|
output_lse: torch.Tensor | None = None,
|
|
) -> None:
|
|
num_tokens = output.shape[0]
|
|
num_query_heads = output.shape[1]
|
|
head_size = output.shape[2]
|
|
padded_head_size = triton.next_power_of_2(head_size)
|
|
|
|
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
|
|
merge_attn_states_kernel[(num_tokens, num_query_heads)](
|
|
output,
|
|
output_lse,
|
|
prefix_output,
|
|
prefix_lse,
|
|
suffix_output,
|
|
suffix_lse,
|
|
head_size,
|
|
padded_head_size,
|
|
output_lse is not None,
|
|
)
|
|
|
|
|
|
@triton.jit
|
|
def merge_attn_states_kernel(
|
|
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
|
output_lse, # [NUM_HEADS, NUM_TOKENS]
|
|
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
|
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
|
|
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
|
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
|
|
HEAD_SIZE: tl.constexpr,
|
|
PADDED_HEAD_SIZE: tl.constexpr,
|
|
OUTPUT_LSE: tl.constexpr,
|
|
):
|
|
token_idx = tl.program_id(0)
|
|
num_tokens = tl.num_programs(0)
|
|
head_idx = tl.program_id(1)
|
|
num_heads = tl.num_programs(1)
|
|
|
|
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
|
|
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
|
|
|
|
# FA2 and FA3 have different behavior for when the sum-exp is 0, this namely
|
|
# arises with 0 len seqlens. FA3 returns -inf here while FA2 returns inf.
|
|
# If we see an inf assume FA2 and convert inf to -inf for consistency
|
|
# and correctness. Inf generally doesn't make sense in this context outside
|
|
# of undefined-behavior/FA2-case, so I think this a safe assumption.
|
|
p_lse = float("-inf") if p_lse == float("inf") else p_lse
|
|
s_lse = float("-inf") if s_lse == float("inf") else s_lse
|
|
|
|
max_lse = tl.maximum(p_lse, s_lse)
|
|
p_lse = p_lse - max_lse
|
|
s_lse = s_lse - max_lse
|
|
# Will reuse precomputed Exp values for scale factor computation.
|
|
p_se = tl.exp(p_lse)
|
|
s_se = tl.exp(s_lse)
|
|
out_se = p_se + s_se
|
|
|
|
if OUTPUT_LSE:
|
|
out_lse = tl.log(out_se) + max_lse
|
|
tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse)
|
|
|
|
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
|
|
head_mask = head_arange < HEAD_SIZE
|
|
p_out = tl.load(
|
|
prefix_output
|
|
+ token_idx * num_heads * HEAD_SIZE
|
|
+ head_idx * HEAD_SIZE
|
|
+ head_arange,
|
|
mask=head_mask,
|
|
)
|
|
s_out = tl.load(
|
|
suffix_output
|
|
+ token_idx * num_heads * HEAD_SIZE
|
|
+ head_idx * HEAD_SIZE
|
|
+ head_arange,
|
|
mask=head_mask,
|
|
)
|
|
|
|
# NOTE(woosuk): Be careful with the numerical stability.
|
|
# We should compute the scale first, and then multiply it with the output.
|
|
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
|
|
p_scale = p_se / out_se
|
|
s_scale = s_se / out_se
|
|
out = p_out * p_scale + s_out * s_scale
|
|
tl.store(
|
|
output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
|
|
out,
|
|
mask=head_mask,
|
|
)
|