Optimize Triton decoding kernel for dynamic workload (#4553)
This commit is contained in:
@@ -39,6 +39,7 @@ class AttentionBackend(ABC):
|
|||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
self,
|
self,
|
||||||
bs: int,
|
bs: int,
|
||||||
|
num_kv_heads: int,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
|
|||||||
@@ -349,6 +349,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
self,
|
self,
|
||||||
bs: int,
|
bs: int,
|
||||||
|
num_kv_heads: int,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
@@ -1062,6 +1063,7 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
def call_fn(i, forward_batch):
|
def call_fn(i, forward_batch):
|
||||||
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
||||||
bs,
|
bs,
|
||||||
|
-1,
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
forward_batch.seq_lens,
|
forward_batch.seq_lens,
|
||||||
seq_lens_sum=-1,
|
seq_lens_sum=-1,
|
||||||
|
|||||||
@@ -279,6 +279,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
self,
|
self,
|
||||||
bs: int,
|
bs: int,
|
||||||
|
num_kv_heads: int,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
@@ -791,6 +792,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|||||||
def call_fn(i, forward_batch):
|
def call_fn(i, forward_batch):
|
||||||
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
||||||
bs,
|
bs,
|
||||||
|
-1,
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
forward_batch.seq_lens,
|
forward_batch.seq_lens,
|
||||||
seq_lens_sum=-1,
|
seq_lens_sum=-1,
|
||||||
|
|||||||
@@ -4,11 +4,13 @@ from typing import TYPE_CHECKING, Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
|
from sglang.srt.utils import get_bool_env_var, get_device_core_count
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
@@ -16,6 +18,51 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def get_num_kv_splits_triton(
|
||||||
|
num_kv_splits_ptr,
|
||||||
|
seq_lens_ptr,
|
||||||
|
bs,
|
||||||
|
num_head,
|
||||||
|
num_kv_head,
|
||||||
|
max_kv_splits,
|
||||||
|
device_core_count,
|
||||||
|
MAX_BS: tl.constexpr,
|
||||||
|
):
|
||||||
|
# TODO: this method is tunable
|
||||||
|
offs_b = tl.arange(0, MAX_BS)
|
||||||
|
mask_b = offs_b < bs
|
||||||
|
|
||||||
|
seq_lens = tl.load(seq_lens_ptr + offs_b, mask=mask_b, other=0)
|
||||||
|
max_seq_len = tl.max(seq_lens)
|
||||||
|
seq_lens = tl.load(seq_lens_ptr + offs_b, mask=mask_b, other=max_seq_len)
|
||||||
|
min_seq_len = tl.min(seq_lens)
|
||||||
|
if max_seq_len * 8 < min_seq_len * 10:
|
||||||
|
min_seq_len = max_seq_len
|
||||||
|
max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)
|
||||||
|
kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)
|
||||||
|
|
||||||
|
# NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
|
||||||
|
ext_seq_len = tl.cast(tl.cdiv(max_seq_len, 256), tl.float32)
|
||||||
|
ext_device_core_count = device_core_count * tl.maximum(
|
||||||
|
tl.cast(tl.ceil(tl.log2(ext_seq_len)), tl.int32), 1
|
||||||
|
)
|
||||||
|
block_h, num_kv_group = 16, num_head // num_kv_head
|
||||||
|
if num_kv_group == 1:
|
||||||
|
bh_grid = bs * num_head
|
||||||
|
else:
|
||||||
|
# from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
|
||||||
|
block_h = tl.minimum(block_h, num_kv_group)
|
||||||
|
bh_grid = bs * tl.cdiv(num_head, block_h)
|
||||||
|
max_kv_splits_2 = tl.minimum(tl.cdiv(ext_device_core_count, bh_grid), max_kv_splits)
|
||||||
|
kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)
|
||||||
|
|
||||||
|
num_kv_splits = tl.maximum(
|
||||||
|
tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
|
||||||
|
)
|
||||||
|
tl.store(num_kv_splits_ptr + offs_b, num_kv_splits, mask=mask_b)
|
||||||
|
|
||||||
|
|
||||||
class TritonAttnBackend(AttentionBackend):
|
class TritonAttnBackend(AttentionBackend):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -64,7 +111,10 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
self.static_kv_splits = get_bool_env_var(
|
||||||
|
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
|
||||||
|
)
|
||||||
|
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
||||||
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
||||||
|
|
||||||
self.forward_metadata = None
|
self.forward_metadata = None
|
||||||
@@ -72,6 +122,30 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
self.max_context_len = model_runner.model_config.context_len
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
|
|
||||||
self.device = model_runner.device
|
self.device = model_runner.device
|
||||||
|
self.device_core_count = get_device_core_count(model_runner.gpu_id)
|
||||||
|
|
||||||
|
def get_num_kv_splits(
|
||||||
|
self,
|
||||||
|
num_kv_splits: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
bs: int,
|
||||||
|
num_kv_head: int,
|
||||||
|
):
|
||||||
|
MAX_SCHEDULE_BS = 4096
|
||||||
|
if self.static_kv_splits or self.device_core_count <= 0 or bs > MAX_SCHEDULE_BS:
|
||||||
|
num_kv_splits.fill_(self.max_kv_splits)
|
||||||
|
return
|
||||||
|
|
||||||
|
get_num_kv_splits_triton[(1,)](
|
||||||
|
num_kv_splits,
|
||||||
|
seq_lens,
|
||||||
|
bs,
|
||||||
|
self.num_head,
|
||||||
|
num_kv_head,
|
||||||
|
self.max_kv_splits,
|
||||||
|
self.device_core_count,
|
||||||
|
MAX_BS=MAX_SCHEDULE_BS,
|
||||||
|
)
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
"""Init auxiliary variables for triton attention backend."""
|
"""Init auxiliary variables for triton attention backend."""
|
||||||
@@ -100,15 +174,35 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||||
bs = kv_indptr.shape[0] - 1
|
bs = kv_indptr.shape[0] - 1
|
||||||
|
|
||||||
attn_logits = torch.empty(
|
attn_logits = [
|
||||||
(
|
torch.empty(
|
||||||
bs,
|
(
|
||||||
self.num_head,
|
bs,
|
||||||
self.num_kv_splits,
|
self.num_head,
|
||||||
self.v_head_dim + 1,
|
self.max_kv_splits,
|
||||||
|
self.v_head_dim,
|
||||||
|
),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
),
|
),
|
||||||
dtype=torch.float32,
|
torch.empty(
|
||||||
device=self.device,
|
(
|
||||||
|
bs,
|
||||||
|
self.num_head,
|
||||||
|
self.max_kv_splits,
|
||||||
|
),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
num_kv_splits = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
||||||
|
|
||||||
|
num_kv_heads = self.num_head
|
||||||
|
if hasattr(forward_batch.token_to_kv_pool, "k_buffer"):
|
||||||
|
if isinstance(forward_batch.token_to_kv_pool.k_buffer, list):
|
||||||
|
num_kv_heads = forward_batch.token_to_kv_pool.k_buffer[0].shape[1]
|
||||||
|
self.get_num_kv_splits(
|
||||||
|
num_kv_splits, forward_batch.seq_lens, bs, num_kv_heads
|
||||||
)
|
)
|
||||||
|
|
||||||
qo_indptr = None
|
qo_indptr = None
|
||||||
@@ -148,6 +242,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
|
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
|
||||||
mask_indptr = mask_indptr[: bs + 1]
|
mask_indptr = mask_indptr[: bs + 1]
|
||||||
max_extend_len = self.num_draft_tokens
|
max_extend_len = self.num_draft_tokens
|
||||||
|
num_kv_splits = None
|
||||||
attn_logits = None
|
attn_logits = None
|
||||||
elif forward_batch.forward_mode.is_draft_extend():
|
elif forward_batch.forward_mode.is_draft_extend():
|
||||||
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||||
@@ -160,6 +255,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
mask_indptr = None
|
mask_indptr = None
|
||||||
max_extend_len = torch.max(spec_info.accept_length).item()
|
max_extend_len = torch.max(spec_info.accept_length).item()
|
||||||
|
num_kv_splits = None
|
||||||
attn_logits = None
|
attn_logits = None
|
||||||
else:
|
else:
|
||||||
kv_indptr[1 : bs + 1] = torch.cumsum(
|
kv_indptr[1 : bs + 1] = torch.cumsum(
|
||||||
@@ -188,10 +284,12 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
mask_indptr = None
|
mask_indptr = None
|
||||||
attn_logits = None
|
attn_logits = None
|
||||||
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
||||||
|
num_kv_splits = None
|
||||||
|
|
||||||
self.forward_metadata = (
|
self.forward_metadata = (
|
||||||
attn_logits,
|
attn_logits,
|
||||||
max_extend_len,
|
max_extend_len,
|
||||||
|
num_kv_splits,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
@@ -202,10 +300,20 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
def init_cuda_graph_state(
|
def init_cuda_graph_state(
|
||||||
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
||||||
):
|
):
|
||||||
self.cuda_graph_attn_logits = torch.zeros(
|
self.cuda_graph_attn_logits = [
|
||||||
(max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
|
torch.zeros(
|
||||||
dtype=torch.float32,
|
(max_bs, self.num_head, self.max_kv_splits, self.v_head_dim),
|
||||||
device=self.device,
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
torch.zeros(
|
||||||
|
(max_bs, self.num_head, self.max_kv_splits),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
self.cuda_graph_num_kv_splits = torch.full(
|
||||||
|
(max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
|
||||||
)
|
)
|
||||||
if kv_indices_buf is None:
|
if kv_indices_buf is None:
|
||||||
self.cuda_graph_kv_indices = torch.zeros(
|
self.cuda_graph_kv_indices = torch.zeros(
|
||||||
@@ -255,6 +363,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
attn_logits = self.cuda_graph_attn_logits
|
attn_logits = self.cuda_graph_attn_logits
|
||||||
max_extend_len = None
|
max_extend_len = None
|
||||||
|
num_kv_splits = self.cuda_graph_num_kv_splits
|
||||||
qo_indptr = None
|
qo_indptr = None
|
||||||
custom_mask = None
|
custom_mask = None
|
||||||
mask_indptr = None
|
mask_indptr = None
|
||||||
@@ -285,6 +394,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
mask_indptr = self.mask_indptr[: bs + 1]
|
mask_indptr = self.mask_indptr[: bs + 1]
|
||||||
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
||||||
max_extend_len = self.num_draft_tokens
|
max_extend_len = self.num_draft_tokens
|
||||||
|
num_kv_splits = None
|
||||||
attn_logits = None
|
attn_logits = None
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -294,6 +404,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
self.forward_metadata = (
|
self.forward_metadata = (
|
||||||
attn_logits,
|
attn_logits,
|
||||||
max_extend_len,
|
max_extend_len,
|
||||||
|
num_kv_splits,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
@@ -304,6 +415,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
self,
|
self,
|
||||||
bs: int,
|
bs: int,
|
||||||
|
num_kv_head: int,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
@@ -317,6 +429,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
# Update kv_indptr, kv_indices
|
# Update kv_indptr, kv_indices
|
||||||
kv_indptr = self.kv_indptr
|
kv_indptr = self.kv_indptr
|
||||||
kv_indices = self.cuda_graph_kv_indices
|
kv_indices = self.cuda_graph_kv_indices
|
||||||
|
num_kv_splits = self.cuda_graph_num_kv_splits
|
||||||
if spec_info is None:
|
if spec_info is None:
|
||||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
|
||||||
kv_indptr = kv_indptr[: bs + 1]
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
@@ -332,6 +445,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
else:
|
else:
|
||||||
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
|
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
|
||||||
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
|
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
|
||||||
|
self.get_num_kv_splits(num_kv_splits, seq_lens, bs, num_kv_head)
|
||||||
elif forward_mode.is_target_verify():
|
elif forward_mode.is_target_verify():
|
||||||
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
|
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
|
||||||
bs = len(req_pool_indices)
|
bs = len(req_pool_indices)
|
||||||
@@ -391,6 +505,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
(
|
(
|
||||||
_,
|
_,
|
||||||
max_extend_len,
|
max_extend_len,
|
||||||
|
_,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
@@ -435,7 +550,9 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
else:
|
else:
|
||||||
o = torch.empty_like(q)
|
o = torch.empty_like(q)
|
||||||
|
|
||||||
attn_logits, _, kv_indptr, kv_indices, _, _, _ = self.forward_metadata
|
attn_logits, _, num_kv_splits, kv_indptr, kv_indices, _, _, _ = (
|
||||||
|
self.forward_metadata
|
||||||
|
)
|
||||||
|
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
@@ -450,7 +567,8 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
self.num_kv_splits,
|
num_kv_splits,
|
||||||
|
self.max_kv_splits,
|
||||||
layer.scaling,
|
layer.scaling,
|
||||||
layer.logit_cap,
|
layer.logit_cap,
|
||||||
)
|
)
|
||||||
@@ -493,6 +611,9 @@ class TritonMultiStepDraftBackend:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.max_context_len = self.attn_backends[0].max_context_len
|
self.max_context_len = self.attn_backends[0].max_context_len
|
||||||
|
self.num_head = (
|
||||||
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||||
|
)
|
||||||
self.device = model_runner.device
|
self.device = model_runner.device
|
||||||
# Cached variables for generate_draft_decode_kv_indices
|
# Cached variables for generate_draft_decode_kv_indices
|
||||||
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
||||||
@@ -579,9 +700,15 @@ class TritonMultiStepDraftBackend:
|
|||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
self, forward_batch: ForwardBatch, bs: int
|
self, forward_batch: ForwardBatch, bs: int
|
||||||
):
|
):
|
||||||
|
num_kv_heads = self.num_head
|
||||||
|
if hasattr(forward_batch.token_to_kv_pool, "k_buffer"):
|
||||||
|
if isinstance(forward_batch.token_to_kv_pool.k_buffer, list):
|
||||||
|
num_kv_heads = forward_batch.token_to_kv_pool.k_buffer[0].shape[1]
|
||||||
|
|
||||||
def call_fn(i, forward_batch):
|
def call_fn(i, forward_batch):
|
||||||
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
||||||
bs,
|
bs,
|
||||||
|
num_kv_heads,
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
forward_batch.seq_lens,
|
forward_batch.seq_lens,
|
||||||
seq_lens_sum=-1,
|
seq_lens_sum=-1,
|
||||||
|
|||||||
@@ -37,6 +37,9 @@ logger.warning(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_MIN_BLOCK_KV = 32
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def tanh(x):
|
def tanh(x):
|
||||||
# Tanh is just a scaled sigmoid
|
# Tanh is just a scaled sigmoid
|
||||||
@@ -52,6 +55,8 @@ def _fwd_kernel_stage1(
|
|||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
Att_Out,
|
Att_Out,
|
||||||
|
Att_Lse,
|
||||||
|
num_kv_splits,
|
||||||
stride_qbs,
|
stride_qbs,
|
||||||
stride_qh,
|
stride_qh,
|
||||||
stride_buf_kbs,
|
stride_buf_kbs,
|
||||||
@@ -65,7 +70,7 @@ def _fwd_kernel_stage1(
|
|||||||
BLOCK_DMODEL: tl.constexpr,
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
BLOCK_DV: tl.constexpr,
|
BLOCK_DV: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
NUM_KV_SPLITS: tl.constexpr,
|
MIN_BLOCK_KV: tl.constexpr,
|
||||||
logit_cap: tl.constexpr,
|
logit_cap: tl.constexpr,
|
||||||
Lk: tl.constexpr,
|
Lk: tl.constexpr,
|
||||||
Lv: tl.constexpr,
|
Lv: tl.constexpr,
|
||||||
@@ -83,11 +88,13 @@ def _fwd_kernel_stage1(
|
|||||||
|
|
||||||
cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
|
cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
|
||||||
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
||||||
|
kv_splits = tl.load(num_kv_splits + cur_batch)
|
||||||
|
|
||||||
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
||||||
q = tl.load(Q + off_q, mask=mask_d, other=0.0)
|
|
||||||
|
|
||||||
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
kv_len_per_split = (
|
||||||
|
tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV
|
||||||
|
)
|
||||||
split_kv_start = kv_len_per_split * split_kv_id
|
split_kv_start = kv_len_per_split * split_kv_id
|
||||||
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
|
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
|
||||||
|
|
||||||
@@ -96,6 +103,7 @@ def _fwd_kernel_stage1(
|
|||||||
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
||||||
|
|
||||||
if split_kv_end > split_kv_start:
|
if split_kv_end > split_kv_start:
|
||||||
|
q = tl.load(Q + off_q, mask=mask_d, other=0.0)
|
||||||
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||||
offs_n = start_n + tl.arange(0, BLOCK_N)
|
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||||
kv_loc = tl.load(
|
kv_loc = tl.load(
|
||||||
@@ -158,11 +166,10 @@ def _fwd_kernel_stage1(
|
|||||||
cur_batch * stride_mid_ob
|
cur_batch * stride_mid_ob
|
||||||
+ cur_head * stride_mid_oh
|
+ cur_head * stride_mid_oh
|
||||||
+ split_kv_id * stride_mid_os
|
+ split_kv_id * stride_mid_os
|
||||||
+ Lv
|
) // Lv
|
||||||
)
|
|
||||||
|
|
||||||
tl.store(
|
tl.store(
|
||||||
Att_Out + offs_mid_o_1,
|
Att_Lse + offs_mid_o_1,
|
||||||
e_max + tl.log(e_sum),
|
e_max + tl.log(e_sum),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -172,9 +179,11 @@ def _decode_att_m_fwd(
|
|||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
att_out,
|
att_out,
|
||||||
|
att_lse,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
):
|
):
|
||||||
@@ -182,13 +191,13 @@ def _decode_att_m_fwd(
|
|||||||
# [TODO] work around SGPR limit on MI3xx
|
# [TODO] work around SGPR limit on MI3xx
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
BLOCK = 8
|
BLOCK = 8
|
||||||
NUM_KV_SPLITS = num_kv_splits
|
MAX_KV_SPLITS = max_kv_splits
|
||||||
Lk = k_buffer.shape[-1]
|
Lk = k_buffer.shape[-1]
|
||||||
Lv = v_buffer.shape[-1]
|
Lv = v_buffer.shape[-1]
|
||||||
|
|
||||||
batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
|
batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
|
||||||
|
|
||||||
grid = (batch, head_num, NUM_KV_SPLITS)
|
grid = (batch, head_num, MAX_KV_SPLITS)
|
||||||
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
||||||
|
|
||||||
if kv_group_num == 1:
|
if kv_group_num == 1:
|
||||||
@@ -209,6 +218,8 @@ def _decode_att_m_fwd(
|
|||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
att_out,
|
att_out,
|
||||||
|
att_lse,
|
||||||
|
num_kv_splits,
|
||||||
q.stride(0),
|
q.stride(0),
|
||||||
q.stride(1),
|
q.stride(1),
|
||||||
k_buffer.stride(0),
|
k_buffer.stride(0),
|
||||||
@@ -222,7 +233,7 @@ def _decode_att_m_fwd(
|
|||||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||||
BLOCK_DV=BLOCK_DV,
|
BLOCK_DV=BLOCK_DV,
|
||||||
BLOCK_N=BLOCK,
|
BLOCK_N=BLOCK,
|
||||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=2,
|
num_stages=2,
|
||||||
@@ -240,6 +251,8 @@ def _fwd_grouped_kernel_stage1(
|
|||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
Att_Out,
|
Att_Out,
|
||||||
|
Att_Lse,
|
||||||
|
num_kv_splits,
|
||||||
stride_qbs,
|
stride_qbs,
|
||||||
stride_qh,
|
stride_qh,
|
||||||
stride_buf_kbs,
|
stride_buf_kbs,
|
||||||
@@ -256,7 +269,7 @@ def _fwd_grouped_kernel_stage1(
|
|||||||
BLOCK_DV: tl.constexpr,
|
BLOCK_DV: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
BLOCK_H: tl.constexpr,
|
BLOCK_H: tl.constexpr,
|
||||||
NUM_KV_SPLITS: tl.constexpr,
|
MIN_BLOCK_KV: tl.constexpr,
|
||||||
logit_cap: tl.constexpr,
|
logit_cap: tl.constexpr,
|
||||||
Lk: tl.constexpr,
|
Lk: tl.constexpr,
|
||||||
Lv: tl.constexpr,
|
Lv: tl.constexpr,
|
||||||
@@ -281,9 +294,9 @@ def _fwd_grouped_kernel_stage1(
|
|||||||
|
|
||||||
cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
|
cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
|
||||||
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
||||||
|
kv_splits = tl.load(num_kv_splits + cur_batch)
|
||||||
|
|
||||||
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
|
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
|
||||||
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
|
|
||||||
|
|
||||||
if BLOCK_DPE > 0:
|
if BLOCK_DPE > 0:
|
||||||
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
||||||
@@ -291,11 +304,10 @@ def _fwd_grouped_kernel_stage1(
|
|||||||
off_qpe = (
|
off_qpe = (
|
||||||
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
|
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
|
||||||
)
|
)
|
||||||
qpe = tl.load(
|
|
||||||
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
kv_len_per_split = (
|
||||||
|
tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV
|
||||||
|
)
|
||||||
split_kv_start = kv_len_per_split * split_kv_id
|
split_kv_start = kv_len_per_split * split_kv_id
|
||||||
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
|
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
|
||||||
|
|
||||||
@@ -304,6 +316,11 @@ def _fwd_grouped_kernel_stage1(
|
|||||||
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
|
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
|
||||||
|
|
||||||
if split_kv_end > split_kv_start:
|
if split_kv_end > split_kv_start:
|
||||||
|
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
|
||||||
|
if BLOCK_DPE > 0:
|
||||||
|
qpe = tl.load(
|
||||||
|
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
|
||||||
|
)
|
||||||
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||||
offs_n = start_n + tl.arange(0, BLOCK_N)
|
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||||
kv_loc = tl.load(
|
kv_loc = tl.load(
|
||||||
@@ -380,11 +397,10 @@ def _fwd_grouped_kernel_stage1(
|
|||||||
cur_batch * stride_mid_ob
|
cur_batch * stride_mid_ob
|
||||||
+ cur_head * stride_mid_oh
|
+ cur_head * stride_mid_oh
|
||||||
+ split_kv_id * stride_mid_os
|
+ split_kv_id * stride_mid_os
|
||||||
+ Lv
|
) // Lv
|
||||||
)
|
|
||||||
|
|
||||||
tl.store(
|
tl.store(
|
||||||
Att_Out + offs_mid_o_1,
|
Att_Lse + offs_mid_o_1,
|
||||||
e_max + tl.log(e_sum),
|
e_max + tl.log(e_sum),
|
||||||
mask=mask_h,
|
mask=mask_h,
|
||||||
)
|
)
|
||||||
@@ -395,9 +411,11 @@ def _decode_grouped_att_m_fwd(
|
|||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
att_out,
|
att_out,
|
||||||
|
att_lse,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
):
|
):
|
||||||
@@ -424,11 +442,11 @@ def _decode_grouped_att_m_fwd(
|
|||||||
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
||||||
|
|
||||||
BLOCK_H = 16
|
BLOCK_H = 16
|
||||||
NUM_KV_SPLITS = num_kv_splits
|
MAX_KV_SPLITS = max_kv_splits
|
||||||
grid = (
|
grid = (
|
||||||
batch,
|
batch,
|
||||||
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
|
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
|
||||||
NUM_KV_SPLITS,
|
MAX_KV_SPLITS,
|
||||||
)
|
)
|
||||||
|
|
||||||
extra_kargs = {}
|
extra_kargs = {}
|
||||||
@@ -447,6 +465,8 @@ def _decode_grouped_att_m_fwd(
|
|||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
att_out,
|
att_out,
|
||||||
|
att_lse,
|
||||||
|
num_kv_splits,
|
||||||
q.stride(0),
|
q.stride(0),
|
||||||
q.stride(1),
|
q.stride(1),
|
||||||
k_buffer.stride(0),
|
k_buffer.stride(0),
|
||||||
@@ -463,7 +483,7 @@ def _decode_grouped_att_m_fwd(
|
|||||||
BLOCK_DV=BLOCK_DV,
|
BLOCK_DV=BLOCK_DV,
|
||||||
BLOCK_N=BLOCK,
|
BLOCK_N=BLOCK,
|
||||||
BLOCK_H=BLOCK_H,
|
BLOCK_H=BLOCK_H,
|
||||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
num_warps=4,
|
num_warps=4,
|
||||||
num_stages=num_stages,
|
num_stages=num_stages,
|
||||||
@@ -476,14 +496,17 @@ def _decode_grouped_att_m_fwd(
|
|||||||
@triton.jit
|
@triton.jit
|
||||||
def _fwd_kernel_stage2(
|
def _fwd_kernel_stage2(
|
||||||
Mid_O,
|
Mid_O,
|
||||||
|
Mid_O_1,
|
||||||
O,
|
O,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
|
num_kv_splits,
|
||||||
stride_mid_ob,
|
stride_mid_ob,
|
||||||
stride_mid_oh,
|
stride_mid_oh,
|
||||||
stride_mid_os,
|
stride_mid_os,
|
||||||
stride_obs,
|
stride_obs,
|
||||||
stride_oh,
|
stride_oh,
|
||||||
NUM_KV_SPLITS: tl.constexpr,
|
MAX_KV_SPLITS: tl.constexpr,
|
||||||
|
MIN_BLOCK_KV: tl.constexpr,
|
||||||
BLOCK_DV: tl.constexpr,
|
BLOCK_DV: tl.constexpr,
|
||||||
Lv: tl.constexpr,
|
Lv: tl.constexpr,
|
||||||
):
|
):
|
||||||
@@ -493,6 +516,7 @@ def _fwd_kernel_stage2(
|
|||||||
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load(
|
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load(
|
||||||
kv_indptr + cur_batch
|
kv_indptr + cur_batch
|
||||||
)
|
)
|
||||||
|
kv_splits = tl.load(num_kv_splits + cur_batch)
|
||||||
|
|
||||||
offs_d = tl.arange(0, BLOCK_DV)
|
offs_d = tl.arange(0, BLOCK_DV)
|
||||||
mask_d = offs_d < Lv
|
mask_d = offs_d < Lv
|
||||||
@@ -502,10 +526,12 @@ def _fwd_kernel_stage2(
|
|||||||
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
||||||
|
|
||||||
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
|
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
|
||||||
offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv
|
offs_logic = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh) // Lv
|
||||||
|
kv_len_per_split = (
|
||||||
|
tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV
|
||||||
|
)
|
||||||
|
|
||||||
for split_kv_id in range(0, NUM_KV_SPLITS):
|
for split_kv_id in range(0, MAX_KV_SPLITS):
|
||||||
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
|
||||||
split_kv_start = kv_len_per_split * split_kv_id
|
split_kv_start = kv_len_per_split * split_kv_id
|
||||||
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
|
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
|
||||||
|
|
||||||
@@ -513,7 +539,7 @@ def _fwd_kernel_stage2(
|
|||||||
tv = tl.load(
|
tv = tl.load(
|
||||||
Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0
|
Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0
|
||||||
)
|
)
|
||||||
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
|
tlogic = tl.load(Mid_O_1 + offs_logic + split_kv_id * stride_mid_os // Lv)
|
||||||
n_e_max = tl.maximum(tlogic, e_max)
|
n_e_max = tl.maximum(tlogic, e_max)
|
||||||
|
|
||||||
old_scale = tl.exp(e_max - n_e_max)
|
old_scale = tl.exp(e_max - n_e_max)
|
||||||
@@ -533,17 +559,19 @@ def _fwd_kernel_stage2(
|
|||||||
|
|
||||||
def _decode_softmax_reducev_fwd(
|
def _decode_softmax_reducev_fwd(
|
||||||
logits,
|
logits,
|
||||||
|
lse,
|
||||||
q,
|
q,
|
||||||
o,
|
o,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
|
max_kv_splits,
|
||||||
):
|
):
|
||||||
batch, head_num = q.shape[0], q.shape[1]
|
batch, head_num = q.shape[0], q.shape[1]
|
||||||
Lv = v_buffer.shape[-1]
|
Lv = v_buffer.shape[-1]
|
||||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
NUM_KV_SPLITS = num_kv_splits
|
MAX_KV_SPLITS = max_kv_splits
|
||||||
|
|
||||||
extra_kargs = {}
|
extra_kargs = {}
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
@@ -554,14 +582,17 @@ def _decode_softmax_reducev_fwd(
|
|||||||
grid = (batch, head_num)
|
grid = (batch, head_num)
|
||||||
_fwd_kernel_stage2[grid](
|
_fwd_kernel_stage2[grid](
|
||||||
logits,
|
logits,
|
||||||
|
lse,
|
||||||
o,
|
o,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
|
num_kv_splits,
|
||||||
logits.stride(0),
|
logits.stride(0),
|
||||||
logits.stride(1),
|
logits.stride(1),
|
||||||
logits.stride(2),
|
logits.stride(2),
|
||||||
o.stride(0),
|
o.stride(0),
|
||||||
o.stride(1),
|
o.stride(1),
|
||||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
MAX_KV_SPLITS=MAX_KV_SPLITS,
|
||||||
|
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
||||||
BLOCK_DV=BLOCK_DV,
|
BLOCK_DV=BLOCK_DV,
|
||||||
Lv=Lv,
|
Lv=Lv,
|
||||||
num_warps=4,
|
num_warps=4,
|
||||||
@@ -579,6 +610,7 @@ def decode_attention_fwd_normal(
|
|||||||
kv_indices,
|
kv_indices,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
):
|
):
|
||||||
@@ -586,14 +618,25 @@ def decode_attention_fwd_normal(
|
|||||||
q,
|
q,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
attn_logits,
|
attn_logits[0],
|
||||||
|
attn_logits[1],
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
)
|
)
|
||||||
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits)
|
_decode_softmax_reducev_fwd(
|
||||||
|
attn_logits[0],
|
||||||
|
attn_logits[1],
|
||||||
|
q,
|
||||||
|
o,
|
||||||
|
v_buffer,
|
||||||
|
kv_indptr,
|
||||||
|
num_kv_splits,
|
||||||
|
max_kv_splits,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def decode_attention_fwd_grouped(
|
def decode_attention_fwd_grouped(
|
||||||
@@ -605,6 +648,7 @@ def decode_attention_fwd_grouped(
|
|||||||
kv_indices,
|
kv_indices,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
):
|
):
|
||||||
@@ -612,14 +656,25 @@ def decode_attention_fwd_grouped(
|
|||||||
q,
|
q,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
attn_logits,
|
attn_logits[0],
|
||||||
|
attn_logits[1],
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
)
|
)
|
||||||
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits)
|
_decode_softmax_reducev_fwd(
|
||||||
|
attn_logits[0],
|
||||||
|
attn_logits[1],
|
||||||
|
q,
|
||||||
|
o,
|
||||||
|
v_buffer,
|
||||||
|
kv_indptr,
|
||||||
|
num_kv_splits,
|
||||||
|
max_kv_splits,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def decode_attention_fwd(
|
def decode_attention_fwd(
|
||||||
@@ -631,12 +686,13 @@ def decode_attention_fwd(
|
|||||||
kv_indices,
|
kv_indices,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
):
|
):
|
||||||
assert num_kv_splits == attn_logits.shape[2]
|
assert max_kv_splits == attn_logits[0].shape[2]
|
||||||
assert q.shape[0] <= kv_indptr.shape[0] - 1
|
assert q.shape[0] <= kv_indptr.shape[0] - 1
|
||||||
assert q.shape[0] <= attn_logits.shape[0]
|
assert q.shape[0] <= attn_logits[0].shape[0]
|
||||||
|
|
||||||
kv_group_num = q.shape[1] // v_buffer.shape[1]
|
kv_group_num = q.shape[1] // v_buffer.shape[1]
|
||||||
|
|
||||||
@@ -651,6 +707,7 @@ def decode_attention_fwd(
|
|||||||
kv_indices,
|
kv_indices,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
)
|
)
|
||||||
@@ -665,6 +722,7 @@ def decode_attention_fwd(
|
|||||||
kv_indices,
|
kv_indices,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ import tqdm
|
|||||||
from sglang.srt.custom_op import CustomOp
|
from sglang.srt.custom_op import CustomOp
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
||||||
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
|
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
|
||||||
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
||||||
@@ -195,6 +196,9 @@ class CudaGraphRunner:
|
|||||||
# Attention backend
|
# Attention backend
|
||||||
self.max_bs = max(self.capture_bs)
|
self.max_bs = max(self.capture_bs)
|
||||||
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
||||||
|
self.num_head = (
|
||||||
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||||
|
)
|
||||||
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
|
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
|
||||||
self.seq_len_fill_value = (
|
self.seq_len_fill_value = (
|
||||||
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||||
@@ -503,9 +507,15 @@ class CudaGraphRunner:
|
|||||||
if hasattr(forward_batch.spec_info, "hidden_states"):
|
if hasattr(forward_batch.spec_info, "hidden_states"):
|
||||||
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
|
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
|
||||||
|
|
||||||
|
num_kv_heads = self.num_head
|
||||||
|
if hasattr(forward_batch.token_to_kv_pool, "k_buffer"):
|
||||||
|
if isinstance(forward_batch.token_to_kv_pool.k_buffer, list):
|
||||||
|
num_kv_heads = forward_batch.token_to_kv_pool.k_buffer[0].shape[1]
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
bs,
|
bs,
|
||||||
|
num_kv_heads,
|
||||||
self.req_pool_indices,
|
self.req_pool_indices,
|
||||||
self.seq_lens,
|
self.seq_lens,
|
||||||
forward_batch.seq_lens_sum + (bs - raw_bs),
|
forward_batch.seq_lens_sum + (bs - raw_bs),
|
||||||
|
|||||||
@@ -228,7 +228,8 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
seq_len = 10 # This represents the number of tokens already in the sequence
|
seq_len = 10 # This represents the number of tokens already in the sequence
|
||||||
total_tokens = B * seq_len
|
total_tokens = B * seq_len
|
||||||
sm_scale = 1.0 / (D**0.5)
|
sm_scale = 1.0 / (D**0.5)
|
||||||
num_kv_splits = 8
|
max_kv_splits = 8
|
||||||
|
num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
# q represents the new token being generated, one per batch
|
# q represents the new token being generated, one per batch
|
||||||
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
|
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
|
||||||
@@ -247,7 +248,12 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
kv_indices = torch.arange(total_tokens, device="cuda")
|
kv_indices = torch.arange(total_tokens, device="cuda")
|
||||||
|
|
||||||
attn_logits = torch.empty(
|
attn_logits = torch.empty(
|
||||||
(B, H_Q, num_kv_splits, D + 1),
|
(B, H_Q, max_kv_splits, D),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
attn_lse = torch.empty(
|
||||||
|
(B, H_Q, max_kv_splits),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
@@ -259,8 +265,9 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
o,
|
o,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
attn_logits,
|
(attn_logits, attn_lse),
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -284,7 +291,8 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
seq_len = S # This represents the number of tokens already in the sequence
|
seq_len = S # This represents the number of tokens already in the sequence
|
||||||
total_tokens = B * seq_len
|
total_tokens = B * seq_len
|
||||||
sm_scale = 1.0 / (D**0.5)
|
sm_scale = 1.0 / (D**0.5)
|
||||||
num_kv_splits = 8
|
max_kv_splits = 8
|
||||||
|
num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
# q represents the new token being generated, one per batch
|
# q represents the new token being generated, one per batch
|
||||||
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
|
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
|
||||||
@@ -304,7 +312,12 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
kv_indices = torch.arange(total_tokens, device="cuda")
|
kv_indices = torch.arange(total_tokens, device="cuda")
|
||||||
|
|
||||||
attn_logits = torch.empty(
|
attn_logits = torch.empty(
|
||||||
(B, H_Q, num_kv_splits, D_V + 1),
|
(B, H_Q, max_kv_splits, D_V),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
attn_lse = torch.empty(
|
||||||
|
(B, H_Q, max_kv_splits),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
@@ -316,13 +329,19 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
o,
|
o,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
attn_logits,
|
(attn_logits, attn_lse),
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_logits1 = torch.empty(
|
attn_logits1 = torch.empty(
|
||||||
(B, H_Q, num_kv_splits, D_V + 1),
|
(B, H_Q, max_kv_splits, D_V),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
attn_lse1 = torch.empty(
|
||||||
|
(B, H_Q, max_kv_splits, D_V),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
@@ -334,8 +353,9 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
o_grouped,
|
o_grouped,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
attn_logits1,
|
(attn_logits1, attn_lse1),
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user