remove assertion in triton attention and add an unit test (#1385)
This commit is contained in:
@@ -199,8 +199,6 @@ def _decode_att_m_fwd(
|
||||
BLOCK = 32
|
||||
# shape constraints
|
||||
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
||||
assert Lq == Lk
|
||||
assert Lk in {16, 32, 64, 96, 128, 256}
|
||||
|
||||
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
||||
|
||||
@@ -482,8 +480,6 @@ def _decode_grouped_att_m_fwd(
|
||||
BLOCK = 32
|
||||
# shape constraints
|
||||
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
||||
assert Lq == Lk
|
||||
assert Lk in {16, 32, 64, 96, 128, 256, 576, 288}
|
||||
|
||||
if Lk == 576:
|
||||
BLOCK_DMODEL = 512
|
||||
|
||||
@@ -277,12 +277,6 @@ def extend_attention_fwd(
|
||||
o_extend.shape[-1],
|
||||
)
|
||||
|
||||
assert Lq == Lk and Lv == Lo
|
||||
|
||||
# TODO: is the assertion necessary?
|
||||
assert Lq in {16, 32, 64, 96, 128, 256, 576, 288}
|
||||
assert Lv in {16, 32, 64, 96, 128, 256, 512}
|
||||
|
||||
if Lq == 576:
|
||||
BLOCK_DMODEL = 512
|
||||
BLOCK_DPE = 64
|
||||
@@ -395,104 +389,3 @@ def redundant_attention(
|
||||
pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i]
|
||||
o_extend[pt : pt + cur_seq_len_extend] = o_buffer[pl:pr]
|
||||
pt += cur_seq_len_extend
|
||||
|
||||
|
||||
def test_once(B, N_CTX, H_Q, H_KV, D):
|
||||
dtype = torch.float16
|
||||
|
||||
b_seq_len_prefix = torch.randint(
|
||||
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
b_seq_len_extend = torch.randint(
|
||||
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
b_seq_len = b_seq_len_prefix + b_seq_len_extend
|
||||
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
|
||||
|
||||
b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda")
|
||||
req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32, device="cuda")
|
||||
b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
|
||||
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
|
||||
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
|
||||
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||
for i in range(B):
|
||||
req_to_tokens[i, : b_seq_len[i]] = torch.arange(
|
||||
b_start_loc[i], b_start_loc[i] + b_seq_len[i]
|
||||
)
|
||||
|
||||
total_token_num = torch.sum(b_seq_len).item()
|
||||
extend_token_num = torch.sum(b_seq_len_extend).item()
|
||||
k_buffer = torch.empty(
|
||||
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
v_buffer = torch.empty(
|
||||
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
|
||||
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
|
||||
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
|
||||
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
||||
for i in range(B):
|
||||
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
|
||||
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
|
||||
extend_start = b_start_loc_extend[i]
|
||||
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
|
||||
k_extend[extend_start:extend_end] = k_buffer[
|
||||
extend_start_in_buffer:extend_end_in_buffer
|
||||
]
|
||||
v_extend[extend_start:extend_end] = v_buffer[
|
||||
extend_start_in_buffer:extend_end_in_buffer
|
||||
]
|
||||
q_extend[extend_start:extend_end] = torch.empty(
|
||||
(b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda"
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
|
||||
o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
||||
o_redundant = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
||||
|
||||
b_seq_len_extend = b_seq_len - b_seq_len_prefix
|
||||
b_start_loc_extend = torch.zeros_like(b_seq_len)
|
||||
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
|
||||
extend_attention_fwd(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_seq_len_prefix,
|
||||
b_start_loc_extend,
|
||||
b_seq_len_extend,
|
||||
max_len_in_batch,
|
||||
max_len_extend,
|
||||
)
|
||||
|
||||
redundant_attention(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_redundant,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_seq_len_prefix,
|
||||
max_len_in_batch,
|
||||
)
|
||||
|
||||
print("Mean: ", torch.mean(torch.abs(o_extend - o_redundant)))
|
||||
print("Max: ", torch.max(torch.abs(o_extend - o_redundant)))
|
||||
|
||||
assert torch.allclose(o_extend, o_redundant, rtol=1e-2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_once(19, 12331, 12, 4, 128)
|
||||
test_once(19, 12331, 12, 4, 96)
|
||||
|
||||
@@ -151,8 +151,6 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
||||
BLOCK = 64
|
||||
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 96, 128, 256}
|
||||
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
|
||||
213
test/srt/test_triton_attention_kernels.py
Normal file
213
test/srt/test_triton_attention_kernels.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.triton_attention.decode_attention import decode_attention_fwd
|
||||
from sglang.srt.layers.triton_attention.extend_attention import (
|
||||
extend_attention_fwd,
|
||||
redundant_attention,
|
||||
)
|
||||
from sglang.srt.layers.triton_attention.prefill_attention import context_attention_fwd
|
||||
|
||||
|
||||
class TestExtendAttention(unittest.TestCase):
|
||||
|
||||
def _set_all_seeds(self, seed):
|
||||
"""Set all random seeds for reproducibility."""
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
def setUp(self):
|
||||
# Set seeds before each test method
|
||||
self._set_all_seeds(42)
|
||||
|
||||
def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D):
|
||||
dtype = torch.bfloat16
|
||||
|
||||
b_seq_len_prefix = torch.randint(
|
||||
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
b_seq_len_extend = torch.randint(
|
||||
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
b_seq_len = b_seq_len_prefix + b_seq_len_extend
|
||||
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
|
||||
|
||||
b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda")
|
||||
req_to_tokens = torch.empty(
|
||||
(B, max_len_in_batch), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
|
||||
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
|
||||
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
|
||||
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||
for i in range(B):
|
||||
req_to_tokens[i, : b_seq_len[i]] = torch.arange(
|
||||
b_start_loc[i], b_start_loc[i] + b_seq_len[i]
|
||||
)
|
||||
|
||||
total_token_num = torch.sum(b_seq_len).item()
|
||||
extend_token_num = torch.sum(b_seq_len_extend).item()
|
||||
k_buffer = torch.empty(
|
||||
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
v_buffer = torch.empty(
|
||||
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
|
||||
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
|
||||
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
|
||||
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
||||
for i in range(B):
|
||||
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
|
||||
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
|
||||
extend_start = b_start_loc_extend[i]
|
||||
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
|
||||
k_extend[extend_start:extend_end] = k_buffer[
|
||||
extend_start_in_buffer:extend_end_in_buffer
|
||||
]
|
||||
v_extend[extend_start:extend_end] = v_buffer[
|
||||
extend_start_in_buffer:extend_end_in_buffer
|
||||
]
|
||||
q_extend[extend_start:extend_end] = torch.empty(
|
||||
(b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda"
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
|
||||
o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
||||
o_redundant = torch.empty(
|
||||
(extend_token_num, H_Q, D), dtype=dtype, device="cuda"
|
||||
)
|
||||
|
||||
b_seq_len_extend = b_seq_len - b_seq_len_prefix
|
||||
b_start_loc_extend = torch.zeros_like(b_seq_len)
|
||||
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
|
||||
extend_attention_fwd(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_seq_len_prefix,
|
||||
b_start_loc_extend,
|
||||
b_seq_len_extend,
|
||||
max_len_in_batch,
|
||||
max_len_extend,
|
||||
)
|
||||
|
||||
redundant_attention(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_redundant,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_seq_len_prefix,
|
||||
max_len_in_batch,
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(o_extend, o_redundant, rtol=1e-2))
|
||||
|
||||
def test_extend_attention(self):
|
||||
|
||||
# Define the varying parameter values
|
||||
attention_values = [128, 96, 80, 13]
|
||||
|
||||
# Loop through the values and call the method
|
||||
for value in attention_values:
|
||||
self._test_extend_attention_once(19, 12331, 12, 4, value)
|
||||
|
||||
def _test_context_attention_once(self, head_dim):
|
||||
# Set up a simple test case
|
||||
batch_size = 2
|
||||
num_heads = 4
|
||||
seq_lens = [8, 12]
|
||||
max_seq_len = max(seq_lens)
|
||||
|
||||
# Create random input tensors
|
||||
q = torch.randn(sum(seq_lens), num_heads, head_dim, device="cuda")
|
||||
k = torch.randn(sum(seq_lens), num_heads, head_dim, device="cuda")
|
||||
v = torch.randn(sum(seq_lens), num_heads, head_dim, device="cuda")
|
||||
o = torch.zeros(sum(seq_lens), num_heads, head_dim, device="cuda")
|
||||
|
||||
# Create b_start_loc and b_seq_len tensors
|
||||
b_start_loc = torch.tensor([0, seq_lens[0]], device="cuda")
|
||||
b_seq_len = torch.tensor(seq_lens, device="cuda")
|
||||
|
||||
context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_seq_len)
|
||||
|
||||
def test_context_attention(self):
|
||||
# Here we just to ensure there is no error
|
||||
# TODO: correctnesss test
|
||||
head_dim = [128, 96, 80, 13]
|
||||
|
||||
for dim in head_dim:
|
||||
self._test_context_attention_once(dim)
|
||||
|
||||
def _test_decode_attention_once(self, B, H_Q, H_KV, D):
|
||||
dtype = torch.bfloat16
|
||||
seq_len = 10 # This represents the number of tokens already in the sequence
|
||||
total_tokens = B * seq_len
|
||||
sm_scale = 1.0 / (D**0.5)
|
||||
|
||||
# q represents the new token being generated, one per batch
|
||||
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
|
||||
|
||||
# k_buffer and v_buffer represent all previous tokens
|
||||
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
|
||||
v_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
|
||||
|
||||
# o will have the same shape as q
|
||||
o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda")
|
||||
|
||||
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
|
||||
b_req_idx = torch.arange(B, device="cuda")
|
||||
b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda")
|
||||
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
||||
|
||||
decode_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
seq_len,
|
||||
total_tokens,
|
||||
sm_scale,
|
||||
)
|
||||
|
||||
def test_decode_attention(self):
|
||||
# Here we just to ensure there is no error
|
||||
# TODO: correctnesss test
|
||||
|
||||
# Test configurations
|
||||
configs = [
|
||||
(2, 4, 4, 64), # MHA
|
||||
(2, 4, 2, 64), # GQA
|
||||
(2, 4, 4, 80), # Non-standard head dim
|
||||
(2, 4, 4, 13), # Prime number head dim
|
||||
]
|
||||
|
||||
for B, H_Q, H_KV, D in configs:
|
||||
self._test_decode_attention_once(B, H_Q, H_KV, D)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user