Update Triton extend backend interface (#3309)
This commit is contained in:
@@ -17,12 +17,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|||||||
def __init__(self, model_runner: ModelRunner):
|
def __init__(self, model_runner: ModelRunner):
|
||||||
# Lazy import to avoid the initialization of cuda context
|
# Lazy import to avoid the initialization of cuda context
|
||||||
from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
|
from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
|
||||||
|
extend_attention_fwd,
|
||||||
flash_decode_attention_fwd,
|
flash_decode_attention_fwd,
|
||||||
flash_decode_sparse_attention_fwd,
|
flash_decode_sparse_attention_fwd,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.attention.triton_ops.extend_attention import (
|
|
||||||
extend_attention_fwd,
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|||||||
@@ -37,6 +37,9 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||||
)
|
)
|
||||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
|
self.qo_indptr = torch.zeros(
|
||||||
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||||
|
)
|
||||||
|
|
||||||
self.num_head = (
|
self.num_head = (
|
||||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||||
@@ -54,6 +57,9 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
"""Init auxiliary variables for triton attention backend."""
|
"""Init auxiliary variables for triton attention backend."""
|
||||||
|
|
||||||
|
bs = forward_batch.batch_size
|
||||||
|
kv_indptr = self.kv_indptr
|
||||||
|
|
||||||
if forward_batch.forward_mode.is_decode():
|
if forward_batch.forward_mode.is_decode():
|
||||||
attn_logits = torch.empty(
|
attn_logits = torch.empty(
|
||||||
(
|
(
|
||||||
@@ -68,31 +74,59 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
max_extend_len = None
|
max_extend_len = None
|
||||||
|
|
||||||
kv_indptr = self.kv_indptr
|
|
||||||
bs = len(forward_batch.req_pool_indices)
|
|
||||||
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||||
kv_indptr = kv_indptr[: bs + 1]
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
kv_indices = torch.empty(
|
kv_indices = torch.empty(
|
||||||
forward_batch.seq_lens_sum, dtype=torch.int32, device="cuda"
|
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
||||||
)
|
)
|
||||||
create_flashinfer_kv_indices_triton[(bs,)](
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
forward_batch.req_to_token_pool.req_to_token,
|
self.req_to_token,
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
forward_batch.seq_lens,
|
forward_batch.seq_lens,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
None,
|
None,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
forward_batch.req_to_token_pool.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
qo_indptr = None
|
||||||
|
custom_mask = None
|
||||||
else:
|
else:
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(
|
||||||
|
forward_batch.extend_prefix_lens, dim=0
|
||||||
|
)
|
||||||
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
|
kv_indices = torch.empty(
|
||||||
|
forward_batch.extend_prefix_lens.sum().item(),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.extend_prefix_lens,
|
||||||
|
kv_indptr,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
qo_indptr = self.qo_indptr
|
||||||
|
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
|
||||||
|
qo_indptr = qo_indptr[: bs + 1]
|
||||||
|
custom_mask = None
|
||||||
|
|
||||||
attn_logits = None
|
attn_logits = None
|
||||||
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
||||||
|
|
||||||
kv_indptr = None
|
self.forward_metadata = (
|
||||||
kv_indices = None
|
attn_logits,
|
||||||
|
max_extend_len,
|
||||||
self.forward_metadata = attn_logits, max_extend_len, kv_indptr, kv_indices
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
qo_indptr,
|
||||||
|
custom_mask,
|
||||||
|
)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(self, max_bs: int):
|
||||||
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
||||||
@@ -144,6 +178,8 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
None,
|
None,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
@@ -197,7 +233,9 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
layer, forward_batch.out_cache_loc, k, v
|
layer, forward_batch.out_cache_loc, k, v
|
||||||
)
|
)
|
||||||
|
|
||||||
_, max_extend_len, _, _ = self.forward_metadata
|
_, max_extend_len, kv_indptr, kv_indices, qo_indptr, custom_mask = (
|
||||||
|
self.forward_metadata
|
||||||
|
)
|
||||||
self.extend_attention_fwd(
|
self.extend_attention_fwd(
|
||||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||||
k.contiguous(),
|
k.contiguous(),
|
||||||
@@ -205,11 +243,9 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||||
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||||
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||||
forward_batch.req_to_token_pool.req_to_token,
|
qo_indptr,
|
||||||
forward_batch.req_pool_indices,
|
kv_indptr,
|
||||||
forward_batch.seq_lens,
|
kv_indices,
|
||||||
forward_batch.extend_seq_lens,
|
|
||||||
forward_batch.extend_start_loc,
|
|
||||||
max_extend_len,
|
max_extend_len,
|
||||||
layer.scaling,
|
layer.scaling,
|
||||||
layer.logit_cap,
|
layer.logit_cap,
|
||||||
@@ -235,7 +271,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
else:
|
else:
|
||||||
o = torch.empty_like(q)
|
o = torch.empty_like(q)
|
||||||
|
|
||||||
attn_logits, _, kv_indptr, kv_indices = self.forward_metadata
|
attn_logits, _, kv_indptr, kv_indices, _, _ = self.forward_metadata
|
||||||
|
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
|||||||
@@ -3,6 +3,13 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
|
is_cuda_available = torch.cuda.is_available()
|
||||||
|
if is_cuda_available:
|
||||||
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||||
|
|
||||||
|
is_hip_ = is_hip()
|
||||||
|
|
||||||
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
||||||
REDUCE_TRITON_TYPE = tl.float32
|
REDUCE_TRITON_TYPE = tl.float32
|
||||||
@@ -274,9 +281,6 @@ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq):
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def flash_decode_attention_fwd(
|
def flash_decode_attention_fwd(
|
||||||
q,
|
q,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
@@ -770,3 +774,333 @@ def flash_decode_sparse_attention_fwd(
|
|||||||
)
|
)
|
||||||
|
|
||||||
sparse_flash_decode_stage3(heavy_token_num, mid_out, mid_o_logexpsum, o, BLOCK_SEQ)
|
sparse_flash_decode_stage3(heavy_token_num, mid_out, mid_o_logexpsum, o, BLOCK_SEQ)
|
||||||
|
|
||||||
|
|
||||||
|
# Extend attention kernel for Double Sparsity
|
||||||
|
# Moved from https://github.com/sgl-project/sglang/blob/v0.4.2.post1/python/sglang/srt/layers/attention/triton_ops/extend_attention.py
|
||||||
|
@triton.jit
|
||||||
|
def _fwd_kernel(
|
||||||
|
Q_Extend,
|
||||||
|
K_Extend,
|
||||||
|
V_Extend,
|
||||||
|
O_Extend,
|
||||||
|
K_Buffer,
|
||||||
|
V_Buffer,
|
||||||
|
Req_to_tokens,
|
||||||
|
B_req_idx,
|
||||||
|
B_Seq_Len,
|
||||||
|
B_Start_Loc_Extend,
|
||||||
|
B_Seq_Len_Extend,
|
||||||
|
sm_scale,
|
||||||
|
kv_group_num,
|
||||||
|
stride_qbs,
|
||||||
|
stride_qh,
|
||||||
|
stride_kbs,
|
||||||
|
stride_kh,
|
||||||
|
stride_vbs,
|
||||||
|
stride_vh,
|
||||||
|
stride_obs,
|
||||||
|
stride_oh,
|
||||||
|
stride_buf_kbs,
|
||||||
|
stride_buf_kh,
|
||||||
|
stride_buf_vbs,
|
||||||
|
stride_buf_vh,
|
||||||
|
stride_req_to_tokens_b,
|
||||||
|
logit_cap: tl.constexpr,
|
||||||
|
Lq: tl.constexpr,
|
||||||
|
Lv: tl.constexpr,
|
||||||
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
|
BLOCK_DPE: tl.constexpr,
|
||||||
|
BLOCK_DV: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
):
|
||||||
|
cur_seq = tl.program_id(0)
|
||||||
|
cur_head = tl.program_id(1)
|
||||||
|
cur_block_m = tl.program_id(2)
|
||||||
|
cur_kv_head = cur_head // kv_group_num
|
||||||
|
|
||||||
|
cur_seq_len = tl.load(B_Seq_Len + cur_seq)
|
||||||
|
cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq)
|
||||||
|
cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend
|
||||||
|
|
||||||
|
cur_seq_prefix_start_in_loc = 0
|
||||||
|
cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq)
|
||||||
|
cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
|
||||||
|
|
||||||
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||||
|
offs_dv = tl.arange(0, BLOCK_DV)
|
||||||
|
offs_m = tl.arange(0, BLOCK_M)
|
||||||
|
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
|
||||||
|
|
||||||
|
mask_d = offs_d < Lq
|
||||||
|
mask_dv = offs_dv < Lv
|
||||||
|
|
||||||
|
offs_q = (
|
||||||
|
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||||
|
* stride_qbs
|
||||||
|
+ cur_head * stride_qh
|
||||||
|
+ offs_d[None, :]
|
||||||
|
)
|
||||||
|
q = tl.load(
|
||||||
|
Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
if BLOCK_DPE > 0:
|
||||||
|
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
||||||
|
offs_qpe = (
|
||||||
|
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||||
|
* stride_qbs
|
||||||
|
+ cur_head * stride_qh
|
||||||
|
+ offs_dpe[None, :]
|
||||||
|
)
|
||||||
|
qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
|
||||||
|
|
||||||
|
# stage 1: compute scores with prefix
|
||||||
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
|
|
||||||
|
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
|
||||||
|
deno = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||||
|
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||||
|
|
||||||
|
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
|
||||||
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
|
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
||||||
|
offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + (
|
||||||
|
cur_seq_prefix_start_in_loc + start_n + offs_n
|
||||||
|
)
|
||||||
|
offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0)
|
||||||
|
|
||||||
|
# load k in transposed way
|
||||||
|
offs_buf_k = (
|
||||||
|
offs_kv_loc[None, :] * stride_buf_kbs
|
||||||
|
+ cur_kv_head * stride_buf_kh
|
||||||
|
+ offs_d[:, None]
|
||||||
|
)
|
||||||
|
k = tl.load(
|
||||||
|
K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
qk = tl.dot(q.to(k.dtype), k)
|
||||||
|
if BLOCK_DPE > 0:
|
||||||
|
offs_kpe = (
|
||||||
|
offs_kv_loc[None, :] * stride_buf_kbs
|
||||||
|
+ cur_kv_head * stride_buf_kh
|
||||||
|
+ offs_dpe[:, None]
|
||||||
|
)
|
||||||
|
kpe = tl.load(
|
||||||
|
K_Buffer + offs_kpe,
|
||||||
|
mask=mask_n[None, :],
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
qk += tl.dot(qpe.to(kpe.dtype), kpe)
|
||||||
|
qk *= sm_scale
|
||||||
|
|
||||||
|
if logit_cap > 0:
|
||||||
|
qk = logit_cap * tanh(qk / logit_cap)
|
||||||
|
|
||||||
|
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
|
||||||
|
|
||||||
|
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||||
|
re_scale = tl.exp(e_max - n_e_max)
|
||||||
|
p = tl.exp(qk - n_e_max[:, None])
|
||||||
|
deno = deno * re_scale + tl.sum(p, 1)
|
||||||
|
|
||||||
|
offs_buf_v = (
|
||||||
|
offs_kv_loc[:, None] * stride_buf_vbs
|
||||||
|
+ cur_kv_head * stride_buf_vh
|
||||||
|
+ offs_dv[None, :]
|
||||||
|
)
|
||||||
|
v = tl.load(
|
||||||
|
V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
||||||
|
)
|
||||||
|
p = p.to(v.dtype)
|
||||||
|
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
||||||
|
|
||||||
|
e_max = n_e_max
|
||||||
|
|
||||||
|
# stage 2: compute the trianlge part
|
||||||
|
|
||||||
|
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
||||||
|
for start_n in range(0, cur_block_m_end, BLOCK_N):
|
||||||
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
|
mask_n = (start_n + offs_n) < cur_block_m_end
|
||||||
|
|
||||||
|
# load k in transposed way
|
||||||
|
offs_k = (
|
||||||
|
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs
|
||||||
|
+ cur_kv_head * stride_kh
|
||||||
|
+ offs_d[:, None]
|
||||||
|
)
|
||||||
|
k = tl.load(
|
||||||
|
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
qk = tl.dot(q, k, out_dtype=tl.float32)
|
||||||
|
if BLOCK_DPE > 0:
|
||||||
|
offs_kpe = (
|
||||||
|
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
|
||||||
|
* stride_kbs
|
||||||
|
+ cur_kv_head * stride_kh
|
||||||
|
+ offs_dpe[:, None]
|
||||||
|
)
|
||||||
|
kpe = tl.load(
|
||||||
|
K_Extend + offs_kpe,
|
||||||
|
mask=mask_n[None, :],
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
qk += tl.dot(qpe, kpe)
|
||||||
|
|
||||||
|
qk *= sm_scale
|
||||||
|
|
||||||
|
if logit_cap > 0:
|
||||||
|
qk = logit_cap * tanh(qk / logit_cap)
|
||||||
|
|
||||||
|
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
||||||
|
start_n + offs_n[None, :]
|
||||||
|
)
|
||||||
|
mask_causual &= mask_m[:, None] & mask_n[None, :]
|
||||||
|
qk = tl.where(mask_causual, qk, float("-inf"))
|
||||||
|
|
||||||
|
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||||
|
re_scale = tl.exp(e_max - n_e_max)
|
||||||
|
p = tl.exp(qk - n_e_max[:, None])
|
||||||
|
deno = deno * re_scale + tl.sum(p, 1)
|
||||||
|
|
||||||
|
offs_v = (
|
||||||
|
(cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
|
||||||
|
+ cur_kv_head * stride_vh
|
||||||
|
+ offs_dv[None, :]
|
||||||
|
)
|
||||||
|
v = tl.load(
|
||||||
|
V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
||||||
|
)
|
||||||
|
p = p.to(v.dtype)
|
||||||
|
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
||||||
|
|
||||||
|
e_max = n_e_max
|
||||||
|
|
||||||
|
offs_o = (
|
||||||
|
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||||
|
* stride_obs
|
||||||
|
+ cur_head * stride_oh
|
||||||
|
+ offs_dv[None, :]
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None] & mask_dv[None, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extend_attention_fwd(
|
||||||
|
q_extend,
|
||||||
|
k_extend,
|
||||||
|
v_extend,
|
||||||
|
o_extend,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
req_to_tokens,
|
||||||
|
b_req_idx,
|
||||||
|
b_seq_len,
|
||||||
|
b_seq_len_extend,
|
||||||
|
b_start_loc_extend,
|
||||||
|
max_len_extend,
|
||||||
|
sm_scale=None,
|
||||||
|
logit_cap=0.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
||||||
|
|
||||||
|
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
|
||||||
|
"""
|
||||||
|
Lq, Lk, Lv = (
|
||||||
|
q_extend.shape[-1],
|
||||||
|
k_extend.shape[-1],
|
||||||
|
v_extend.shape[-1],
|
||||||
|
)
|
||||||
|
|
||||||
|
if Lq == 576:
|
||||||
|
BLOCK_DMODEL = 512
|
||||||
|
BLOCK_DPE = 64
|
||||||
|
elif Lq == 288:
|
||||||
|
BLOCK_DMODEL = 256
|
||||||
|
BLOCK_DPE = 32
|
||||||
|
elif Lq == 192:
|
||||||
|
BLOCK_DMODEL = 128
|
||||||
|
BLOCK_DPE = 64
|
||||||
|
else:
|
||||||
|
BLOCK_DMODEL = triton.next_power_of_2(Lq)
|
||||||
|
BLOCK_DPE = 0
|
||||||
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
|
if is_hip_:
|
||||||
|
BLOCK_M, BLOCK_N = (64, 64)
|
||||||
|
num_warps = 4
|
||||||
|
|
||||||
|
else:
|
||||||
|
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
|
||||||
|
if Lq <= 256:
|
||||||
|
BLOCK_M, BLOCK_N = (128, 64)
|
||||||
|
else:
|
||||||
|
BLOCK_M, BLOCK_N = (32, 64)
|
||||||
|
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
||||||
|
if Lq <= 128:
|
||||||
|
BLOCK_M, BLOCK_N = (128, 128)
|
||||||
|
elif Lq <= 256:
|
||||||
|
BLOCK_M, BLOCK_N = (64, 64)
|
||||||
|
else:
|
||||||
|
BLOCK_M, BLOCK_N = (32, 64)
|
||||||
|
else:
|
||||||
|
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
||||||
|
|
||||||
|
num_warps = 4 if Lk <= 64 else 8
|
||||||
|
|
||||||
|
sm_scale = sm_scale or 1.0 / (Lq**0.5)
|
||||||
|
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
|
||||||
|
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
||||||
|
|
||||||
|
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
||||||
|
num_stages = 1
|
||||||
|
|
||||||
|
extra_kargs = {}
|
||||||
|
if is_hip_:
|
||||||
|
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||||
|
|
||||||
|
_fwd_kernel[grid](
|
||||||
|
q_extend,
|
||||||
|
k_extend,
|
||||||
|
v_extend,
|
||||||
|
o_extend,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
req_to_tokens,
|
||||||
|
b_req_idx,
|
||||||
|
b_seq_len,
|
||||||
|
b_start_loc_extend,
|
||||||
|
b_seq_len_extend,
|
||||||
|
sm_scale,
|
||||||
|
kv_group_num,
|
||||||
|
q_extend.stride(0),
|
||||||
|
q_extend.stride(1),
|
||||||
|
k_extend.stride(0),
|
||||||
|
k_extend.stride(1),
|
||||||
|
v_extend.stride(0),
|
||||||
|
v_extend.stride(1),
|
||||||
|
o_extend.stride(0),
|
||||||
|
o_extend.stride(1),
|
||||||
|
k_buffer.stride(0),
|
||||||
|
k_buffer.stride(1),
|
||||||
|
v_buffer.stride(0),
|
||||||
|
v_buffer.stride(1),
|
||||||
|
req_to_tokens.stride(0),
|
||||||
|
logit_cap=logit_cap,
|
||||||
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||||
|
BLOCK_DPE=BLOCK_DPE,
|
||||||
|
BLOCK_DV=BLOCK_DV,
|
||||||
|
BLOCK_M=BLOCK_M,
|
||||||
|
BLOCK_N=BLOCK_N,
|
||||||
|
Lq=Lq,
|
||||||
|
Lv=Lv,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=num_stages,
|
||||||
|
**extra_kargs,
|
||||||
|
)
|
||||||
|
|||||||
@@ -46,11 +46,9 @@ def _fwd_kernel(
|
|||||||
O_Extend,
|
O_Extend,
|
||||||
K_Buffer,
|
K_Buffer,
|
||||||
V_Buffer,
|
V_Buffer,
|
||||||
Req_to_tokens,
|
qo_indptr,
|
||||||
B_req_idx,
|
kv_indptr,
|
||||||
B_Seq_Len,
|
kv_indices,
|
||||||
B_Start_Loc_Extend,
|
|
||||||
B_Seq_Len_Extend,
|
|
||||||
sm_scale,
|
sm_scale,
|
||||||
kv_group_num,
|
kv_group_num,
|
||||||
stride_qbs,
|
stride_qbs,
|
||||||
@@ -65,7 +63,6 @@ def _fwd_kernel(
|
|||||||
stride_buf_kh,
|
stride_buf_kh,
|
||||||
stride_buf_vbs,
|
stride_buf_vbs,
|
||||||
stride_buf_vh,
|
stride_buf_vh,
|
||||||
stride_req_to_tokens_b,
|
|
||||||
logit_cap: tl.constexpr,
|
logit_cap: tl.constexpr,
|
||||||
Lq: tl.constexpr,
|
Lq: tl.constexpr,
|
||||||
Lv: tl.constexpr,
|
Lv: tl.constexpr,
|
||||||
@@ -80,13 +77,10 @@ def _fwd_kernel(
|
|||||||
cur_block_m = tl.program_id(2)
|
cur_block_m = tl.program_id(2)
|
||||||
cur_kv_head = cur_head // kv_group_num
|
cur_kv_head = cur_head // kv_group_num
|
||||||
|
|
||||||
cur_seq_len = tl.load(B_Seq_Len + cur_seq)
|
cur_seq_extend_start_idx = tl.load(qo_indptr + cur_seq)
|
||||||
cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq)
|
cur_seq_len_extend = tl.load(qo_indptr + cur_seq + 1) - cur_seq_extend_start_idx
|
||||||
cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend
|
cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq)
|
||||||
|
cur_seq_len_prefix = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx
|
||||||
cur_seq_prefix_start_in_loc = 0
|
|
||||||
cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq)
|
|
||||||
cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
|
|
||||||
|
|
||||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||||
offs_dv = tl.arange(0, BLOCK_DV)
|
offs_dv = tl.arange(0, BLOCK_DV)
|
||||||
@@ -97,7 +91,7 @@ def _fwd_kernel(
|
|||||||
mask_dv = offs_dv < Lv
|
mask_dv = offs_dv < Lv
|
||||||
|
|
||||||
offs_q = (
|
offs_q = (
|
||||||
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||||
* stride_qbs
|
* stride_qbs
|
||||||
+ cur_head * stride_qh
|
+ cur_head * stride_qh
|
||||||
+ offs_d[None, :]
|
+ offs_d[None, :]
|
||||||
@@ -109,7 +103,7 @@ def _fwd_kernel(
|
|||||||
if BLOCK_DPE > 0:
|
if BLOCK_DPE > 0:
|
||||||
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
||||||
offs_qpe = (
|
offs_qpe = (
|
||||||
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||||
* stride_qbs
|
* stride_qbs
|
||||||
+ cur_head * stride_qh
|
+ cur_head * stride_qh
|
||||||
+ offs_dpe[None, :]
|
+ offs_dpe[None, :]
|
||||||
@@ -126,10 +120,9 @@ def _fwd_kernel(
|
|||||||
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
|
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
|
||||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
||||||
offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + (
|
offs_kv_loc = tl.load(
|
||||||
cur_seq_prefix_start_in_loc + start_n + offs_n
|
kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
|
||||||
)
|
)
|
||||||
offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0)
|
|
||||||
|
|
||||||
# load k in transposed way
|
# load k in transposed way
|
||||||
offs_buf_k = (
|
offs_buf_k = (
|
||||||
@@ -188,7 +181,7 @@ def _fwd_kernel(
|
|||||||
|
|
||||||
# load k in transposed way
|
# load k in transposed way
|
||||||
offs_k = (
|
offs_k = (
|
||||||
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs
|
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
||||||
+ cur_kv_head * stride_kh
|
+ cur_kv_head * stride_kh
|
||||||
+ offs_d[:, None]
|
+ offs_d[:, None]
|
||||||
)
|
)
|
||||||
@@ -199,8 +192,7 @@ def _fwd_kernel(
|
|||||||
qk = tl.dot(q, k, out_dtype=tl.float32)
|
qk = tl.dot(q, k, out_dtype=tl.float32)
|
||||||
if BLOCK_DPE > 0:
|
if BLOCK_DPE > 0:
|
||||||
offs_kpe = (
|
offs_kpe = (
|
||||||
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
|
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
||||||
* stride_kbs
|
|
||||||
+ cur_kv_head * stride_kh
|
+ cur_kv_head * stride_kh
|
||||||
+ offs_dpe[:, None]
|
+ offs_dpe[:, None]
|
||||||
)
|
)
|
||||||
@@ -228,7 +220,7 @@ def _fwd_kernel(
|
|||||||
deno = deno * re_scale + tl.sum(p, 1)
|
deno = deno * re_scale + tl.sum(p, 1)
|
||||||
|
|
||||||
offs_v = (
|
offs_v = (
|
||||||
(cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
|
(cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
|
||||||
+ cur_kv_head * stride_vh
|
+ cur_kv_head * stride_vh
|
||||||
+ offs_dv[None, :]
|
+ offs_dv[None, :]
|
||||||
)
|
)
|
||||||
@@ -241,7 +233,7 @@ def _fwd_kernel(
|
|||||||
e_max = n_e_max
|
e_max = n_e_max
|
||||||
|
|
||||||
offs_o = (
|
offs_o = (
|
||||||
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||||
* stride_obs
|
* stride_obs
|
||||||
+ cur_head * stride_oh
|
+ cur_head * stride_oh
|
||||||
+ offs_dv[None, :]
|
+ offs_dv[None, :]
|
||||||
@@ -258,11 +250,9 @@ def extend_attention_fwd(
|
|||||||
o_extend,
|
o_extend,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
req_to_tokens,
|
qo_indptr,
|
||||||
b_req_idx,
|
kv_indptr,
|
||||||
b_seq_len,
|
kv_indices,
|
||||||
b_seq_len_extend,
|
|
||||||
b_start_loc_extend,
|
|
||||||
max_len_extend,
|
max_len_extend,
|
||||||
sm_scale=None,
|
sm_scale=None,
|
||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
@@ -315,7 +305,7 @@ def extend_attention_fwd(
|
|||||||
num_warps = 4 if Lk <= 64 else 8
|
num_warps = 4 if Lk <= 64 else 8
|
||||||
|
|
||||||
sm_scale = sm_scale or 1.0 / (Lq**0.5)
|
sm_scale = sm_scale or 1.0 / (Lq**0.5)
|
||||||
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
|
batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1]
|
||||||
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
||||||
|
|
||||||
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
||||||
@@ -332,11 +322,9 @@ def extend_attention_fwd(
|
|||||||
o_extend,
|
o_extend,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
req_to_tokens,
|
qo_indptr,
|
||||||
b_req_idx,
|
kv_indptr,
|
||||||
b_seq_len,
|
kv_indices,
|
||||||
b_start_loc_extend,
|
|
||||||
b_seq_len_extend,
|
|
||||||
sm_scale,
|
sm_scale,
|
||||||
kv_group_num,
|
kv_group_num,
|
||||||
q_extend.stride(0),
|
q_extend.stride(0),
|
||||||
@@ -351,7 +339,6 @@ def extend_attention_fwd(
|
|||||||
k_buffer.stride(1),
|
k_buffer.stride(1),
|
||||||
v_buffer.stride(0),
|
v_buffer.stride(0),
|
||||||
v_buffer.stride(1),
|
v_buffer.stride(1),
|
||||||
req_to_tokens.stride(0),
|
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||||
BLOCK_DPE=BLOCK_DPE,
|
BLOCK_DPE=BLOCK_DPE,
|
||||||
|
|||||||
@@ -45,16 +45,20 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
|
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
|
||||||
|
|
||||||
b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda")
|
b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda")
|
||||||
req_to_tokens = torch.empty(
|
|
||||||
(B, max_len_in_batch), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
|
b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
|
||||||
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
|
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
|
||||||
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
|
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
|
||||||
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||||
|
|
||||||
|
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
|
||||||
|
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
|
||||||
|
kv_indices = torch.zeros(
|
||||||
|
(b_seq_len_prefix.sum().item(),), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(B):
|
for i in range(B):
|
||||||
req_to_tokens[i, : b_seq_len[i]] = torch.arange(
|
kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange(
|
||||||
b_start_loc[i], b_start_loc[i] + b_seq_len[i]
|
b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i]
|
||||||
)
|
)
|
||||||
|
|
||||||
total_token_num = torch.sum(b_seq_len).item()
|
total_token_num = torch.sum(b_seq_len).item()
|
||||||
@@ -90,9 +94,10 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
b_seq_len_extend = b_seq_len - b_seq_len_prefix
|
b_seq_len_extend = b_seq_len - b_seq_len_prefix
|
||||||
b_start_loc_extend = torch.zeros_like(b_seq_len)
|
|
||||||
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
|
||||||
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
|
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
|
||||||
|
qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
|
||||||
|
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
|
||||||
|
|
||||||
extend_attention_fwd(
|
extend_attention_fwd(
|
||||||
q_extend,
|
q_extend,
|
||||||
k_extend,
|
k_extend,
|
||||||
@@ -100,11 +105,9 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
o_extend,
|
o_extend,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
req_to_tokens,
|
qo_indptr,
|
||||||
b_req_idx,
|
kv_indptr,
|
||||||
b_seq_len,
|
kv_indices,
|
||||||
b_seq_len_extend,
|
|
||||||
b_start_loc_extend,
|
|
||||||
max_len_extend,
|
max_len_extend,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user