[Fix] Window attention compatible with RadixAttention and chunked prefill (#1112)

This commit is contained in:
Ying Sheng
2024-08-15 10:33:20 -07:00
committed by GitHub
parent 9195d1362a
commit 93d4e354d8
5 changed files with 37 additions and 56 deletions

View File

@@ -15,6 +15,8 @@ limitations under the License.
"""Radix attention."""
from typing import Optional
import torch
from flashinfer.cascade import merge_state
from torch import nn
@@ -34,8 +36,7 @@ class RadixAttention(nn.Module):
scaling: float,
num_kv_heads: int,
layer_id: int,
reuse: bool = False,
sliding_window_size: int = -1,
sliding_window_size: Optional[int] = None,
logit_cap: int = -1,
v_head_dim: int = -1,
):
@@ -48,8 +49,7 @@ class RadixAttention(nn.Module):
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
self.scaling = scaling
self.layer_id = layer_id
self.reuse = reuse
self.sliding_window_size = sliding_window_size
self.sliding_window_size = sliding_window_size if sliding_window_size else -1
if (
not global_server_args_dict.get("disable_flashinfer", False)
@@ -118,16 +118,16 @@ class RadixAttention(nn.Module):
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
prefill_wrapper_ragged = input_metadata.flashinfer_prefill_wrapper_ragged
prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
if self.sliding_window_size != -1 or self.reuse:
if self.sliding_window_size != -1:
prefill_wrapper_paged = prefill_wrapper_paged[0]
else:
if isinstance(prefill_wrapper_paged, list):
prefill_wrapper_paged = prefill_wrapper_paged[1]
if not input_metadata.flashinfer_use_ragged or self.reuse:
if not self.reuse:
if not input_metadata.flashinfer_use_ragged:
if k is not None:
assert v is not None
self.store_kv_cache(k, v, input_metadata)
o = prefill_wrapper_paged.forward(
@@ -139,21 +139,20 @@ class RadixAttention(nn.Module):
logits_soft_cap=self.logit_cap,
)
else:
o1, s1 = prefill_wrapper_ragged.forward_return_lse(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
causal=True,
sm_scale=self.scaling,
window_left=self.sliding_window_size,
logits_soft_cap=self.logit_cap,
o1, s1 = (
input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
causal=True,
sm_scale=self.scaling,
logits_soft_cap=self.logit_cap,
)
)
if input_metadata.extend_no_prefix:
o = o1
else:
# TODO window attention + radix attention will come up in next PR
assert self.sliding_window_size == -1
o2, s2 = prefill_wrapper_paged.forward_return_lse(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
@@ -179,7 +178,8 @@ class RadixAttention(nn.Module):
if isinstance(decode_wrapper, list):
decode_wrapper = decode_wrapper[1]
if not self.reuse:
if k is not None:
assert v is not None
self.store_kv_cache(k, v, input_metadata)
o = decode_wrapper.forward(