Update Triton extend backend interface (#3309)
This commit is contained in:
@@ -17,12 +17,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
# Lazy import to avoid the initialization of cuda context
|
||||
from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
|
||||
extend_attention_fwd,
|
||||
flash_decode_attention_fwd,
|
||||
flash_decode_sparse_attention_fwd,
|
||||
)
|
||||
from sglang.srt.layers.attention.triton_ops.extend_attention import (
|
||||
extend_attention_fwd,
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -37,6 +37,9 @@ class TritonAttnBackend(AttentionBackend):
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.qo_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
|
||||
self.num_head = (
|
||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||
@@ -54,6 +57,9 @@ class TritonAttnBackend(AttentionBackend):
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Init auxiliary variables for triton attention backend."""
|
||||
|
||||
bs = forward_batch.batch_size
|
||||
kv_indptr = self.kv_indptr
|
||||
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
attn_logits = torch.empty(
|
||||
(
|
||||
@@ -68,31 +74,59 @@ class TritonAttnBackend(AttentionBackend):
|
||||
|
||||
max_extend_len = None
|
||||
|
||||
kv_indptr = self.kv_indptr
|
||||
bs = len(forward_batch.req_pool_indices)
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
forward_batch.seq_lens_sum, dtype=torch.int32, device="cuda"
|
||||
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
forward_batch.req_to_token_pool.req_to_token,
|
||||
self.req_to_token,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
forward_batch.req_to_token_pool.req_to_token.stride(0),
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
|
||||
qo_indptr = None
|
||||
custom_mask = None
|
||||
else:
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(
|
||||
forward_batch.extend_prefix_lens, dim=0
|
||||
)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
forward_batch.extend_prefix_lens.sum().item(),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.extend_prefix_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
|
||||
qo_indptr = self.qo_indptr
|
||||
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
|
||||
qo_indptr = qo_indptr[: bs + 1]
|
||||
custom_mask = None
|
||||
|
||||
attn_logits = None
|
||||
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
||||
|
||||
kv_indptr = None
|
||||
kv_indices = None
|
||||
|
||||
self.forward_metadata = attn_logits, max_extend_len, kv_indptr, kv_indices
|
||||
self.forward_metadata = (
|
||||
attn_logits,
|
||||
max_extend_len,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
qo_indptr,
|
||||
custom_mask,
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
||||
@@ -144,6 +178,8 @@ class TritonAttnBackend(AttentionBackend):
|
||||
None,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
@@ -197,7 +233,9 @@ class TritonAttnBackend(AttentionBackend):
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
_, max_extend_len, _, _ = self.forward_metadata
|
||||
_, max_extend_len, kv_indptr, kv_indices, qo_indptr, custom_mask = (
|
||||
self.forward_metadata
|
||||
)
|
||||
self.extend_attention_fwd(
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
k.contiguous(),
|
||||
@@ -205,11 +243,9 @@ class TritonAttnBackend(AttentionBackend):
|
||||
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||
forward_batch.req_to_token_pool.req_to_token,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.extend_seq_lens,
|
||||
forward_batch.extend_start_loc,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
max_extend_len,
|
||||
layer.scaling,
|
||||
layer.logit_cap,
|
||||
@@ -235,7 +271,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
attn_logits, _, kv_indptr, kv_indices = self.forward_metadata
|
||||
attn_logits, _, kv_indptr, kv_indices, _, _ = self.forward_metadata
|
||||
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
|
||||
@@ -3,6 +3,13 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
is_cuda_available = torch.cuda.is_available()
|
||||
if is_cuda_available:
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
|
||||
is_hip_ = is_hip()
|
||||
|
||||
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
||||
REDUCE_TRITON_TYPE = tl.float32
|
||||
@@ -274,9 +281,6 @@ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq):
|
||||
return
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def flash_decode_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
@@ -770,3 +774,333 @@ def flash_decode_sparse_attention_fwd(
|
||||
)
|
||||
|
||||
sparse_flash_decode_stage3(heavy_token_num, mid_out, mid_o_logexpsum, o, BLOCK_SEQ)
|
||||
|
||||
|
||||
# Extend attention kernel for Double Sparsity
|
||||
# Moved from https://github.com/sgl-project/sglang/blob/v0.4.2.post1/python/sglang/srt/layers/attention/triton_ops/extend_attention.py
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q_Extend,
|
||||
K_Extend,
|
||||
V_Extend,
|
||||
O_Extend,
|
||||
K_Buffer,
|
||||
V_Buffer,
|
||||
Req_to_tokens,
|
||||
B_req_idx,
|
||||
B_Seq_Len,
|
||||
B_Start_Loc_Extend,
|
||||
B_Seq_Len_Extend,
|
||||
sm_scale,
|
||||
kv_group_num,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_buf_kbs,
|
||||
stride_buf_kh,
|
||||
stride_buf_vbs,
|
||||
stride_buf_vh,
|
||||
stride_req_to_tokens_b,
|
||||
logit_cap: tl.constexpr,
|
||||
Lq: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DPE: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_seq = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
cur_block_m = tl.program_id(2)
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
cur_seq_len = tl.load(B_Seq_Len + cur_seq)
|
||||
cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq)
|
||||
cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend
|
||||
|
||||
cur_seq_prefix_start_in_loc = 0
|
||||
cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq)
|
||||
cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_dv = tl.arange(0, BLOCK_DV)
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
|
||||
|
||||
mask_d = offs_d < Lq
|
||||
mask_dv = offs_dv < Lv
|
||||
|
||||
offs_q = (
|
||||
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||
* stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
q = tl.load(
|
||||
Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0
|
||||
)
|
||||
|
||||
if BLOCK_DPE > 0:
|
||||
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
||||
offs_qpe = (
|
||||
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||
* stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_dpe[None, :]
|
||||
)
|
||||
qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
|
||||
|
||||
# stage 1: compute scores with prefix
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
|
||||
deno = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
|
||||
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
||||
offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + (
|
||||
cur_seq_prefix_start_in_loc + start_n + offs_n
|
||||
)
|
||||
offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0)
|
||||
|
||||
# load k in transposed way
|
||||
offs_buf_k = (
|
||||
offs_kv_loc[None, :] * stride_buf_kbs
|
||||
+ cur_kv_head * stride_buf_kh
|
||||
+ offs_d[:, None]
|
||||
)
|
||||
k = tl.load(
|
||||
K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
||||
)
|
||||
|
||||
qk = tl.dot(q.to(k.dtype), k)
|
||||
if BLOCK_DPE > 0:
|
||||
offs_kpe = (
|
||||
offs_kv_loc[None, :] * stride_buf_kbs
|
||||
+ cur_kv_head * stride_buf_kh
|
||||
+ offs_dpe[:, None]
|
||||
)
|
||||
kpe = tl.load(
|
||||
K_Buffer + offs_kpe,
|
||||
mask=mask_n[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
qk += tl.dot(qpe.to(kpe.dtype), kpe)
|
||||
qk *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max[:, None])
|
||||
deno = deno * re_scale + tl.sum(p, 1)
|
||||
|
||||
offs_buf_v = (
|
||||
offs_kv_loc[:, None] * stride_buf_vbs
|
||||
+ cur_kv_head * stride_buf_vh
|
||||
+ offs_dv[None, :]
|
||||
)
|
||||
v = tl.load(
|
||||
V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
||||
)
|
||||
p = p.to(v.dtype)
|
||||
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
||||
|
||||
e_max = n_e_max
|
||||
|
||||
# stage 2: compute the trianlge part
|
||||
|
||||
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
||||
for start_n in range(0, cur_block_m_end, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
mask_n = (start_n + offs_n) < cur_block_m_end
|
||||
|
||||
# load k in transposed way
|
||||
offs_k = (
|
||||
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs
|
||||
+ cur_kv_head * stride_kh
|
||||
+ offs_d[:, None]
|
||||
)
|
||||
k = tl.load(
|
||||
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
||||
)
|
||||
|
||||
qk = tl.dot(q, k, out_dtype=tl.float32)
|
||||
if BLOCK_DPE > 0:
|
||||
offs_kpe = (
|
||||
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
|
||||
* stride_kbs
|
||||
+ cur_kv_head * stride_kh
|
||||
+ offs_dpe[:, None]
|
||||
)
|
||||
kpe = tl.load(
|
||||
K_Extend + offs_kpe,
|
||||
mask=mask_n[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
qk += tl.dot(qpe, kpe)
|
||||
|
||||
qk *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
||||
start_n + offs_n[None, :]
|
||||
)
|
||||
mask_causual &= mask_m[:, None] & mask_n[None, :]
|
||||
qk = tl.where(mask_causual, qk, float("-inf"))
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max[:, None])
|
||||
deno = deno * re_scale + tl.sum(p, 1)
|
||||
|
||||
offs_v = (
|
||||
(cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
|
||||
+ cur_kv_head * stride_vh
|
||||
+ offs_dv[None, :]
|
||||
)
|
||||
v = tl.load(
|
||||
V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
||||
)
|
||||
p = p.to(v.dtype)
|
||||
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
||||
|
||||
e_max = n_e_max
|
||||
|
||||
offs_o = (
|
||||
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||
* stride_obs
|
||||
+ cur_head * stride_oh
|
||||
+ offs_dv[None, :]
|
||||
)
|
||||
tl.store(
|
||||
O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None] & mask_dv[None, :]
|
||||
)
|
||||
|
||||
|
||||
def extend_attention_fwd(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
b_seq_len_extend,
|
||||
b_start_loc_extend,
|
||||
max_len_extend,
|
||||
sm_scale=None,
|
||||
logit_cap=0.0,
|
||||
):
|
||||
"""
|
||||
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
||||
|
||||
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
|
||||
"""
|
||||
Lq, Lk, Lv = (
|
||||
q_extend.shape[-1],
|
||||
k_extend.shape[-1],
|
||||
v_extend.shape[-1],
|
||||
)
|
||||
|
||||
if Lq == 576:
|
||||
BLOCK_DMODEL = 512
|
||||
BLOCK_DPE = 64
|
||||
elif Lq == 288:
|
||||
BLOCK_DMODEL = 256
|
||||
BLOCK_DPE = 32
|
||||
elif Lq == 192:
|
||||
BLOCK_DMODEL = 128
|
||||
BLOCK_DPE = 64
|
||||
else:
|
||||
BLOCK_DMODEL = triton.next_power_of_2(Lq)
|
||||
BLOCK_DPE = 0
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
if is_hip_:
|
||||
BLOCK_M, BLOCK_N = (64, 64)
|
||||
num_warps = 4
|
||||
|
||||
else:
|
||||
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
|
||||
if Lq <= 256:
|
||||
BLOCK_M, BLOCK_N = (128, 64)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (32, 64)
|
||||
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
||||
if Lq <= 128:
|
||||
BLOCK_M, BLOCK_N = (128, 128)
|
||||
elif Lq <= 256:
|
||||
BLOCK_M, BLOCK_N = (64, 64)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (32, 64)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
||||
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
sm_scale = sm_scale or 1.0 / (Lq**0.5)
|
||||
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
|
||||
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
||||
|
||||
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
||||
num_stages = 1
|
||||
|
||||
extra_kargs = {}
|
||||
if is_hip_:
|
||||
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
b_start_loc_extend,
|
||||
b_seq_len_extend,
|
||||
sm_scale,
|
||||
kv_group_num,
|
||||
q_extend.stride(0),
|
||||
q_extend.stride(1),
|
||||
k_extend.stride(0),
|
||||
k_extend.stride(1),
|
||||
v_extend.stride(0),
|
||||
v_extend.stride(1),
|
||||
o_extend.stride(0),
|
||||
o_extend.stride(1),
|
||||
k_buffer.stride(0),
|
||||
k_buffer.stride(1),
|
||||
v_buffer.stride(0),
|
||||
v_buffer.stride(1),
|
||||
req_to_tokens.stride(0),
|
||||
logit_cap=logit_cap,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_DPE=BLOCK_DPE,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
Lq=Lq,
|
||||
Lv=Lv,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
**extra_kargs,
|
||||
)
|
||||
|
||||
@@ -46,11 +46,9 @@ def _fwd_kernel(
|
||||
O_Extend,
|
||||
K_Buffer,
|
||||
V_Buffer,
|
||||
Req_to_tokens,
|
||||
B_req_idx,
|
||||
B_Seq_Len,
|
||||
B_Start_Loc_Extend,
|
||||
B_Seq_Len_Extend,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
sm_scale,
|
||||
kv_group_num,
|
||||
stride_qbs,
|
||||
@@ -65,7 +63,6 @@ def _fwd_kernel(
|
||||
stride_buf_kh,
|
||||
stride_buf_vbs,
|
||||
stride_buf_vh,
|
||||
stride_req_to_tokens_b,
|
||||
logit_cap: tl.constexpr,
|
||||
Lq: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
@@ -80,13 +77,10 @@ def _fwd_kernel(
|
||||
cur_block_m = tl.program_id(2)
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
cur_seq_len = tl.load(B_Seq_Len + cur_seq)
|
||||
cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq)
|
||||
cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend
|
||||
|
||||
cur_seq_prefix_start_in_loc = 0
|
||||
cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq)
|
||||
cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
|
||||
cur_seq_extend_start_idx = tl.load(qo_indptr + cur_seq)
|
||||
cur_seq_len_extend = tl.load(qo_indptr + cur_seq + 1) - cur_seq_extend_start_idx
|
||||
cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq)
|
||||
cur_seq_len_prefix = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_dv = tl.arange(0, BLOCK_DV)
|
||||
@@ -97,7 +91,7 @@ def _fwd_kernel(
|
||||
mask_dv = offs_dv < Lv
|
||||
|
||||
offs_q = (
|
||||
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||
* stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_d[None, :]
|
||||
@@ -109,7 +103,7 @@ def _fwd_kernel(
|
||||
if BLOCK_DPE > 0:
|
||||
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
||||
offs_qpe = (
|
||||
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||
* stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_dpe[None, :]
|
||||
@@ -126,10 +120,9 @@ def _fwd_kernel(
|
||||
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
||||
offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + (
|
||||
cur_seq_prefix_start_in_loc + start_n + offs_n
|
||||
offs_kv_loc = tl.load(
|
||||
kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
|
||||
)
|
||||
offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0)
|
||||
|
||||
# load k in transposed way
|
||||
offs_buf_k = (
|
||||
@@ -188,7 +181,7 @@ def _fwd_kernel(
|
||||
|
||||
# load k in transposed way
|
||||
offs_k = (
|
||||
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs
|
||||
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
||||
+ cur_kv_head * stride_kh
|
||||
+ offs_d[:, None]
|
||||
)
|
||||
@@ -199,8 +192,7 @@ def _fwd_kernel(
|
||||
qk = tl.dot(q, k, out_dtype=tl.float32)
|
||||
if BLOCK_DPE > 0:
|
||||
offs_kpe = (
|
||||
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
|
||||
* stride_kbs
|
||||
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
||||
+ cur_kv_head * stride_kh
|
||||
+ offs_dpe[:, None]
|
||||
)
|
||||
@@ -228,7 +220,7 @@ def _fwd_kernel(
|
||||
deno = deno * re_scale + tl.sum(p, 1)
|
||||
|
||||
offs_v = (
|
||||
(cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
|
||||
(cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
|
||||
+ cur_kv_head * stride_vh
|
||||
+ offs_dv[None, :]
|
||||
)
|
||||
@@ -241,7 +233,7 @@ def _fwd_kernel(
|
||||
e_max = n_e_max
|
||||
|
||||
offs_o = (
|
||||
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||
* stride_obs
|
||||
+ cur_head * stride_oh
|
||||
+ offs_dv[None, :]
|
||||
@@ -258,11 +250,9 @@ def extend_attention_fwd(
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
b_seq_len_extend,
|
||||
b_start_loc_extend,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
max_len_extend,
|
||||
sm_scale=None,
|
||||
logit_cap=0.0,
|
||||
@@ -315,7 +305,7 @@ def extend_attention_fwd(
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
sm_scale = sm_scale or 1.0 / (Lq**0.5)
|
||||
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
|
||||
batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1]
|
||||
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
||||
|
||||
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
||||
@@ -332,11 +322,9 @@ def extend_attention_fwd(
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
b_start_loc_extend,
|
||||
b_seq_len_extend,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
sm_scale,
|
||||
kv_group_num,
|
||||
q_extend.stride(0),
|
||||
@@ -351,7 +339,6 @@ def extend_attention_fwd(
|
||||
k_buffer.stride(1),
|
||||
v_buffer.stride(0),
|
||||
v_buffer.stride(1),
|
||||
req_to_tokens.stride(0),
|
||||
logit_cap=logit_cap,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_DPE=BLOCK_DPE,
|
||||
|
||||
Reference in New Issue
Block a user