[Fix] Compatibility of window attention and cuda graph (#1090)

This commit is contained in:
Ying Sheng
2024-08-14 10:37:01 -07:00
committed by GitHub
parent a34dd86a7d
commit 96a2093ef0
7 changed files with 70 additions and 39 deletions

View File

@@ -34,6 +34,7 @@ class RadixAttention(nn.Module):
scaling: float,
num_kv_heads: int,
layer_id: int,
reuse: bool = False,
sliding_window_size: int = -1,
logit_cap: int = -1,
v_head_dim: int = -1,
@@ -47,6 +48,7 @@ class RadixAttention(nn.Module):
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
self.scaling = scaling
self.layer_id = layer_id
self.reuse = reuse
self.sliding_window_size = sliding_window_size
if (
@@ -127,8 +129,9 @@ class RadixAttention(nn.Module):
if isinstance(prefill_wrapper_paged, list):
prefill_wrapper_paged = prefill_wrapper_paged[1]
if not input_metadata.flashinfer_use_ragged:
self.store_kv_cache(k, v, input_metadata)
if not input_metadata.flashinfer_use_ragged or self.reuse:
if not self.reuse:
self.store_kv_cache(k, v, input_metadata)
o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
@@ -179,7 +182,8 @@ class RadixAttention(nn.Module):
if isinstance(decode_wrapper, list):
decode_wrapper = decode_wrapper[1]
self.store_kv_cache(k, v, input_metadata)
if not self.reuse:
self.store_kv_cache(k, v, input_metadata)
o = decode_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
@@ -191,8 +195,10 @@ class RadixAttention(nn.Module):
return o.view(-1, self.tp_q_head_num * self.head_dim)
def forward(self, q, k, v, input_metadata: InputMetadata):
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
if k is not None:
assert v is not None
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
if input_metadata.forward_mode == ForwardMode.EXTEND:
return self.extend_forward(q, k, v, input_metadata)