536 lines
18 KiB
Python
536 lines
18 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Optional
|
|
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
from sglang.srt.speculative.spec_info import SpecInfo
|
|
|
|
try:
|
|
from aiter import paged_attention_rocm
|
|
except ImportError:
|
|
print(
|
|
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
|
)
|
|
|
|
from sglang.srt.layers.attention.triton_ops.extend_attention import extend_attention_fwd
|
|
|
|
_AITER_PARTITION_SIZE_ROCM = 256
|
|
|
|
|
|
class AiterDecodeAttnBackend(AttentionBackend):
|
|
def __init__(
|
|
self,
|
|
model_runner: ModelRunner,
|
|
skip_prefill: bool = False,
|
|
kv_indptr_buf: Optional[torch.Tensor] = None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.decode_attention_fwd = paged_attention_rocm
|
|
self.extend_attention_fwd = extend_attention_fwd
|
|
|
|
self.skip_prefill = skip_prefill
|
|
|
|
max_bs = model_runner.req_to_token_pool.size
|
|
|
|
if kv_indptr_buf is None:
|
|
self.kv_indptr = torch.zeros(
|
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
|
)
|
|
else:
|
|
self.kv_indptr = kv_indptr_buf
|
|
|
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
|
|
|
if not self.skip_prefill:
|
|
self.qo_indptr = torch.zeros(
|
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
|
)
|
|
|
|
self.mask_indptr = torch.zeros(
|
|
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
|
|
)
|
|
|
|
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
|
|
|
# tp sharding on number of heads
|
|
self.num_head = (
|
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
|
)
|
|
|
|
self.head_dim = model_runner.model_config.head_dim
|
|
|
|
# triton prefill initialization
|
|
self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
|
|
|
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
|
|
|
self.num_v_head = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-2]
|
|
|
|
self.forward_metadata = None
|
|
|
|
self.max_context_len = model_runner.model_config.context_len
|
|
|
|
self.device = model_runner.device
|
|
|
|
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
|
|
|
self.q_dtype = model_runner.model_config.dtype
|
|
|
|
# aiter decode initialization
|
|
self.max_num_partitions = (
|
|
self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1
|
|
) // _AITER_PARTITION_SIZE_ROCM
|
|
|
|
nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8
|
|
|
|
self.workspace_buffer = torch.empty(
|
|
(max_bs * self.num_head * self.max_num_partitions * self.head_dim)
|
|
* nbyes_per_qo_elem
|
|
+ 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
|
|
dtype=torch.uint8,
|
|
device=self.device,
|
|
)
|
|
|
|
self.scale = float(1.0 / (self.head_dim**0.5))
|
|
self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to(
|
|
self.device
|
|
)
|
|
self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to(
|
|
self.device
|
|
)
|
|
|
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
|
"""Init auxiliary variables"""
|
|
bs = forward_batch.batch_size
|
|
kv_indptr = self.kv_indptr
|
|
spec_info = forward_batch.spec_info
|
|
|
|
if forward_batch.forward_mode.is_decode_or_idle():
|
|
if spec_info is None:
|
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
|
kv_indptr = kv_indptr[: bs + 1]
|
|
kv_indices = torch.zeros(
|
|
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
|
)
|
|
# prepare kv_indices and kv_indptr
|
|
create_flashinfer_kv_indices_triton[(bs,)](
|
|
self.req_to_token,
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.seq_lens,
|
|
kv_indptr,
|
|
None,
|
|
kv_indices,
|
|
self.req_to_token.stride(0),
|
|
)
|
|
else:
|
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
|
bs = kv_indptr.shape[0] - 1
|
|
|
|
attn_logits = None # accomodate forward_metadata format
|
|
qo_indptr = None
|
|
custom_mask = None
|
|
mask_indptr = None
|
|
max_extend_len = None
|
|
elif forward_batch.forward_mode.is_target_verify():
|
|
bs = len(forward_batch.req_pool_indices)
|
|
qo_indptr = torch.arange(
|
|
0,
|
|
(1 + bs) * self.num_draft_tokens,
|
|
step=self.num_draft_tokens,
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
|
kv_indptr = kv_indptr[: bs + 1]
|
|
kv_indices = torch.zeros(
|
|
kv_indptr[-1], dtype=torch.int32, device=self.device
|
|
)
|
|
create_flashinfer_kv_indices_triton[(bs,)](
|
|
self.req_to_token,
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.seq_lens,
|
|
kv_indptr,
|
|
None,
|
|
kv_indices,
|
|
self.req_to_token.stride(0),
|
|
)
|
|
|
|
custom_mask = spec_info.custom_mask
|
|
seq_mask_len = self.num_draft_tokens * (
|
|
forward_batch.seq_lens + self.num_draft_tokens
|
|
)
|
|
mask_indptr = self.mask_indptr
|
|
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
|
|
mask_indptr = mask_indptr[: bs + 1]
|
|
max_extend_len = self.num_draft_tokens
|
|
attn_logits = None
|
|
elif forward_batch.forward_mode.is_draft_extend():
|
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
|
spec_info.generate_attn_arg_prefill(
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.seq_lens,
|
|
self.req_to_token,
|
|
)
|
|
)
|
|
mask_indptr = None
|
|
max_extend_len = torch.max(spec_info.accept_length).item()
|
|
attn_logits = None
|
|
else:
|
|
kv_indptr[1 : bs + 1] = torch.cumsum(
|
|
forward_batch.extend_prefix_lens, dim=0
|
|
)
|
|
kv_indptr = kv_indptr[: bs + 1]
|
|
kv_indices = torch.zeros(
|
|
forward_batch.extend_prefix_lens.sum().item(),
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
|
|
create_flashinfer_kv_indices_triton[(bs,)](
|
|
self.req_to_token,
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.extend_prefix_lens,
|
|
kv_indptr,
|
|
None,
|
|
kv_indices,
|
|
self.req_to_token.stride(0),
|
|
)
|
|
|
|
qo_indptr = self.qo_indptr
|
|
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
|
|
qo_indptr = qo_indptr[: bs + 1]
|
|
custom_mask = None
|
|
mask_indptr = None
|
|
attn_logits = None
|
|
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
|
|
|
self.forward_metadata = (
|
|
attn_logits,
|
|
max_extend_len,
|
|
kv_indptr,
|
|
kv_indices,
|
|
qo_indptr,
|
|
custom_mask,
|
|
mask_indptr,
|
|
)
|
|
|
|
def init_cuda_graph_state(
|
|
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
|
):
|
|
|
|
self.cuda_graph_attn_logits = torch.zeros(
|
|
(max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
|
|
dtype=torch.float32,
|
|
device=self.device,
|
|
)
|
|
if kv_indices_buf is None:
|
|
self.cuda_graph_kv_indices = torch.zeros(
|
|
(max_bs * self.max_context_len),
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
else:
|
|
self.cuda_graph_kv_indices = kv_indices_buf
|
|
|
|
if not self.skip_prefill:
|
|
self.cuda_graph_custom_mask = torch.zeros(
|
|
(max_bs * self.max_context_len),
|
|
dtype=torch.uint8,
|
|
device=self.device,
|
|
)
|
|
|
|
def init_forward_metadata_capture_cuda_graph(
|
|
self,
|
|
bs: int,
|
|
num_tokens: int,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
encoder_lens: Optional[torch.Tensor],
|
|
forward_mode: ForwardMode,
|
|
spec_info: Optional[SpecInfo],
|
|
):
|
|
assert encoder_lens is None, "Not supported"
|
|
|
|
if forward_mode.is_decode_or_idle():
|
|
if spec_info is None:
|
|
kv_indptr = self.kv_indptr
|
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
|
kv_indptr = kv_indptr[: bs + 1]
|
|
kv_indices = self.cuda_graph_kv_indices
|
|
create_flashinfer_kv_indices_triton[(bs,)](
|
|
self.req_to_token,
|
|
req_pool_indices,
|
|
seq_lens,
|
|
kv_indptr,
|
|
None,
|
|
kv_indices,
|
|
self.req_to_token.stride(0),
|
|
)
|
|
else:
|
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
|
attn_logits = None
|
|
max_extend_len = None
|
|
qo_indptr = None
|
|
custom_mask = None
|
|
mask_indptr = None
|
|
elif forward_mode.is_target_verify():
|
|
qo_indptr = self.qo_indptr[: bs + 1]
|
|
qo_indptr[: bs + 1] = torch.arange(
|
|
0,
|
|
(1 + bs) * self.num_draft_tokens,
|
|
step=self.num_draft_tokens,
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
kv_indptr = self.kv_indptr[: bs + 1]
|
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
|
kv_indices = self.cuda_graph_kv_indices
|
|
create_flashinfer_kv_indices_triton[(bs,)](
|
|
self.req_to_token,
|
|
req_pool_indices,
|
|
seq_lens,
|
|
kv_indptr,
|
|
None,
|
|
kv_indices,
|
|
self.req_to_token.stride(0),
|
|
)
|
|
|
|
custom_mask = self.cuda_graph_custom_mask
|
|
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
|
mask_indptr = self.mask_indptr[: bs + 1]
|
|
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
|
max_extend_len = self.num_draft_tokens
|
|
attn_logits = None
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
|
|
)
|
|
|
|
self.forward_metadata = (
|
|
attn_logits,
|
|
max_extend_len,
|
|
kv_indptr,
|
|
kv_indices,
|
|
qo_indptr,
|
|
custom_mask,
|
|
mask_indptr,
|
|
)
|
|
|
|
def init_forward_metadata_replay_cuda_graph(
|
|
self,
|
|
bs: int,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
seq_lens_sum: int,
|
|
encoder_lens: Optional[torch.Tensor],
|
|
forward_mode: ForwardMode,
|
|
spec_info: Optional[SpecInfo],
|
|
seq_lens_cpu: Optional[torch.Tensor],
|
|
):
|
|
# NOTE: encoder_lens expected to be zeros or None
|
|
if forward_mode.is_decode_or_idle():
|
|
# Update kv_indptr, kv_indices
|
|
kv_indptr = self.kv_indptr
|
|
kv_indices = self.cuda_graph_kv_indices
|
|
if spec_info is None:
|
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
|
|
kv_indptr = kv_indptr[: bs + 1]
|
|
create_flashinfer_kv_indices_triton[(bs,)](
|
|
self.req_to_token,
|
|
req_pool_indices[:bs],
|
|
seq_lens[:bs],
|
|
kv_indptr,
|
|
None,
|
|
kv_indices,
|
|
self.req_to_token.stride(0),
|
|
)
|
|
else:
|
|
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
|
|
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
|
|
elif forward_mode.is_target_verify():
|
|
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
|
|
bs = len(req_pool_indices)
|
|
qo_indptr = self.qo_indptr[: bs + 1]
|
|
qo_indptr[: bs + 1] = torch.arange(
|
|
0,
|
|
(1 + bs) * self.num_draft_tokens,
|
|
step=self.num_draft_tokens,
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
kv_indptr = self.kv_indptr[: bs + 1]
|
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
|
kv_indices = self.cuda_graph_kv_indices
|
|
create_flashinfer_kv_indices_triton[(bs,)](
|
|
self.req_to_token,
|
|
req_pool_indices,
|
|
seq_lens,
|
|
kv_indptr,
|
|
None,
|
|
kv_indices,
|
|
self.req_to_token.stride(0),
|
|
)
|
|
custom_mask = self.cuda_graph_custom_mask
|
|
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
|
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
|
mask_indptr = self.mask_indptr[: bs + 1]
|
|
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid forward mode: {forward_mode=} for CUDA Graph replay."
|
|
)
|
|
|
|
def get_cuda_graph_seq_len_fill_value(self):
|
|
return 1
|
|
|
|
def forward_extend(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
layer: RadixAttention,
|
|
forward_batch: ForwardBatch,
|
|
save_kv_cache=True,
|
|
):
|
|
# TODO: reuse the buffer across layers
|
|
if layer.qk_head_dim != layer.v_head_dim:
|
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
|
else:
|
|
o = torch.empty_like(q)
|
|
|
|
if save_kv_cache:
|
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
layer, forward_batch.out_cache_loc, k, v
|
|
)
|
|
|
|
(
|
|
_,
|
|
max_extend_len,
|
|
kv_indptr,
|
|
kv_indices,
|
|
qo_indptr,
|
|
custom_mask,
|
|
mask_indptr,
|
|
) = self.forward_metadata
|
|
|
|
self.extend_attention_fwd(
|
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
|
k.contiguous(),
|
|
v.contiguous(),
|
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
|
qo_indptr,
|
|
kv_indptr,
|
|
kv_indices,
|
|
custom_mask,
|
|
mask_indptr,
|
|
max_extend_len,
|
|
layer.scaling,
|
|
layer.logit_cap,
|
|
)
|
|
return o
|
|
|
|
def forward_decode(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
layer: RadixAttention,
|
|
forward_batch: ForwardBatch,
|
|
save_kv_cache=True,
|
|
):
|
|
# During torch.compile, there is a bug in rotary_emb that causes the
|
|
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
|
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
|
|
|
if layer.qk_head_dim != layer.v_head_dim:
|
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
|
else:
|
|
o = torch.empty_like(q)
|
|
|
|
attn_logits, _, kv_indptr, kv_indices, _, _, _ = self.forward_metadata
|
|
|
|
if save_kv_cache:
|
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
layer, forward_batch.out_cache_loc, k, v
|
|
)
|
|
|
|
self.decode_attention_fwd(
|
|
o.view(
|
|
-1, layer.tp_q_head_num, layer.qk_head_dim
|
|
), # (bs, head_num_q, head_dim_q)
|
|
self.workspace_buffer,
|
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
|
|
-1, 1, layer.tp_k_head_num, layer.qk_head_dim
|
|
),
|
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
|
|
-1, 1, layer.tp_v_head_num, layer.v_head_dim
|
|
),
|
|
self.scale,
|
|
kv_indptr,
|
|
kv_indices,
|
|
self.kv_last_page_lens,
|
|
1,
|
|
self.max_num_partitions,
|
|
None,
|
|
"auto",
|
|
"NHD",
|
|
layer.logit_cap,
|
|
self.k_scale,
|
|
self.v_scale,
|
|
None,
|
|
_AITER_PARTITION_SIZE_ROCM,
|
|
)
|
|
|
|
return o
|
|
|
|
|
|
@triton.jit
|
|
def create_flashinfer_kv_indices_triton(
|
|
req_to_token_ptr, # [max_batch, max_context_len]
|
|
req_pool_indices_ptr,
|
|
page_kernel_lens_ptr,
|
|
kv_indptr,
|
|
kv_start_idx,
|
|
kv_indices_ptr,
|
|
req_to_token_ptr_stride: tl.constexpr,
|
|
):
|
|
BLOCK_SIZE: tl.constexpr = 512
|
|
pid = tl.program_id(axis=0)
|
|
|
|
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
|
kv_indices_offset = tl.load(kv_indptr + pid)
|
|
|
|
kv_start = 0
|
|
kv_end = 0
|
|
if kv_start_idx:
|
|
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
|
kv_end = kv_start
|
|
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
|
|
|
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
|
for i in range(num_loop):
|
|
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
|
mask = offset < kv_end - kv_start
|
|
data = tl.load(
|
|
req_to_token_ptr
|
|
+ req_pool_index * req_to_token_ptr_stride
|
|
+ kv_start
|
|
+ offset,
|
|
mask=mask,
|
|
)
|
|
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|