Update Triton decode backend interface (#3292)
This commit is contained in:
@@ -5,6 +5,9 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.attention import AttentionBackend
|
from sglang.srt.layers.attention import AttentionBackend
|
||||||
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||||
|
create_flashinfer_kv_indices_triton,
|
||||||
|
)
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
|
|
||||||
@@ -29,6 +32,12 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
self.decode_attention_fwd = decode_attention_fwd
|
self.decode_attention_fwd = decode_attention_fwd
|
||||||
self.extend_attention_fwd = extend_attention_fwd
|
self.extend_attention_fwd = extend_attention_fwd
|
||||||
|
|
||||||
|
max_bs = model_runner.req_to_token_pool.size
|
||||||
|
self.kv_indptr = torch.zeros(
|
||||||
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||||
|
)
|
||||||
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
|
|
||||||
self.num_head = (
|
self.num_head = (
|
||||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||||
)
|
)
|
||||||
@@ -58,11 +67,32 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
max_extend_len = None
|
max_extend_len = None
|
||||||
|
|
||||||
|
kv_indptr = self.kv_indptr
|
||||||
|
bs = len(forward_batch.req_pool_indices)
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||||
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
|
kv_indices = torch.empty(
|
||||||
|
forward_batch.seq_lens_sum, dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
forward_batch.req_to_token_pool.req_to_token,
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
kv_indptr,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
forward_batch.req_to_token_pool.req_to_token.stride(0),
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
attn_logits = None
|
attn_logits = None
|
||||||
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
||||||
|
|
||||||
self.forward_metadata = attn_logits, max_extend_len
|
kv_indptr = None
|
||||||
|
kv_indices = None
|
||||||
|
|
||||||
|
self.forward_metadata = attn_logits, max_extend_len, kv_indptr, kv_indices
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(self, max_bs: int):
|
||||||
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
||||||
@@ -73,7 +103,12 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
self.cuda_graph_attn_logits = torch.empty(
|
self.cuda_graph_attn_logits = torch.empty(
|
||||||
(max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
|
(max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device="cuda",
|
device=self.device,
|
||||||
|
)
|
||||||
|
self.cuda_graph_kv_indices = torch.zeros(
|
||||||
|
(max_bs * self.cuda_graph_max_seq_len),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
@@ -90,9 +125,25 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
assert forward_mode.is_decode(), "Not supported"
|
assert forward_mode.is_decode(), "Not supported"
|
||||||
assert spec_info is None, "Not supported"
|
assert spec_info is None, "Not supported"
|
||||||
|
|
||||||
|
kv_indptr = self.kv_indptr
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||||
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
|
kv_indices = self.cuda_graph_kv_indices
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
kv_indptr,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
)
|
||||||
|
|
||||||
self.forward_metadata = (
|
self.forward_metadata = (
|
||||||
self.cuda_graph_attn_logits,
|
self.cuda_graph_attn_logits,
|
||||||
None,
|
None,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
@@ -109,6 +160,20 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
self.cuda_graph_start_loc.zero_()
|
self.cuda_graph_start_loc.zero_()
|
||||||
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
||||||
|
|
||||||
|
kv_indptr = self.kv_indptr
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
|
||||||
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
|
kv_indices = self.cuda_graph_kv_indices
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
req_pool_indices[:bs],
|
||||||
|
seq_lens[:bs],
|
||||||
|
kv_indptr,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
)
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
@@ -132,7 +197,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
layer, forward_batch.out_cache_loc, k, v
|
layer, forward_batch.out_cache_loc, k, v
|
||||||
)
|
)
|
||||||
|
|
||||||
_, max_extend_len = self.forward_metadata
|
_, max_extend_len, _, _ = self.forward_metadata
|
||||||
self.extend_attention_fwd(
|
self.extend_attention_fwd(
|
||||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||||
k.contiguous(),
|
k.contiguous(),
|
||||||
@@ -170,7 +235,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
else:
|
else:
|
||||||
o = torch.empty_like(q)
|
o = torch.empty_like(q)
|
||||||
|
|
||||||
attn_logits, _ = self.forward_metadata
|
attn_logits, _, kv_indptr, kv_indices = self.forward_metadata
|
||||||
|
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
@@ -182,9 +247,8 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||||
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||||
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||||
forward_batch.req_to_token_pool.req_to_token,
|
kv_indptr,
|
||||||
forward_batch.req_pool_indices,
|
kv_indices,
|
||||||
forward_batch.seq_lens,
|
|
||||||
attn_logits,
|
attn_logits,
|
||||||
self.num_kv_splits,
|
self.num_kv_splits,
|
||||||
layer.scaling,
|
layer.scaling,
|
||||||
|
|||||||
@@ -49,11 +49,9 @@ def _fwd_kernel_stage1(
|
|||||||
K_Buffer,
|
K_Buffer,
|
||||||
V_Buffer,
|
V_Buffer,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
Req_to_tokens,
|
kv_indptr,
|
||||||
B_req_idx,
|
kv_indices,
|
||||||
B_Seqlen,
|
|
||||||
Att_Out,
|
Att_Out,
|
||||||
stride_req_to_tokens_b,
|
|
||||||
stride_qbs,
|
stride_qbs,
|
||||||
stride_qh,
|
stride_qh,
|
||||||
stride_buf_kbs,
|
stride_buf_kbs,
|
||||||
@@ -82,8 +80,9 @@ def _fwd_kernel_stage1(
|
|||||||
offs_dv = tl.arange(0, BLOCK_DV)
|
offs_dv = tl.arange(0, BLOCK_DV)
|
||||||
mask_d = offs_d < Lk
|
mask_d = offs_d < Lk
|
||||||
mask_dv = offs_dv < Lv
|
mask_dv = offs_dv < Lv
|
||||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
|
||||||
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
|
cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
|
||||||
|
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
||||||
|
|
||||||
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
||||||
q = tl.load(Q + off_q, mask=mask_d, other=0.0)
|
q = tl.load(Q + off_q, mask=mask_d, other=0.0)
|
||||||
@@ -100,7 +99,7 @@ def _fwd_kernel_stage1(
|
|||||||
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||||
offs_n = start_n + tl.arange(0, BLOCK_N)
|
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||||
kv_loc = tl.load(
|
kv_loc = tl.load(
|
||||||
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
|
kv_indices + cur_batch_kv_start_idx + offs_n,
|
||||||
mask=offs_n < split_kv_end,
|
mask=offs_n < split_kv_end,
|
||||||
other=0,
|
other=0,
|
||||||
)
|
)
|
||||||
@@ -173,9 +172,8 @@ def _decode_att_m_fwd(
|
|||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
att_out,
|
att_out,
|
||||||
Req_to_tokens,
|
kv_indptr,
|
||||||
B_req_idx,
|
kv_indices,
|
||||||
B_Seqlen,
|
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
@@ -188,7 +186,7 @@ def _decode_att_m_fwd(
|
|||||||
Lk = k_buffer.shape[-1]
|
Lk = k_buffer.shape[-1]
|
||||||
Lv = v_buffer.shape[-1]
|
Lv = v_buffer.shape[-1]
|
||||||
|
|
||||||
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
|
||||||
|
|
||||||
grid = (batch, head_num, NUM_KV_SPLITS)
|
grid = (batch, head_num, NUM_KV_SPLITS)
|
||||||
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
||||||
@@ -208,11 +206,9 @@ def _decode_att_m_fwd(
|
|||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
Req_to_tokens,
|
kv_indptr,
|
||||||
B_req_idx,
|
kv_indices,
|
||||||
B_Seqlen,
|
|
||||||
att_out,
|
att_out,
|
||||||
Req_to_tokens.stride(0),
|
|
||||||
q.stride(0),
|
q.stride(0),
|
||||||
q.stride(1),
|
q.stride(1),
|
||||||
k_buffer.stride(0),
|
k_buffer.stride(0),
|
||||||
@@ -241,11 +237,9 @@ def _fwd_grouped_kernel_stage1(
|
|||||||
K_Buffer,
|
K_Buffer,
|
||||||
V_Buffer,
|
V_Buffer,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
Req_to_tokens,
|
kv_indptr,
|
||||||
B_req_idx,
|
kv_indices,
|
||||||
B_Seqlen,
|
|
||||||
Att_Out,
|
Att_Out,
|
||||||
stride_req_to_tokens_b,
|
|
||||||
stride_qbs,
|
stride_qbs,
|
||||||
stride_qh,
|
stride_qh,
|
||||||
stride_buf_kbs,
|
stride_buf_kbs,
|
||||||
@@ -284,8 +278,9 @@ def _fwd_grouped_kernel_stage1(
|
|||||||
offs_dv = tl.arange(0, BLOCK_DV)
|
offs_dv = tl.arange(0, BLOCK_DV)
|
||||||
mask_d = offs_d < Lk
|
mask_d = offs_d < Lk
|
||||||
mask_dv = offs_dv < Lv
|
mask_dv = offs_dv < Lv
|
||||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
|
||||||
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
|
cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
|
||||||
|
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
||||||
|
|
||||||
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
|
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
|
||||||
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
|
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
|
||||||
@@ -312,7 +307,7 @@ def _fwd_grouped_kernel_stage1(
|
|||||||
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||||
offs_n = start_n + tl.arange(0, BLOCK_N)
|
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||||
kv_loc = tl.load(
|
kv_loc = tl.load(
|
||||||
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
|
kv_indices + cur_batch_kv_start_idx + offs_n,
|
||||||
mask=offs_n < split_kv_end,
|
mask=offs_n < split_kv_end,
|
||||||
other=0,
|
other=0,
|
||||||
)
|
)
|
||||||
@@ -400,9 +395,8 @@ def _decode_grouped_att_m_fwd(
|
|||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
att_out,
|
att_out,
|
||||||
Req_to_tokens,
|
kv_indptr,
|
||||||
B_req_idx,
|
kv_indices,
|
||||||
B_Seqlen,
|
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
@@ -426,7 +420,7 @@ def _decode_grouped_att_m_fwd(
|
|||||||
BLOCK_DPE = 0
|
BLOCK_DPE = 0
|
||||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
|
||||||
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
||||||
|
|
||||||
BLOCK_H = 16
|
BLOCK_H = 16
|
||||||
@@ -450,11 +444,9 @@ def _decode_grouped_att_m_fwd(
|
|||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
Req_to_tokens,
|
kv_indptr,
|
||||||
B_req_idx,
|
kv_indices,
|
||||||
B_Seqlen,
|
|
||||||
att_out,
|
att_out,
|
||||||
Req_to_tokens.stride(0),
|
|
||||||
q.stride(0),
|
q.stride(0),
|
||||||
q.stride(1),
|
q.stride(1),
|
||||||
k_buffer.stride(0),
|
k_buffer.stride(0),
|
||||||
@@ -485,7 +477,7 @@ def _decode_grouped_att_m_fwd(
|
|||||||
def _fwd_kernel_stage2(
|
def _fwd_kernel_stage2(
|
||||||
Mid_O,
|
Mid_O,
|
||||||
O,
|
O,
|
||||||
B_Seqlen,
|
kv_indptr,
|
||||||
stride_mid_ob,
|
stride_mid_ob,
|
||||||
stride_mid_oh,
|
stride_mid_oh,
|
||||||
stride_mid_os,
|
stride_mid_os,
|
||||||
@@ -498,7 +490,9 @@ def _fwd_kernel_stage2(
|
|||||||
cur_batch = tl.program_id(0)
|
cur_batch = tl.program_id(0)
|
||||||
cur_head = tl.program_id(1)
|
cur_head = tl.program_id(1)
|
||||||
|
|
||||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load(
|
||||||
|
kv_indptr + cur_batch
|
||||||
|
)
|
||||||
|
|
||||||
offs_d = tl.arange(0, BLOCK_DV)
|
offs_d = tl.arange(0, BLOCK_DV)
|
||||||
mask_d = offs_d < Lv
|
mask_d = offs_d < Lv
|
||||||
@@ -542,7 +536,7 @@ def _decode_softmax_reducev_fwd(
|
|||||||
q,
|
q,
|
||||||
o,
|
o,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
b_seq_len,
|
kv_indptr,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
):
|
):
|
||||||
batch, head_num = q.shape[0], q.shape[1]
|
batch, head_num = q.shape[0], q.shape[1]
|
||||||
@@ -561,7 +555,7 @@ def _decode_softmax_reducev_fwd(
|
|||||||
_fwd_kernel_stage2[grid](
|
_fwd_kernel_stage2[grid](
|
||||||
logits,
|
logits,
|
||||||
o,
|
o,
|
||||||
b_seq_len,
|
kv_indptr,
|
||||||
logits.stride(0),
|
logits.stride(0),
|
||||||
logits.stride(1),
|
logits.stride(1),
|
||||||
logits.stride(2),
|
logits.stride(2),
|
||||||
@@ -581,9 +575,8 @@ def decode_attention_fwd_normal(
|
|||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
req_to_token,
|
kv_indptr,
|
||||||
b_req_idx,
|
kv_indices,
|
||||||
b_seq_len,
|
|
||||||
attn_logits,
|
attn_logits,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
@@ -594,14 +587,13 @@ def decode_attention_fwd_normal(
|
|||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
req_to_token,
|
kv_indptr,
|
||||||
b_req_idx,
|
kv_indices,
|
||||||
b_seq_len,
|
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
)
|
)
|
||||||
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
|
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits)
|
||||||
|
|
||||||
|
|
||||||
def decode_attention_fwd_grouped(
|
def decode_attention_fwd_grouped(
|
||||||
@@ -609,9 +601,8 @@ def decode_attention_fwd_grouped(
|
|||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
req_to_token,
|
kv_indptr,
|
||||||
b_req_idx,
|
kv_indices,
|
||||||
b_seq_len,
|
|
||||||
attn_logits,
|
attn_logits,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
@@ -622,14 +613,13 @@ def decode_attention_fwd_grouped(
|
|||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
req_to_token,
|
kv_indptr,
|
||||||
b_req_idx,
|
kv_indices,
|
||||||
b_seq_len,
|
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
)
|
)
|
||||||
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
|
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits)
|
||||||
|
|
||||||
|
|
||||||
def decode_attention_fwd(
|
def decode_attention_fwd(
|
||||||
@@ -637,9 +627,8 @@ def decode_attention_fwd(
|
|||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
req_to_token,
|
kv_indptr,
|
||||||
b_req_idx,
|
kv_indices,
|
||||||
b_seq_len,
|
|
||||||
attn_logits,
|
attn_logits,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
@@ -655,9 +644,8 @@ def decode_attention_fwd(
|
|||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
req_to_token,
|
kv_indptr,
|
||||||
b_req_idx,
|
kv_indices,
|
||||||
b_seq_len,
|
|
||||||
attn_logits,
|
attn_logits,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
@@ -670,9 +658,8 @@ def decode_attention_fwd(
|
|||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
req_to_token,
|
kv_indptr,
|
||||||
b_req_idx,
|
kv_indices,
|
||||||
b_seq_len,
|
|
||||||
attn_logits,
|
attn_logits,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
|
|||||||
@@ -194,10 +194,12 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
# o will have the same shape as q
|
# o will have the same shape as q
|
||||||
o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda")
|
o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
|
|
||||||
b_req_idx = torch.arange(B, device="cuda")
|
|
||||||
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
||||||
|
|
||||||
|
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
|
||||||
|
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len[:B], dim=0)
|
||||||
|
kv_indices = torch.arange(total_tokens, device="cuda")
|
||||||
|
|
||||||
attn_logits = torch.empty(
|
attn_logits = torch.empty(
|
||||||
(B, H_Q, num_kv_splits, D + 1),
|
(B, H_Q, num_kv_splits, D + 1),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
@@ -209,9 +211,8 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
req_to_token,
|
kv_indptr,
|
||||||
b_req_idx,
|
kv_indices,
|
||||||
b_seq_len,
|
|
||||||
attn_logits,
|
attn_logits,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
@@ -250,10 +251,12 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
|
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
|
||||||
o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
|
o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
|
|
||||||
b_req_idx = torch.arange(B, device="cuda")
|
|
||||||
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
||||||
|
|
||||||
|
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
|
||||||
|
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len[:B], dim=0)
|
||||||
|
kv_indices = torch.arange(total_tokens, device="cuda")
|
||||||
|
|
||||||
attn_logits = torch.empty(
|
attn_logits = torch.empty(
|
||||||
(B, H_Q, num_kv_splits, D_V + 1),
|
(B, H_Q, num_kv_splits, D_V + 1),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
@@ -265,9 +268,8 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
req_to_token,
|
kv_indptr,
|
||||||
b_req_idx,
|
kv_indices,
|
||||||
b_seq_len,
|
|
||||||
attn_logits,
|
attn_logits,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
@@ -284,9 +286,8 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o_grouped,
|
o_grouped,
|
||||||
req_to_token,
|
kv_indptr,
|
||||||
b_req_idx,
|
kv_indices,
|
||||||
b_seq_len,
|
|
||||||
attn_logits1,
|
attn_logits1,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
|
|||||||
Reference in New Issue
Block a user