514 lines
18 KiB
Python
514 lines
18 KiB
Python
from __future__ import annotations
|
|
|
|
"""
|
|
end to end attention solution with aiter kernels
|
|
"""
|
|
|
|
import math
|
|
import os
|
|
from dataclasses import dataclass
|
|
from enum import Enum, auto
|
|
from functools import partial
|
|
from typing import TYPE_CHECKING, List, Optional, Union
|
|
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
from sglang.global_config import global_config
|
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
from sglang.srt.speculative.spec_info import SpecInfo
|
|
|
|
try:
|
|
from aiter import mha_batch_prefill_func, paged_attention_ragged
|
|
except ImportError:
|
|
print(
|
|
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
|
)
|
|
|
|
|
|
class WrapperDispatch(Enum):
|
|
SLIDING_WINDOW = auto()
|
|
CROSS_ATTENTION = auto()
|
|
|
|
|
|
@dataclass
|
|
class ForwardMetadata:
|
|
kv_indptr: torch.Tensor
|
|
kv_indices: torch.Tensor
|
|
max_q_len: int
|
|
max_kv_len: int
|
|
|
|
|
|
global_workspace_buffer = None
|
|
|
|
_AITER_PARTITION_SIZE_ROCM = 256
|
|
|
|
|
|
class AiterAttnBackend(AttentionBackend):
|
|
def __init__(
|
|
self,
|
|
model_runner: ModelRunner,
|
|
skip_prefill: bool = False,
|
|
kv_indptr_buf: Optional[torch.Tensor] = None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.device = model_runner.device
|
|
self.is_multimodal = model_runner.model_config.is_multimodal
|
|
self.num_head = (
|
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
|
)
|
|
self.head_dim = model_runner.model_config.head_dim
|
|
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
|
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
|
get_attention_tp_size()
|
|
)
|
|
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
|
|
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
|
|
|
# Parse constants
|
|
self.max_context_len = model_runner.model_config.context_len
|
|
self.skip_prefill = skip_prefill
|
|
|
|
max_bs = model_runner.req_to_token_pool.size
|
|
|
|
if kv_indptr_buf is None:
|
|
self.kv_indptr = torch.zeros(
|
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
|
)
|
|
else:
|
|
self.kv_indptr = kv_indptr_buf
|
|
|
|
self.kv_last_page_len = torch.ones(
|
|
(max_bs,), dtype=torch.int32, device=model_runner.device
|
|
)
|
|
self.qo_indptr = torch.zeros(
|
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
|
)
|
|
|
|
# Create prefill indices updater
|
|
if not skip_prefill:
|
|
self.indices_updater_prefill = AiterIndicesUpdaterPrefill(
|
|
model_runner, self
|
|
)
|
|
|
|
# aiter kernel related initialization
|
|
self.max_num_partitions = (
|
|
self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1
|
|
) // _AITER_PARTITION_SIZE_ROCM
|
|
|
|
nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8
|
|
|
|
self.workspace_buffer = torch.empty(
|
|
(max_bs * self.num_head * self.max_num_partitions * self.head_dim)
|
|
* nbyes_per_qo_elem
|
|
+ 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
|
|
dtype=torch.uint8,
|
|
device=self.device,
|
|
)
|
|
|
|
self.scale = float(1.0 / (self.head_dim**0.5))
|
|
self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to(
|
|
self.device
|
|
)
|
|
self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to(
|
|
self.device
|
|
)
|
|
|
|
self.logits_soft_cap = 0.0
|
|
|
|
self.forward_metadata: ForwardMetadata = None
|
|
|
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
|
if forward_batch.forward_mode.is_decode_or_idle():
|
|
# update for aiter
|
|
# create kv_indices and kv_inptr
|
|
bs = forward_batch.batch_size
|
|
kv_indptr = self.kv_indptr
|
|
spec_info = forward_batch.spec_info
|
|
if spec_info is None:
|
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
|
kv_indptr = kv_indptr[: bs + 1]
|
|
kv_indices = torch.zeros(
|
|
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
|
)
|
|
create_flashinfer_kv_indices_triton[(bs,)](
|
|
self.req_to_token,
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.seq_lens,
|
|
kv_indptr,
|
|
None,
|
|
kv_indices,
|
|
self.req_to_token.stride(0),
|
|
)
|
|
else:
|
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
|
bs = kv_indptr.shape[0] - 1
|
|
|
|
self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None)
|
|
|
|
elif forward_batch.forward_mode.is_draft_extend():
|
|
self.indices_updater_prefill.update(
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.seq_lens,
|
|
forward_batch.seq_lens_sum,
|
|
prefix_lens=None,
|
|
encoder_lens=forward_batch.encoder_lens,
|
|
spec_info=forward_batch.spec_info,
|
|
)
|
|
self.forward_metadata = ForwardMetadata(
|
|
self.indices_updater_prefill.kv_indptr,
|
|
self.indices_updater_prefill.kv_indices,
|
|
self.indices_updater_prefill.max_q_len,
|
|
self.indices_updater_prefill.max_kv_len,
|
|
)
|
|
elif forward_batch.forward_mode.is_target_verify():
|
|
self.indices_updater_prefill.update(
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.seq_lens,
|
|
forward_batch.seq_lens_sum,
|
|
prefix_lens=None,
|
|
encoder_lens=forward_batch.encoder_lens,
|
|
spec_info=forward_batch.spec_info,
|
|
)
|
|
self.forward_metadata = ForwardMetadata(
|
|
self.indices_updater_prefill.kv_indptr,
|
|
self.indices_updater_prefill.kv_indices,
|
|
self.indices_updater_prefill.max_q_len,
|
|
self.indices_updater_prefill.max_kv_len,
|
|
)
|
|
else:
|
|
prefix_lens = forward_batch.extend_prefix_lens
|
|
|
|
if self.is_multimodal:
|
|
extend_no_prefix = False
|
|
else:
|
|
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
|
|
|
self.indices_updater_prefill.update(
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.seq_lens,
|
|
forward_batch.seq_lens_sum,
|
|
prefix_lens,
|
|
encoder_lens=forward_batch.encoder_lens,
|
|
spec_info=None,
|
|
)
|
|
self.forward_metadata = ForwardMetadata(
|
|
self.indices_updater_prefill.kv_indptr,
|
|
self.indices_updater_prefill.kv_indices,
|
|
self.indices_updater_prefill.max_q_len,
|
|
self.indices_updater_prefill.max_kv_len,
|
|
)
|
|
|
|
def init_cuda_graph_state(
|
|
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
|
):
|
|
if kv_indices_buf is None:
|
|
self.cuda_graph_kv_indices = torch.zeros(
|
|
(max_bs * self.max_context_len),
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
else:
|
|
self.cuda_graph_kv_indices = kv_indices_buf
|
|
|
|
if not self.skip_prefill:
|
|
self.cuda_graph_custom_mask = torch.zeros(
|
|
(max_bs * self.max_context_len),
|
|
dtype=torch.uint8,
|
|
device=self.device,
|
|
)
|
|
|
|
def init_forward_metadata_capture_cuda_graph(
|
|
self,
|
|
bs: int,
|
|
num_tokens: int,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
encoder_lens: Optional[torch.Tensor],
|
|
forward_mode: ForwardMode,
|
|
spec_info: Optional[SpecInfo],
|
|
):
|
|
if forward_mode.is_decode_or_idle():
|
|
if spec_info is None:
|
|
kv_indptr = self.kv_indptr
|
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
|
kv_indptr = kv_indptr[: bs + 1]
|
|
kv_indices = self.cuda_graph_kv_indices
|
|
create_flashinfer_kv_indices_triton[(bs,)](
|
|
self.req_to_token,
|
|
req_pool_indices,
|
|
seq_lens,
|
|
kv_indptr,
|
|
None,
|
|
kv_indices,
|
|
self.req_to_token.stride(0),
|
|
)
|
|
else:
|
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
|
self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None)
|
|
|
|
elif forward_mode.is_target_verify():
|
|
seq_lens_sum = seq_lens.sum().item()
|
|
self.indices_updater_prefill.update(
|
|
req_pool_indices,
|
|
seq_lens,
|
|
seq_lens_sum,
|
|
prefix_lens=None,
|
|
encoder_lens=encoder_lens,
|
|
spec_info=spec_info,
|
|
)
|
|
self.forward_metadata = ForwardMetadata(
|
|
self.indices_updater_prefill.kv_indptr,
|
|
self.indices_updater_prefill.kv_indices,
|
|
self.indices_updater_prefill.max_q_len,
|
|
self.indices_updater_prefill.max_kv_len,
|
|
)
|
|
|
|
else:
|
|
raise ValueError(f"Invalid mode: {forward_mode=}")
|
|
|
|
def init_forward_metadata_replay_cuda_graph(
|
|
self,
|
|
bs: int,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
seq_lens_sum: int,
|
|
encoder_lens: Optional[torch.Tensor],
|
|
forward_mode: ForwardMode,
|
|
spec_info: Optional[SpecInfo],
|
|
seq_lens_cpu: Optional[torch.Tensor],
|
|
):
|
|
if forward_mode.is_decode_or_idle():
|
|
kv_indptr = self.kv_indptr
|
|
kv_indices = self.cuda_graph_kv_indices
|
|
if spec_info is None:
|
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
|
|
kv_indptr = kv_indptr[: bs + 1]
|
|
create_flashinfer_kv_indices_triton[(bs,)](
|
|
self.req_to_token,
|
|
req_pool_indices[:bs],
|
|
seq_lens[:bs],
|
|
kv_indptr,
|
|
None,
|
|
kv_indices,
|
|
self.req_to_token.stride(0),
|
|
)
|
|
else:
|
|
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
|
|
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
|
|
|
|
elif forward_mode.is_target_verify():
|
|
self.indices_updater_prefill.update(
|
|
req_pool_indices[:bs],
|
|
seq_lens[:bs],
|
|
seq_lens_sum,
|
|
prefix_lens=None,
|
|
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
|
spec_info=spec_info,
|
|
)
|
|
else:
|
|
raise ValueError("Invalid forward mode")
|
|
|
|
def get_cuda_graph_seq_len_fill_value(self):
|
|
return 1
|
|
|
|
def forward_extend(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
layer: RadixAttention,
|
|
forward_batch: ForwardBatch,
|
|
save_kv_cache=True,
|
|
):
|
|
cache_loc = (
|
|
forward_batch.out_cache_loc
|
|
if not layer.is_cross_attention
|
|
else forward_batch.encoder_out_cache_loc
|
|
)
|
|
|
|
self.logits_soft_cap = layer.logit_cap
|
|
|
|
if k is not None:
|
|
assert v is not None
|
|
if save_kv_cache:
|
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
|
)
|
|
|
|
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
|
|
|
bs0 = forward_batch.batch_size + 1
|
|
|
|
o = mha_batch_prefill_func(
|
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
|
k_cache,
|
|
v_cache,
|
|
self.qo_indptr[:bs0],
|
|
self.forward_metadata.kv_indptr[:bs0],
|
|
self.forward_metadata.kv_indices,
|
|
self.forward_metadata.max_q_len,
|
|
self.forward_metadata.max_kv_len,
|
|
causal=True,
|
|
logits_soft_cap=self.logits_soft_cap,
|
|
alibi_slopes=None,
|
|
return_lse=False,
|
|
return_attn_probs=False,
|
|
)
|
|
|
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
|
|
|
def forward_decode(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
layer: RadixAttention,
|
|
forward_batch: ForwardBatch,
|
|
save_kv_cache=True,
|
|
):
|
|
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
|
|
|
if layer.qk_head_dim != layer.v_head_dim:
|
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
|
else:
|
|
o = torch.empty_like(q)
|
|
|
|
if save_kv_cache:
|
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
layer, forward_batch.out_cache_loc, k, v
|
|
)
|
|
|
|
self.logits_soft_cap = layer.logit_cap
|
|
paged_attention_ragged(
|
|
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
|
self.workspace_buffer,
|
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
|
|
-1, 1, layer.tp_k_head_num, layer.qk_head_dim
|
|
),
|
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
|
|
-1, 1, layer.tp_v_head_num, layer.v_head_dim
|
|
),
|
|
self.scale,
|
|
self.forward_metadata.kv_indptr,
|
|
self.forward_metadata.kv_indices,
|
|
self.kv_last_page_lens,
|
|
1,
|
|
self.max_num_partitions,
|
|
None,
|
|
"auto",
|
|
"NHD",
|
|
self.logits_soft_cap,
|
|
self.k_scale,
|
|
self.v_scale,
|
|
None,
|
|
_AITER_PARTITION_SIZE_ROCM,
|
|
)
|
|
|
|
return o
|
|
|
|
|
|
class AiterIndicesUpdaterPrefill:
|
|
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
|
# Parse Constants
|
|
self.num_qo_heads = (
|
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
|
)
|
|
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
|
get_attention_tp_size()
|
|
)
|
|
self.head_dim = model_runner.model_config.head_dim
|
|
self.data_type = model_runner.kv_cache_dtype
|
|
self.q_data_type = model_runner.dtype
|
|
self.sliding_window_size = model_runner.sliding_window_size
|
|
self.attn_backend = attn_backend
|
|
|
|
# Buffers and wrappers
|
|
self.kv_indptr = attn_backend.kv_indptr
|
|
self.kv_last_page_len = attn_backend.kv_last_page_len
|
|
self.qo_indptr = attn_backend.qo_indptr
|
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
|
self.update = self.update_single_wrapper
|
|
|
|
self.kv_indices = None
|
|
self.max_q_len = 0
|
|
self.max_kv_len = 0
|
|
|
|
def update(
|
|
self,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
seq_lens_sum: int,
|
|
prefix_lens: torch.Tensor,
|
|
encoder_lens: Optional[torch.Tensor],
|
|
spec_info: Optional[SpecInfo],
|
|
):
|
|
# Keep the signature for type checking. It will be assigned during runtime.
|
|
raise NotImplementedError()
|
|
|
|
def update_single_wrapper(
|
|
self,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
seq_lens_sum: int,
|
|
prefix_lens: torch.Tensor,
|
|
encoder_lens: Optional[torch.Tensor],
|
|
spec_info: Optional[SpecInfo],
|
|
):
|
|
|
|
kv_start_idx = None
|
|
kv_indptr = self.kv_indptr
|
|
qo_indptr = self.qo_indptr
|
|
paged_kernel_lens = seq_lens
|
|
paged_kernel_lens_sum = seq_lens_sum
|
|
|
|
bs = len(req_pool_indices)
|
|
if spec_info is None:
|
|
# Normal extend
|
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
|
kv_indptr = kv_indptr[: bs + 1]
|
|
kv_indices = torch.empty(
|
|
paged_kernel_lens_sum + 256,
|
|
dtype=torch.int32,
|
|
device=req_pool_indices.device,
|
|
)
|
|
create_flashinfer_kv_indices_triton[(bs,)](
|
|
self.req_to_token,
|
|
req_pool_indices,
|
|
paged_kernel_lens,
|
|
kv_indptr,
|
|
kv_start_idx,
|
|
kv_indices,
|
|
self.req_to_token.shape[1],
|
|
)
|
|
|
|
self.max_kv_len = torch.max(paged_kernel_lens).item()
|
|
|
|
extend_lens = seq_lens - prefix_lens
|
|
self.max_q_len = torch.max(extend_lens).item()
|
|
|
|
qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)
|
|
qo_indptr = qo_indptr[: bs + 1]
|
|
custom_mask = None
|
|
else:
|
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
|
spec_info.generate_attn_arg_prefill(
|
|
req_pool_indices,
|
|
paged_kernel_lens,
|
|
self.req_to_token,
|
|
)
|
|
)
|
|
|
|
self.kv_indices = kv_indices
|