Simplify flashinfer dispatch (#1552)
This commit is contained in:
@@ -14,7 +14,10 @@ import torch.nn as nn
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.layers.attention.flashinfer_utils import update_flashinfer_indices
|
||||
from sglang.srt.layers.attention.flashinfer_utils import (
|
||||
WrapperDispatch,
|
||||
update_flashinfer_indices,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
@@ -53,10 +56,19 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
assert not (
|
||||
model_runner.sliding_window_size is not None
|
||||
and model_runner.has_cross_attention
|
||||
), "Sliding window and cross attention are not supported together"
|
||||
|
||||
self.num_wrappers = 1
|
||||
self.dispatch_reason = None
|
||||
if model_runner.sliding_window_size is not None:
|
||||
self.num_wrappers = 2
|
||||
else:
|
||||
self.num_wrappers = 1
|
||||
self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
|
||||
elif model_runner.has_cross_attention:
|
||||
self.num_wrappers = 2
|
||||
self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
|
||||
|
||||
# NOTE: we do not use ragged attention when there are multiple wrappers
|
||||
self.prefill_wrapper_ragged = (
|
||||
@@ -88,8 +100,12 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
if self.num_wrappers == 1:
|
||||
return 0
|
||||
|
||||
# TODO: make sure the idx is related to sliding window size
|
||||
return layer.sliding_window_size == -1
|
||||
if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
||||
return layer.sliding_window_size == -1
|
||||
if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
||||
return layer.is_cross_attention
|
||||
|
||||
raise ValueError(f"Unknown dispatch reason: {self.dispatch_reason}")
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
from enum import Enum, auto
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
class WrapperDispatch(Enum):
|
||||
SLIDING_WINDOW = auto()
|
||||
CROSS_ATTENTION = auto()
|
||||
|
||||
|
||||
@triton.jit
|
||||
def create_flashinfer_kv_indices_triton(
|
||||
req_to_token_ptr, # [max_batch, max_context_len]
|
||||
@@ -80,67 +87,6 @@ class FlashinferUpdater:
|
||||
(self.batch_size,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
def _init_indices_no_sliding_window(self):
|
||||
if self.use_ragged:
|
||||
paged_kernel_lens = self.prefix_lens
|
||||
else:
|
||||
paged_kernel_lens = self.seq_lens
|
||||
|
||||
self.kv_indptr = torch.zeros(
|
||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
self.kv_indices = torch.empty(
|
||||
self.kv_indptr[-1], dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
create_flashinfer_kv_indices_triton[(self.batch_size,)](
|
||||
self.model_runner.req_to_token_pool.req_to_token,
|
||||
self.req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
self.kv_indptr,
|
||||
None,
|
||||
self.kv_indices,
|
||||
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
||||
)
|
||||
|
||||
def _init_indices_sliding_window(self, wrapper_id):
|
||||
if wrapper_id == 0:
|
||||
# window attention use paged only
|
||||
if self.forward_mode.is_decode():
|
||||
paged_kernel_lens = torch.minimum(
|
||||
self.seq_lens,
|
||||
torch.tensor(self.model_runner.sliding_window_size + 1),
|
||||
)
|
||||
else:
|
||||
paged_kernel_lens = torch.minimum(
|
||||
self.seq_lens,
|
||||
torch.tensor(self.model_runner.sliding_window_size)
|
||||
+ self.seq_lens
|
||||
- self.prefix_lens,
|
||||
)
|
||||
else:
|
||||
# full attention
|
||||
paged_kernel_lens = self.seq_lens
|
||||
|
||||
kv_start_idx = self.seq_lens - paged_kernel_lens
|
||||
self.kv_indptr = torch.zeros(
|
||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
self.kv_indices = torch.empty(
|
||||
self.kv_indptr[-1], dtype=torch.int32, device="cuda"
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(self.batch_size,)](
|
||||
self.model_runner.req_to_token_pool.req_to_token,
|
||||
self.req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
self.kv_indptr,
|
||||
kv_start_idx,
|
||||
self.kv_indices,
|
||||
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
||||
)
|
||||
|
||||
def _update_decode_indices(self, decode_wrapper):
|
||||
assert not isinstance(decode_wrapper, list)
|
||||
decode_wrapper.end_forward()
|
||||
@@ -189,8 +135,53 @@ class FlashinferUpdater:
|
||||
1,
|
||||
)
|
||||
|
||||
def update_indices_no_sliding_window(self):
|
||||
self._init_indices_no_sliding_window()
|
||||
def _get_indices(self, dispatch_reason: WrapperDispatch = None, wrapper_id=0):
|
||||
if dispatch_reason is None:
|
||||
if self.use_ragged:
|
||||
paged_kernel_lens = self.prefix_lens
|
||||
else:
|
||||
paged_kernel_lens = self.seq_lens
|
||||
self.kv_start_idx = None
|
||||
elif dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
||||
if wrapper_id == 0:
|
||||
# window attention use paged only
|
||||
if self.forward_mode.is_decode():
|
||||
paged_kernel_lens = torch.minimum(
|
||||
self.seq_lens,
|
||||
torch.tensor(self.model_runner.sliding_window_size + 1),
|
||||
)
|
||||
else:
|
||||
paged_kernel_lens = torch.minimum(
|
||||
self.seq_lens,
|
||||
torch.tensor(self.model_runner.sliding_window_size)
|
||||
+ self.seq_lens
|
||||
- self.prefix_lens,
|
||||
)
|
||||
else:
|
||||
# full attention
|
||||
paged_kernel_lens = self.seq_lens
|
||||
self.kv_start_idx = self.seq_lens - paged_kernel_lens
|
||||
|
||||
self.kv_indptr = torch.zeros(
|
||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
self.kv_indices = torch.empty(
|
||||
self.kv_indptr[-1], dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
create_flashinfer_kv_indices_triton[(self.batch_size,)](
|
||||
self.model_runner.req_to_token_pool.req_to_token,
|
||||
self.req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
self.kv_indptr,
|
||||
self.kv_start_idx,
|
||||
self.kv_indices,
|
||||
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
||||
)
|
||||
|
||||
def _update_indicess_single_wrapper(self):
|
||||
self._get_indices()
|
||||
|
||||
if self.forward_mode.is_decode():
|
||||
self._update_decode_indices(self.decode_wrappers[0])
|
||||
@@ -200,11 +191,13 @@ class FlashinferUpdater:
|
||||
self.prefill_wrappers_paged[0],
|
||||
)
|
||||
|
||||
def update_indices_sliding_window(self):
|
||||
assert self.use_ragged is False
|
||||
def _update_indices_cross_attention(self):
|
||||
pass
|
||||
|
||||
def _update_indices_sliding_window(self):
|
||||
assert self.use_ragged is False
|
||||
for wrapper_id in range(2):
|
||||
self._init_indices_sliding_window(wrapper_id)
|
||||
self._get_indices(WrapperDispatch.SLIDING_WINDOW, wrapper_id)
|
||||
if self.forward_mode.is_decode():
|
||||
self._update_decode_indices(self.decode_wrappers[wrapper_id])
|
||||
else:
|
||||
@@ -233,7 +226,12 @@ def update_flashinfer_indices(
|
||||
use_ragged,
|
||||
)
|
||||
|
||||
if model_runner.sliding_window_size is None:
|
||||
updater.update_indices_no_sliding_window()
|
||||
dispatch_reason = model_runner.attn_backend.dispatch_reason
|
||||
|
||||
if dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
||||
updater._update_indices_sliding_window()
|
||||
elif dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
||||
updater._update_indices_cross_attention()
|
||||
else:
|
||||
updater.update_indices_sliding_window()
|
||||
assert model_runner.attn_backend.num_wrappers == 1
|
||||
updater._update_indicess_single_wrapper()
|
||||
|
||||
@@ -32,9 +32,10 @@ class RadixAttention(nn.Module):
|
||||
scaling: float,
|
||||
num_kv_heads: int,
|
||||
layer_id: int,
|
||||
sliding_window_size: int = -1,
|
||||
logit_cap: float = 0.0,
|
||||
v_head_dim: int = -1,
|
||||
sliding_window_size: int = -1,
|
||||
is_cross_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_q_head_num = num_heads
|
||||
@@ -47,6 +48,7 @@ class RadixAttention(nn.Module):
|
||||
self.layer_id = layer_id
|
||||
self.logit_cap = logit_cap
|
||||
self.sliding_window_size = sliding_window_size or -1
|
||||
self.is_cross_attention = is_cross_attention
|
||||
|
||||
def forward(self, q, k, v, forward_batch: ForwardBatch):
|
||||
if k is not None:
|
||||
|
||||
@@ -231,6 +231,7 @@ class ModelRunner:
|
||||
if hasattr(self.model, "get_attention_sliding_window_size")
|
||||
else None
|
||||
)
|
||||
self.has_cross_attention = getattr(self.model, "has_cross_attention", False)
|
||||
self.is_generation = is_generation_model(
|
||||
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
||||
)
|
||||
@@ -453,6 +454,10 @@ class ModelRunner:
|
||||
"Window attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
assert not self.has_cross_attention, (
|
||||
"Cross attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
self.attn_backend = TritonAttnBackend(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
@@ -163,12 +163,12 @@ class Gemma2Attention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_idx,
|
||||
logit_cap=self.config.attn_logit_softcapping,
|
||||
sliding_window_size=(
|
||||
get_attention_sliding_window_size(config)
|
||||
if use_sliding_window
|
||||
else None
|
||||
),
|
||||
logit_cap=self.config.attn_logit_softcapping,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user