Refactor attention backend (#1381)
This commit is contained in:
@@ -15,25 +15,14 @@ limitations under the License.
|
||||
|
||||
"""Radix attention."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from flashinfer.cascade import merge_state
|
||||
from torch import nn
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.triton_attention.decode_attention import decode_attention_fwd
|
||||
from sglang.srt.layers.triton_attention.extend_attention import extend_attention_fwd
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.model_executor.model_runner import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
class RadixAttention(nn.Module):
|
||||
"""
|
||||
The attention layer implementation.
|
||||
Now it has two backends: FlashInfer and Triton.
|
||||
FlashInfer is faster and Triton is easier to customize.
|
||||
It supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -43,8 +32,8 @@ class RadixAttention(nn.Module):
|
||||
scaling: float,
|
||||
num_kv_heads: int,
|
||||
layer_id: int,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
logit_cap: int = -1,
|
||||
sliding_window_size: int = -1,
|
||||
logit_cap: float = 0.0,
|
||||
v_head_dim: int = -1,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -56,164 +45,14 @@ class RadixAttention(nn.Module):
|
||||
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
|
||||
self.scaling = scaling
|
||||
self.layer_id = layer_id
|
||||
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
||||
self.sliding_window_size = sliding_window_size if sliding_window_size else -1
|
||||
|
||||
# Choose backend
|
||||
if (
|
||||
global_server_args_dict["attention_backend"] == "flashinfer"
|
||||
and self.qk_head_dim == self.v_head_dim
|
||||
):
|
||||
self.extend_forward = self.extend_forward_flashinfer
|
||||
self.decode_forward = self.decode_forward_flashinfer
|
||||
elif global_server_args_dict["attention_backend"] == "triton":
|
||||
self.extend_forward = self.extend_forward_triton
|
||||
self.decode_forward = self.decode_forward_triton
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: {global_server_args_dict['attention_backend']}"
|
||||
)
|
||||
|
||||
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||
if self.qk_head_dim != self.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
extend_attention_fwd(
|
||||
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
||||
k.contiguous(),
|
||||
v.contiguous(),
|
||||
o.view(-1, self.tp_q_head_num, self.v_head_dim),
|
||||
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
||||
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
||||
input_metadata.req_to_token_pool.req_to_token,
|
||||
input_metadata.req_pool_indices,
|
||||
input_metadata.triton_start_loc,
|
||||
input_metadata.seq_lens,
|
||||
input_metadata.triton_prefix_lens,
|
||||
input_metadata.extend_start_loc,
|
||||
input_metadata.extend_seq_lens,
|
||||
input_metadata.triton_max_seq_len,
|
||||
input_metadata.triton_max_extend_len,
|
||||
sm_scale=self.scaling,
|
||||
logit_cap=self.logit_cap,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||
if self.qk_head_dim != self.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
decode_attention_fwd(
|
||||
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
||||
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
||||
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
||||
o.view(-1, self.tp_q_head_num, self.v_head_dim),
|
||||
input_metadata.req_to_token_pool.req_to_token,
|
||||
input_metadata.req_pool_indices,
|
||||
input_metadata.triton_start_loc,
|
||||
input_metadata.seq_lens,
|
||||
input_metadata.triton_max_seq_len,
|
||||
input_metadata.total_num_tokens,
|
||||
sm_scale=self.scaling,
|
||||
logit_cap=self.logit_cap,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
||||
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
||||
prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
|
||||
if self.sliding_window_size != -1:
|
||||
prefill_wrapper_paged = prefill_wrapper_paged[0]
|
||||
else:
|
||||
if isinstance(prefill_wrapper_paged, list):
|
||||
prefill_wrapper_paged = prefill_wrapper_paged[1]
|
||||
|
||||
if not input_metadata.flashinfer_use_ragged:
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
o = prefill_wrapper_paged.forward(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
||||
causal=True,
|
||||
sm_scale=self.scaling,
|
||||
window_left=self.sliding_window_size,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
else:
|
||||
o1, s1 = (
|
||||
input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
||||
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
||||
causal=True,
|
||||
sm_scale=self.scaling,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
)
|
||||
|
||||
if input_metadata.extend_no_prefix:
|
||||
o = o1
|
||||
else:
|
||||
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
||||
causal=False,
|
||||
sm_scale=self.scaling,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
|
||||
o, _ = merge_state(o1, s1, o2, s2)
|
||||
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||
|
||||
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
||||
decode_wrapper = input_metadata.flashinfer_decode_wrapper
|
||||
if self.sliding_window_size != -1:
|
||||
decode_wrapper = decode_wrapper[0]
|
||||
else:
|
||||
if isinstance(decode_wrapper, list):
|
||||
decode_wrapper = decode_wrapper[1]
|
||||
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
o = decode_wrapper.forward(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
||||
sm_scale=self.scaling,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
|
||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||
self.logit_cap = logit_cap
|
||||
self.sliding_window_size = sliding_window_size or -1
|
||||
|
||||
def forward(self, q, k, v, input_metadata: InputMetadata):
|
||||
if k is not None:
|
||||
# For cross-layer sharing, kv can be None
|
||||
assert v is not None
|
||||
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
||||
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
||||
|
||||
if input_metadata.forward_mode.is_extend():
|
||||
return self.extend_forward(q, k, v, input_metadata)
|
||||
elif input_metadata.forward_mode.is_decode():
|
||||
return self.decode_forward(q, k, v, input_metadata)
|
||||
|
||||
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
||||
input_metadata.token_to_kv_pool.set_kv_buffer(
|
||||
self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v
|
||||
)
|
||||
return input_metadata.attn_backend.forward(q, k, v, self, input_metadata)
|
||||
|
||||
Reference in New Issue
Block a user