Simplify flashinfer dispatch (#1552)
This commit is contained in:
@@ -14,7 +14,10 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.attention import AttentionBackend
|
from sglang.srt.layers.attention import AttentionBackend
|
||||||
from sglang.srt.layers.attention.flashinfer_utils import update_flashinfer_indices
|
from sglang.srt.layers.attention.flashinfer_utils import (
|
||||||
|
WrapperDispatch,
|
||||||
|
update_flashinfer_indices,
|
||||||
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
@@ -53,10 +56,19 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert not (
|
||||||
|
model_runner.sliding_window_size is not None
|
||||||
|
and model_runner.has_cross_attention
|
||||||
|
), "Sliding window and cross attention are not supported together"
|
||||||
|
|
||||||
|
self.num_wrappers = 1
|
||||||
|
self.dispatch_reason = None
|
||||||
if model_runner.sliding_window_size is not None:
|
if model_runner.sliding_window_size is not None:
|
||||||
self.num_wrappers = 2
|
self.num_wrappers = 2
|
||||||
else:
|
self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
|
||||||
self.num_wrappers = 1
|
elif model_runner.has_cross_attention:
|
||||||
|
self.num_wrappers = 2
|
||||||
|
self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
|
||||||
|
|
||||||
# NOTE: we do not use ragged attention when there are multiple wrappers
|
# NOTE: we do not use ragged attention when there are multiple wrappers
|
||||||
self.prefill_wrapper_ragged = (
|
self.prefill_wrapper_ragged = (
|
||||||
@@ -88,8 +100,12 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
if self.num_wrappers == 1:
|
if self.num_wrappers == 1:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# TODO: make sure the idx is related to sliding window size
|
if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
||||||
return layer.sliding_window_size == -1
|
return layer.sliding_window_size == -1
|
||||||
|
if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
||||||
|
return layer.is_cross_attention
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown dispatch reason: {self.dispatch_reason}")
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
if forward_batch.forward_mode.is_decode():
|
if forward_batch.forward_mode.is_decode():
|
||||||
|
|||||||
@@ -1,8 +1,15 @@
|
|||||||
|
from enum import Enum, auto
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
class WrapperDispatch(Enum):
|
||||||
|
SLIDING_WINDOW = auto()
|
||||||
|
CROSS_ATTENTION = auto()
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def create_flashinfer_kv_indices_triton(
|
def create_flashinfer_kv_indices_triton(
|
||||||
req_to_token_ptr, # [max_batch, max_context_len]
|
req_to_token_ptr, # [max_batch, max_context_len]
|
||||||
@@ -80,67 +87,6 @@ class FlashinferUpdater:
|
|||||||
(self.batch_size,), dtype=torch.int32, device="cuda"
|
(self.batch_size,), dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_indices_no_sliding_window(self):
|
|
||||||
if self.use_ragged:
|
|
||||||
paged_kernel_lens = self.prefix_lens
|
|
||||||
else:
|
|
||||||
paged_kernel_lens = self.seq_lens
|
|
||||||
|
|
||||||
self.kv_indptr = torch.zeros(
|
|
||||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
|
||||||
self.kv_indices = torch.empty(
|
|
||||||
self.kv_indptr[-1], dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
create_flashinfer_kv_indices_triton[(self.batch_size,)](
|
|
||||||
self.model_runner.req_to_token_pool.req_to_token,
|
|
||||||
self.req_pool_indices,
|
|
||||||
paged_kernel_lens,
|
|
||||||
self.kv_indptr,
|
|
||||||
None,
|
|
||||||
self.kv_indices,
|
|
||||||
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _init_indices_sliding_window(self, wrapper_id):
|
|
||||||
if wrapper_id == 0:
|
|
||||||
# window attention use paged only
|
|
||||||
if self.forward_mode.is_decode():
|
|
||||||
paged_kernel_lens = torch.minimum(
|
|
||||||
self.seq_lens,
|
|
||||||
torch.tensor(self.model_runner.sliding_window_size + 1),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
paged_kernel_lens = torch.minimum(
|
|
||||||
self.seq_lens,
|
|
||||||
torch.tensor(self.model_runner.sliding_window_size)
|
|
||||||
+ self.seq_lens
|
|
||||||
- self.prefix_lens,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# full attention
|
|
||||||
paged_kernel_lens = self.seq_lens
|
|
||||||
|
|
||||||
kv_start_idx = self.seq_lens - paged_kernel_lens
|
|
||||||
self.kv_indptr = torch.zeros(
|
|
||||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
|
||||||
self.kv_indices = torch.empty(
|
|
||||||
self.kv_indptr[-1], dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
create_flashinfer_kv_indices_triton[(self.batch_size,)](
|
|
||||||
self.model_runner.req_to_token_pool.req_to_token,
|
|
||||||
self.req_pool_indices,
|
|
||||||
paged_kernel_lens,
|
|
||||||
self.kv_indptr,
|
|
||||||
kv_start_idx,
|
|
||||||
self.kv_indices,
|
|
||||||
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _update_decode_indices(self, decode_wrapper):
|
def _update_decode_indices(self, decode_wrapper):
|
||||||
assert not isinstance(decode_wrapper, list)
|
assert not isinstance(decode_wrapper, list)
|
||||||
decode_wrapper.end_forward()
|
decode_wrapper.end_forward()
|
||||||
@@ -189,8 +135,53 @@ class FlashinferUpdater:
|
|||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_indices_no_sliding_window(self):
|
def _get_indices(self, dispatch_reason: WrapperDispatch = None, wrapper_id=0):
|
||||||
self._init_indices_no_sliding_window()
|
if dispatch_reason is None:
|
||||||
|
if self.use_ragged:
|
||||||
|
paged_kernel_lens = self.prefix_lens
|
||||||
|
else:
|
||||||
|
paged_kernel_lens = self.seq_lens
|
||||||
|
self.kv_start_idx = None
|
||||||
|
elif dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
||||||
|
if wrapper_id == 0:
|
||||||
|
# window attention use paged only
|
||||||
|
if self.forward_mode.is_decode():
|
||||||
|
paged_kernel_lens = torch.minimum(
|
||||||
|
self.seq_lens,
|
||||||
|
torch.tensor(self.model_runner.sliding_window_size + 1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
paged_kernel_lens = torch.minimum(
|
||||||
|
self.seq_lens,
|
||||||
|
torch.tensor(self.model_runner.sliding_window_size)
|
||||||
|
+ self.seq_lens
|
||||||
|
- self.prefix_lens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# full attention
|
||||||
|
paged_kernel_lens = self.seq_lens
|
||||||
|
self.kv_start_idx = self.seq_lens - paged_kernel_lens
|
||||||
|
|
||||||
|
self.kv_indptr = torch.zeros(
|
||||||
|
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
|
self.kv_indices = torch.empty(
|
||||||
|
self.kv_indptr[-1], dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
create_flashinfer_kv_indices_triton[(self.batch_size,)](
|
||||||
|
self.model_runner.req_to_token_pool.req_to_token,
|
||||||
|
self.req_pool_indices,
|
||||||
|
paged_kernel_lens,
|
||||||
|
self.kv_indptr,
|
||||||
|
self.kv_start_idx,
|
||||||
|
self.kv_indices,
|
||||||
|
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_indicess_single_wrapper(self):
|
||||||
|
self._get_indices()
|
||||||
|
|
||||||
if self.forward_mode.is_decode():
|
if self.forward_mode.is_decode():
|
||||||
self._update_decode_indices(self.decode_wrappers[0])
|
self._update_decode_indices(self.decode_wrappers[0])
|
||||||
@@ -200,11 +191,13 @@ class FlashinferUpdater:
|
|||||||
self.prefill_wrappers_paged[0],
|
self.prefill_wrappers_paged[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_indices_sliding_window(self):
|
def _update_indices_cross_attention(self):
|
||||||
assert self.use_ragged is False
|
pass
|
||||||
|
|
||||||
|
def _update_indices_sliding_window(self):
|
||||||
|
assert self.use_ragged is False
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
self._init_indices_sliding_window(wrapper_id)
|
self._get_indices(WrapperDispatch.SLIDING_WINDOW, wrapper_id)
|
||||||
if self.forward_mode.is_decode():
|
if self.forward_mode.is_decode():
|
||||||
self._update_decode_indices(self.decode_wrappers[wrapper_id])
|
self._update_decode_indices(self.decode_wrappers[wrapper_id])
|
||||||
else:
|
else:
|
||||||
@@ -233,7 +226,12 @@ def update_flashinfer_indices(
|
|||||||
use_ragged,
|
use_ragged,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_runner.sliding_window_size is None:
|
dispatch_reason = model_runner.attn_backend.dispatch_reason
|
||||||
updater.update_indices_no_sliding_window()
|
|
||||||
|
if dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
||||||
|
updater._update_indices_sliding_window()
|
||||||
|
elif dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
||||||
|
updater._update_indices_cross_attention()
|
||||||
else:
|
else:
|
||||||
updater.update_indices_sliding_window()
|
assert model_runner.attn_backend.num_wrappers == 1
|
||||||
|
updater._update_indicess_single_wrapper()
|
||||||
|
|||||||
@@ -32,9 +32,10 @@ class RadixAttention(nn.Module):
|
|||||||
scaling: float,
|
scaling: float,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
sliding_window_size: int = -1,
|
|
||||||
logit_cap: float = 0.0,
|
logit_cap: float = 0.0,
|
||||||
v_head_dim: int = -1,
|
v_head_dim: int = -1,
|
||||||
|
sliding_window_size: int = -1,
|
||||||
|
is_cross_attention: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_q_head_num = num_heads
|
self.tp_q_head_num = num_heads
|
||||||
@@ -47,6 +48,7 @@ class RadixAttention(nn.Module):
|
|||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
self.logit_cap = logit_cap
|
self.logit_cap = logit_cap
|
||||||
self.sliding_window_size = sliding_window_size or -1
|
self.sliding_window_size = sliding_window_size or -1
|
||||||
|
self.is_cross_attention = is_cross_attention
|
||||||
|
|
||||||
def forward(self, q, k, v, forward_batch: ForwardBatch):
|
def forward(self, q, k, v, forward_batch: ForwardBatch):
|
||||||
if k is not None:
|
if k is not None:
|
||||||
|
|||||||
@@ -231,6 +231,7 @@ class ModelRunner:
|
|||||||
if hasattr(self.model, "get_attention_sliding_window_size")
|
if hasattr(self.model, "get_attention_sliding_window_size")
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
self.has_cross_attention = getattr(self.model, "has_cross_attention", False)
|
||||||
self.is_generation = is_generation_model(
|
self.is_generation = is_generation_model(
|
||||||
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
||||||
)
|
)
|
||||||
@@ -453,6 +454,10 @@ class ModelRunner:
|
|||||||
"Window attention is not supported in the triton attention backend. "
|
"Window attention is not supported in the triton attention backend. "
|
||||||
"Please use `--attention-backend flashinfer`."
|
"Please use `--attention-backend flashinfer`."
|
||||||
)
|
)
|
||||||
|
assert not self.has_cross_attention, (
|
||||||
|
"Cross attention is not supported in the triton attention backend. "
|
||||||
|
"Please use `--attention-backend flashinfer`."
|
||||||
|
)
|
||||||
self.attn_backend = TritonAttnBackend(self)
|
self.attn_backend = TritonAttnBackend(self)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -163,12 +163,12 @@ class Gemma2Attention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_idx,
|
layer_id=layer_idx,
|
||||||
|
logit_cap=self.config.attn_logit_softcapping,
|
||||||
sliding_window_size=(
|
sliding_window_size=(
|
||||||
get_attention_sliding_window_size(config)
|
get_attention_sliding_window_size(config)
|
||||||
if use_sliding_window
|
if use_sliding_window
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
logit_cap=self.config.attn_logit_softcapping,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
Reference in New Issue
Block a user