[Fix] Window attention compatible with RadixAttention and chunked prefill (#1112)
This commit is contained in:
@@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
"""Radix attention."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from flashinfer.cascade import merge_state
|
||||
from torch import nn
|
||||
@@ -34,8 +36,7 @@ class RadixAttention(nn.Module):
|
||||
scaling: float,
|
||||
num_kv_heads: int,
|
||||
layer_id: int,
|
||||
reuse: bool = False,
|
||||
sliding_window_size: int = -1,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
logit_cap: int = -1,
|
||||
v_head_dim: int = -1,
|
||||
):
|
||||
@@ -48,8 +49,7 @@ class RadixAttention(nn.Module):
|
||||
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
|
||||
self.scaling = scaling
|
||||
self.layer_id = layer_id
|
||||
self.reuse = reuse
|
||||
self.sliding_window_size = sliding_window_size
|
||||
self.sliding_window_size = sliding_window_size if sliding_window_size else -1
|
||||
|
||||
if (
|
||||
not global_server_args_dict.get("disable_flashinfer", False)
|
||||
@@ -118,16 +118,16 @@ class RadixAttention(nn.Module):
|
||||
|
||||
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
||||
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
||||
prefill_wrapper_ragged = input_metadata.flashinfer_prefill_wrapper_ragged
|
||||
prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
|
||||
if self.sliding_window_size != -1 or self.reuse:
|
||||
if self.sliding_window_size != -1:
|
||||
prefill_wrapper_paged = prefill_wrapper_paged[0]
|
||||
else:
|
||||
if isinstance(prefill_wrapper_paged, list):
|
||||
prefill_wrapper_paged = prefill_wrapper_paged[1]
|
||||
|
||||
if not input_metadata.flashinfer_use_ragged or self.reuse:
|
||||
if not self.reuse:
|
||||
if not input_metadata.flashinfer_use_ragged:
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
o = prefill_wrapper_paged.forward(
|
||||
@@ -139,21 +139,20 @@ class RadixAttention(nn.Module):
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
else:
|
||||
o1, s1 = prefill_wrapper_ragged.forward_return_lse(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
||||
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
||||
causal=True,
|
||||
sm_scale=self.scaling,
|
||||
window_left=self.sliding_window_size,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
o1, s1 = (
|
||||
input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
||||
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
||||
causal=True,
|
||||
sm_scale=self.scaling,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
)
|
||||
|
||||
if input_metadata.extend_no_prefix:
|
||||
o = o1
|
||||
else:
|
||||
# TODO window attention + radix attention will come up in next PR
|
||||
assert self.sliding_window_size == -1
|
||||
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
||||
@@ -179,7 +178,8 @@ class RadixAttention(nn.Module):
|
||||
if isinstance(decode_wrapper, list):
|
||||
decode_wrapper = decode_wrapper[1]
|
||||
|
||||
if not self.reuse:
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
o = decode_wrapper.forward(
|
||||
|
||||
@@ -194,6 +194,7 @@ class InputMetadata:
|
||||
if (
|
||||
forward_mode != ForwardMode.DECODE
|
||||
and int(torch.sum(ret.seq_lens)) > 4096
|
||||
and model_runner.sliding_window_size is None
|
||||
):
|
||||
flashinfer_use_ragged = True
|
||||
ret.init_flashinfer_handlers(
|
||||
@@ -322,22 +323,25 @@ def update_flashinfer_indices(
|
||||
1,
|
||||
)
|
||||
else:
|
||||
# window attention use paged only
|
||||
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
||||
for wrapper_id in range(2):
|
||||
if flashinfer_use_ragged and wrapper_id == 1:
|
||||
# full attention use ragged+paged
|
||||
paged_kernel_lens = prefix_lens
|
||||
if wrapper_id == 0:
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
paged_kernel_lens = torch.minimum(
|
||||
seq_lens, torch.tensor(model_runner.sliding_window_size + 1)
|
||||
)
|
||||
else:
|
||||
paged_kernel_lens = torch.minimum(
|
||||
seq_lens,
|
||||
torch.tensor(model_runner.sliding_window_size)
|
||||
+ seq_lens
|
||||
- prefix_lens,
|
||||
)
|
||||
else:
|
||||
# window attention use paged only
|
||||
paged_kernel_lens = seq_lens
|
||||
|
||||
if wrapper_id == 0 and forward_mode == ForwardMode.DECODE:
|
||||
paged_kernel_lens = torch.minimum(
|
||||
paged_kernel_lens, torch.tensor(model_runner.sliding_window_size)
|
||||
)
|
||||
kv_start_idx = seq_lens - paged_kernel_lens
|
||||
else:
|
||||
kv_start_idx = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||
kv_start_idx = seq_lens - paged_kernel_lens
|
||||
|
||||
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
@@ -376,17 +380,6 @@ def update_flashinfer_indices(
|
||||
)
|
||||
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||
|
||||
if flashinfer_use_ragged and wrapper_id == 1:
|
||||
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
||||
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
||||
qo_indptr,
|
||||
qo_indptr,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
# cached part
|
||||
model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].end_forward()
|
||||
model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].begin_forward(
|
||||
qo_indptr,
|
||||
|
||||
@@ -334,11 +334,7 @@ class ModelRunner:
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_ragged = (
|
||||
BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffer, "NHD"
|
||||
)
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_ragged = None
|
||||
self.flashinfer_prefill_wrapper_paged = []
|
||||
self.flashinfer_decode_wrapper = []
|
||||
for i in range(2):
|
||||
|
||||
@@ -213,7 +213,7 @@ class Gemma2Attention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_idx,
|
||||
sliding_window_size=get_window_size(config) if use_sliding_window else -1,
|
||||
sliding_window_size=get_window_size(config) if use_sliding_window else None,
|
||||
logit_cap=self.config.attn_logit_softcapping,
|
||||
)
|
||||
|
||||
|
||||
@@ -450,16 +450,8 @@ class ServerArgs:
|
||||
self.dp_size > 1 and self.node_rank is not None
|
||||
), "multi-node data parallel is not supported"
|
||||
if "gemma-2" in self.model_path.lower():
|
||||
logger.info(
|
||||
f"When using sliding window in gemma-2, disable radix_cache, regex_jump_forward, and turn on flashinfer."
|
||||
)
|
||||
# FIXME: compatibility with radix attention
|
||||
self.disable_radix_cache = True
|
||||
# FIXME: compatibility with jump forward
|
||||
self.disable_regex_jump_forward = True
|
||||
logger.info(f"When using sliding window in gemma-2, turn on flashinfer.")
|
||||
self.disable_flashinfer = False
|
||||
# FIXME: compatibility with chunked prefill
|
||||
self.chunked_prefill_size = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
||||
Reference in New Issue
Block a user