ROCm: Flex Attention Enablement with custom backends (#4178)

Co-authored-by: linsun12 <linsun12@amd.com>
This commit is contained in:
HAI
2025-03-07 04:38:53 -08:00
committed by GitHub
parent c827c671f7
commit 0beea4503f
7 changed files with 1434 additions and 35 deletions

View File

@@ -0,0 +1,605 @@
from __future__ import annotations
"""
end to end attention solution with aiter kernels
"""
import math
import os
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from typing import TYPE_CHECKING, List, Optional, Union
import torch
import triton
import triton.language as tl
from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
# flashinfer AMD fork
from flashinfer import BatchPrefillWithPagedKVCacheWrapper
try:
from aiter import paged_attention_rocm
except ImportError:
print(
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
)
class WrapperDispatch(Enum):
SLIDING_WINDOW = auto()
CROSS_ATTENTION = auto()
@dataclass
class DecodeMetadata:
kv_indptr: torch.Tensor
kv_indices: torch.Tensor
@dataclass
class PrefillMetadata:
prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper
extend_no_prefix: bool
global_workspace_buffer = None
_AITER_PARTITION_SIZE_ROCM = 256
class AiterAttnBackend(AttentionBackend):
def __init__(
self,
model_runner: ModelRunner,
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
):
super().__init__()
self.device = model_runner.device
self.is_multimodal = model_runner.model_config.is_multimodal
self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.head_dim = model_runner.model_config.head_dim
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
self.kv_cache_dtype = model_runner.kv_cache_dtype
self.req_to_token = model_runner.req_to_token_pool.req_to_token
# Parse constants
self.max_context_len = model_runner.model_config.context_len
self.skip_prefill = skip_prefill
# Qwen2 models require higher flashinfer workspace size
if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
global global_workspace_buffer
if global_workspace_buffer is None:
global_workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device=model_runner.device,
)
self.workspace_buffer = global_workspace_buffer
max_bs = model_runner.req_to_token_pool.size
if kv_indptr_buf is None:
self.kv_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
else:
self.kv_indptr = kv_indptr_buf
self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)
self.qo_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD", backend="fa2"
)
self.prefill_wrapper_verify = BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD"
)
# Create prefill indices updater
if not skip_prefill:
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
model_runner, self
)
# aiter kernel related initialization
self.max_num_partitions = (
self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1
) // _AITER_PARTITION_SIZE_ROCM
nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8
self.workspace_buffer = torch.empty(
(max_bs * self.num_head * self.max_num_partitions * self.head_dim)
* nbyes_per_qo_elem
+ 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
dtype=torch.uint8,
device=self.device,
)
self.scale = float(1.0 / (self.head_dim**0.5))
self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to(
self.device
)
self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to(
self.device
)
self.logits_soft_cap = 0.0
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
self.decode_cuda_graph_metadata = {}
self.prefill_cuda_graph_metadata = {}
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle():
# update for aiter
# create kv_indices and kv_inptr
bs = forward_batch.batch_size
kv_indptr = self.kv_indptr
spec_info = forward_batch.spec_info
if spec_info is None:
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.zeros(
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1
self.forward_metadata = DecodeMetadata(kv_indptr, kv_indices)
elif forward_batch.forward_mode.is_draft_extend():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
prefill_wrapper=self.prefill_wrapper_paged,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = PrefillMetadata(
self.prefill_wrapper_paged, False, False
)
elif forward_batch.forward_mode.is_target_verify():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
prefill_wrapper=self.prefill_wrapper_verify,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = PrefillMetadata(
self.prefill_wrapper_verify, False, False
)
else:
prefix_lens = forward_batch.extend_prefix_lens
if self.is_multimodal:
extend_no_prefix = False
else:
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens,
prefill_wrapper=self.prefill_wrapper_paged,
encoder_lens=forward_batch.encoder_lens,
spec_info=None,
)
self.forward_metadata = PrefillMetadata(
self.prefill_wrapper_paged, extend_no_prefix
)
def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
):
if kv_indices_buf is None:
self.cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len),
dtype=torch.int32,
device=self.device,
)
else:
self.cuda_graph_kv_indices = kv_indices_buf
if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
dtype=torch.uint8,
device=self.device,
)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
if forward_mode.is_decode_or_idle():
if spec_info is None:
kv_indptr = self.kv_indptr
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
self.forward_metadata = DecodeMetadata(kv_indptr, kv_indices)
self.decode_cuda_graph_metadata[bs] = self.forward_metadata
elif forward_mode.is_target_verify():
prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_cuda_graph=False,
qo_indptr_buf=self.cuda_graph_qo_indptr[: bs + 1],
paged_kv_indptr_buf=self.kv_indptr[: bs + 1],
paged_kv_indices_buf=self.cuda_graph_kv_indices,
paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
custom_mask_buf=self.cuda_graph_custom_mask,
mask_indptr_buf=self.cuda_graph_qk_indptr[: bs + 1],
)
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_prefill.update(
req_pool_indices,
seq_lens,
seq_lens_sum,
prefix_lens=None,
prefill_wrapper=prefill_wrapper,
encoder_lens=encoder_lens,
spec_info=spec_info,
)
self.prefill_cuda_graph_metadata[bs] = prefill_wrapper
self.forward_metadata = PrefillMetadata(prefill_wrapper, False)
else:
raise ValueError(f"Invalid mode: {forward_mode=}")
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
kv_indptr = self.kv_indptr
kv_indices = self.cuda_graph_kv_indices
if spec_info is None:
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
kv_indptr = kv_indptr[: bs + 1]
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices[:bs],
seq_lens[:bs],
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
else:
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
elif forward_mode.is_target_verify():
self.indices_updater_prefill.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
prefix_lens=None,
prefill_wrapper=self.prefill_cuda_graph_metadata[bs],
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
spec_info=spec_info,
)
else:
raise ValueError("Invalid forward mode")
def get_cuda_graph_seq_len_fill_value(self):
return 1
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
self.logits_soft_cap = layer.logit_cap
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=not layer.is_cross_attention,
sm_scale=layer.scaling,
window_left=layer.sliding_window_size,
logits_soft_cap=self.logits_soft_cap,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
self.logits_soft_cap = layer.logit_cap
paged_attention_rocm(
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
self.workspace_buffer,
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
-1, 1, layer.tp_k_head_num, layer.qk_head_dim
),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
-1, 1, layer.tp_v_head_num, layer.v_head_dim
),
self.scale,
self.forward_metadata.kv_indptr,
self.forward_metadata.kv_indices,
self.kv_last_page_lens,
1,
self.max_num_partitions,
None,
"auto",
"NHD",
self.logits_soft_cap,
self.k_scale,
self.v_scale,
None,
_AITER_PARTITION_SIZE_ROCM,
)
return o
class FlashInferIndicesUpdaterPrefill:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
# Parse Constants
self.num_qo_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
self.head_dim = model_runner.model_config.head_dim
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.sliding_window_size = model_runner.sliding_window_size
self.attn_backend = attn_backend
# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
self.qo_indptr = attn_backend.qo_indptr
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.update = self.update_single_wrapper
def update(
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
def update_single_wrapper(
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
):
paged_kernel_lens = seq_lens
paged_kernel_lens_sum = seq_lens_sum
self.call_begin_forward(
prefill_wrapper,
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
seq_lens,
prefix_lens,
None,
self.kv_indptr,
self.qo_indptr,
spec_info,
)
def call_begin_forward(
self,
wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
seq_lens: torch.Tensor,
prefix_lens: torch.Tensor,
kv_start_idx: torch.Tensor,
kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor,
spec_info: Optional[SpecInfo],
):
bs = len(req_pool_indices)
if spec_info is None:
# Normal extend
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum + 256,
dtype=torch.int32,
device=req_pool_indices.device,
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
kv_start_idx,
kv_indices,
self.req_to_token.shape[1],
)
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
custom_mask = None
else:
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
req_pool_indices,
paged_kernel_lens,
self.req_to_token,
)
)
# cached part
# adding logits_soft_cap arg in plan() stage
wrapper_paged.begin_forward(
qo_indptr,
kv_indptr,
kv_indices,
self.kv_last_page_len[:bs],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
1,
q_data_type=self.q_data_type,
custom_mask=custom_mask,
non_blocking=True,
logits_soft_cap=self.attn_backend.logits_soft_cap,
)
@triton.jit
def create_flashinfer_kv_indices_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_indptr,
kv_start_idx,
kv_indices_ptr,
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(axis=0)
req_pool_index = tl.load(req_pool_indices_ptr + pid)
kv_indices_offset = tl.load(kv_indptr + pid)
kv_start = 0
kv_end = 0
if kv_start_idx:
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
kv_end = kv_start
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < kv_end - kv_start
data = tl.load(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ kv_start
+ offset,
mask=mask,
)
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)

View File

@@ -0,0 +1,535 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
import torch
import triton
import triton.language as tl
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
try:
from aiter import paged_attention_rocm
except ImportError:
print(
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
)
from sglang.srt.layers.attention.triton_ops.extend_attention import extend_attention_fwd
_AITER_PARTITION_SIZE_ROCM = 256
class AiterDecodeAttnBackend(AttentionBackend):
def __init__(
self,
model_runner: ModelRunner,
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
):
super().__init__()
self.decode_attention_fwd = paged_attention_rocm
self.extend_attention_fwd = extend_attention_fwd
self.skip_prefill = skip_prefill
max_bs = model_runner.req_to_token_pool.size
if kv_indptr_buf is None:
self.kv_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
else:
self.kv_indptr = kv_indptr_buf
self.req_to_token = model_runner.req_to_token_pool.req_to_token
if not self.skip_prefill:
self.qo_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
self.mask_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
)
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
# tp sharding on number of heads
self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.head_dim = model_runner.model_config.head_dim
# triton prefill initialization
self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
self.num_v_head = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-2]
self.forward_metadata = None
self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device
self.kv_cache_dtype = model_runner.kv_cache_dtype
self.q_dtype = model_runner.model_config.dtype
# aiter decode initialization
self.max_num_partitions = (
self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1
) // _AITER_PARTITION_SIZE_ROCM
nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8
self.workspace_buffer = torch.empty(
(max_bs * self.num_head * self.max_num_partitions * self.head_dim)
* nbyes_per_qo_elem
+ 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
dtype=torch.uint8,
device=self.device,
)
self.scale = float(1.0 / (self.head_dim**0.5))
self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to(
self.device
)
self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to(
self.device
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init auxiliary variables"""
bs = forward_batch.batch_size
kv_indptr = self.kv_indptr
spec_info = forward_batch.spec_info
if forward_batch.forward_mode.is_decode_or_idle():
if spec_info is None:
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.zeros(
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
)
# prepare kv_indices and kv_indptr
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1
attn_logits = None # accomodate forward_metadata format
qo_indptr = None
custom_mask = None
mask_indptr = None
max_extend_len = None
elif forward_batch.forward_mode.is_target_verify():
bs = len(forward_batch.req_pool_indices)
qo_indptr = torch.arange(
0,
(1 + bs) * self.num_draft_tokens,
step=self.num_draft_tokens,
dtype=torch.int32,
device=self.device,
)
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.zeros(
kv_indptr[-1], dtype=torch.int32, device=self.device
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
custom_mask = spec_info.custom_mask
seq_mask_len = self.num_draft_tokens * (
forward_batch.seq_lens + self.num_draft_tokens
)
mask_indptr = self.mask_indptr
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
mask_indptr = mask_indptr[: bs + 1]
max_extend_len = self.num_draft_tokens
attn_logits = None
elif forward_batch.forward_mode.is_draft_extend():
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
self.req_to_token,
)
)
mask_indptr = None
max_extend_len = torch.max(spec_info.accept_length).item()
attn_logits = None
else:
kv_indptr[1 : bs + 1] = torch.cumsum(
forward_batch.extend_prefix_lens, dim=0
)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.zeros(
forward_batch.extend_prefix_lens.sum().item(),
dtype=torch.int32,
device=self.device,
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.extend_prefix_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
qo_indptr = self.qo_indptr
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
custom_mask = None
mask_indptr = None
attn_logits = None
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
self.forward_metadata = (
attn_logits,
max_extend_len,
kv_indptr,
kv_indices,
qo_indptr,
custom_mask,
mask_indptr,
)
def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
):
self.cuda_graph_attn_logits = torch.zeros(
(max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
dtype=torch.float32,
device=self.device,
)
if kv_indices_buf is None:
self.cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len),
dtype=torch.int32,
device=self.device,
)
else:
self.cuda_graph_kv_indices = kv_indices_buf
if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
dtype=torch.uint8,
device=self.device,
)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
assert encoder_lens is None, "Not supported"
if forward_mode.is_decode_or_idle():
if spec_info is None:
kv_indptr = self.kv_indptr
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
attn_logits = None
max_extend_len = None
qo_indptr = None
custom_mask = None
mask_indptr = None
elif forward_mode.is_target_verify():
qo_indptr = self.qo_indptr[: bs + 1]
qo_indptr[: bs + 1] = torch.arange(
0,
(1 + bs) * self.num_draft_tokens,
step=self.num_draft_tokens,
dtype=torch.int32,
device=self.device,
)
kv_indptr = self.kv_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
custom_mask = self.cuda_graph_custom_mask
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
mask_indptr = self.mask_indptr[: bs + 1]
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
max_extend_len = self.num_draft_tokens
attn_logits = None
else:
raise ValueError(
f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
)
self.forward_metadata = (
attn_logits,
max_extend_len,
kv_indptr,
kv_indices,
qo_indptr,
custom_mask,
mask_indptr,
)
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
seq_lens_cpu: Optional[torch.Tensor],
):
# NOTE: encoder_lens expected to be zeros or None
if forward_mode.is_decode_or_idle():
# Update kv_indptr, kv_indices
kv_indptr = self.kv_indptr
kv_indices = self.cuda_graph_kv_indices
if spec_info is None:
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
kv_indptr = kv_indptr[: bs + 1]
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices[:bs],
seq_lens[:bs],
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
else:
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
elif forward_mode.is_target_verify():
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
bs = len(req_pool_indices)
qo_indptr = self.qo_indptr[: bs + 1]
qo_indptr[: bs + 1] = torch.arange(
0,
(1 + bs) * self.num_draft_tokens,
step=self.num_draft_tokens,
dtype=torch.int32,
device=self.device,
)
kv_indptr = self.kv_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
custom_mask = self.cuda_graph_custom_mask
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
mask_indptr = self.mask_indptr[: bs + 1]
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
else:
raise ValueError(
f"Invalid forward mode: {forward_mode=} for CUDA Graph replay."
)
def get_cuda_graph_seq_len_fill_value(self):
return 1
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
(
_,
max_extend_len,
kv_indptr,
kv_indices,
qo_indptr,
custom_mask,
mask_indptr,
) = self.forward_metadata
self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k.contiguous(),
v.contiguous(),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
qo_indptr,
kv_indptr,
kv_indices,
custom_mask,
mask_indptr,
max_extend_len,
layer.scaling,
layer.logit_cap,
)
return o
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)
attn_logits, _, kv_indptr, kv_indices, _, _, _ = self.forward_metadata
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
self.decode_attention_fwd(
o.view(
-1, layer.tp_q_head_num, layer.qk_head_dim
), # (bs, head_num_q, head_dim_q)
self.workspace_buffer,
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
-1, 1, layer.tp_k_head_num, layer.qk_head_dim
),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
-1, 1, layer.tp_v_head_num, layer.v_head_dim
),
self.scale,
kv_indptr,
kv_indices,
self.kv_last_page_lens,
1,
self.max_num_partitions,
None,
"auto",
"NHD",
layer.logit_cap,
self.k_scale,
self.v_scale,
None,
_AITER_PARTITION_SIZE_ROCM,
)
return o
@triton.jit
def create_flashinfer_kv_indices_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_indptr,
kv_start_idx,
kv_indices_ptr,
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(axis=0)
req_pool_index = tl.load(req_pool_indices_ptr + pid)
kv_indices_offset = tl.load(kv_indptr + pid)
kv_start = 0
kv_end = 0
if kv_start_idx:
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
kv_end = kv_start
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < kv_end - kv_start
data = tl.load(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ kv_start
+ offset,
mask=mask,
)
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)

View File

@@ -79,6 +79,12 @@ from sglang.srt.utils import (
)
from sglang.utils import get_exception_traceback
is_hip_ = is_hip()
if is_hip_:
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
from sglang.srt.layers.attention.aiter_decode_backend import AiterDecodeAttnBackend
logger = logging.getLogger(__name__)
@@ -641,7 +647,7 @@ class ModelRunner:
if self.server_args.kv_cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
if is_hip(): # Using natively supported format
if is_hip_: # Using natively supported format
self.kv_cache_dtype = torch.float8_e5m2fnuz
else:
self.kv_cache_dtype = torch.float8_e5m2
@@ -778,33 +784,59 @@ class ModelRunner:
def init_attention_backend(self):
"""Init attention kernel backend."""
if self.server_args.attention_backend == "flashinfer":
# Init streams
if self.server_args.speculative_algorithm == "EAGLE":
self.plan_stream_for_flashinfer = torch.cuda.Stream()
if is_cuda():
if self.server_args.attention_backend == "flashinfer":
# Init streams
if self.server_args.speculative_algorithm == "EAGLE":
self.plan_stream_for_flashinfer = torch.cuda.Stream()
self.attn_backend = FlashInferAttnBackend(self)
elif self.server_args.attention_backend == "triton":
assert self.sliding_window_size is None, (
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
assert not self.model_config.is_encoder_decoder, (
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
if self.server_args.enable_double_sparsity:
self.attn_backend = DoubleSparseAttnBackend(self)
self.attn_backend = FlashInferAttnBackend(self)
elif self.server_args.attention_backend == "triton":
assert self.sliding_window_size is None, (
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
assert not self.model_config.is_encoder_decoder, (
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
if self.server_args.enable_double_sparsity:
self.attn_backend = DoubleSparseAttnBackend(self)
else:
self.attn_backend = TritonAttnBackend(self)
elif self.server_args.attention_backend == "torch_native":
self.attn_backend = TorchNativeAttnBackend(self)
elif self.server_args.attention_backend == "flashinfer_mla":
self.attn_backend = FlashInferMLAAttnBackend(self)
else:
self.attn_backend = TritonAttnBackend(self)
elif self.server_args.attention_backend == "torch_native":
self.attn_backend = TorchNativeAttnBackend(self)
elif self.server_args.attention_backend == "flashinfer_mla":
self.attn_backend = FlashInferMLAAttnBackend(self)
else:
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"
)
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"
)
elif is_hip_:
# AMD hip supported attention backends
if self.server_args.attention_backend == "aiter":
self.attn_backend = AiterAttnBackend(self)
elif self.server_args.attention_backend == "aiter_decode":
self.attn_backend = AiterDecodeAttnBackend(self)
elif self.server_args.attention_backend == "triton":
assert self.sliding_window_size is None, (
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
assert not self.model_config.is_encoder_decoder, (
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
if self.server_args.enable_double_sparsity:
self.attn_backend = DoubleSparseAttnBackend(self)
else:
self.attn_backend = TritonAttnBackend(self)
elif self.server_args.attention_backend == "torch_native":
self.attn_backend = TorchNativeAttnBackend(self)
else:
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"
)
def init_double_sparsity_channel_config(self, selected_channel):
selected_channel = "." + selected_channel + "_proj"

View File

@@ -710,13 +710,23 @@ class ServerArgs:
)
# Kernel backend
parser.add_argument(
"--attention-backend",
type=str,
choices=["flashinfer", "triton", "torch_native"],
default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.",
)
if is_hip():
parser.add_argument(
"--attention-backend",
type=str,
choices=["triton", "torch_native", "aiter", "aiter_decode"],
default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.",
)
else:
parser.add_argument(
"--attention-backend",
type=str,
choices=["flashinfer", "triton", "torch_native"],
default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.",
)
parser.add_argument(
"--sampling-backend",
type=str,