Deterministic Mode: Add 1-stage triton kernel for prefill (#11147)
Co-authored-by: Minglei Zhu <mingleizhu1122@gmail.com> Co-authored-by: Binyao Jiang <bijiang@linkedin.com>
This commit is contained in:
@@ -10,7 +10,9 @@ from sglang.srt.layers.attention.triton_ops.decode_attention import (
|
||||
decode_attention_fwd_normal,
|
||||
)
|
||||
from sglang.srt.layers.attention.triton_ops.extend_attention import (
|
||||
build_unified_kv_indices,
|
||||
extend_attention_fwd,
|
||||
extend_attention_fwd_unified,
|
||||
redundant_attention,
|
||||
)
|
||||
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
||||
@@ -571,6 +573,204 @@ class TestTritonAttention(CustomTestCase):
|
||||
for B, H_Q, H_KV, D, D_V in configs:
|
||||
self._test_grouped_decode_attention_once(B, S, H_Q, H_KV, D, D_V)
|
||||
|
||||
def _test_extend_attention_unified_vs_regular_once(self, B, N_CTX, H_Q, H_KV, D):
|
||||
"""Test that unified kernel produces same results as 2-stage kernel."""
|
||||
dtype = torch.bfloat16
|
||||
|
||||
b_seq_len_prefix = torch.randint(
|
||||
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
b_seq_len_extend = torch.randint(
|
||||
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
b_seq_len = b_seq_len_prefix + b_seq_len_extend
|
||||
|
||||
b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
|
||||
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
|
||||
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
|
||||
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||
|
||||
# Setup prefix KV indices
|
||||
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
|
||||
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
|
||||
kv_indices = torch.zeros(
|
||||
(b_seq_len_prefix.sum().item(),), dtype=torch.int64, device="cuda"
|
||||
)
|
||||
|
||||
for i in range(B):
|
||||
kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange(
|
||||
b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i]
|
||||
)
|
||||
|
||||
total_token_num = torch.sum(b_seq_len).item()
|
||||
extend_token_num = torch.sum(b_seq_len_extend).item()
|
||||
k_buffer = torch.empty(
|
||||
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
v_buffer = torch.empty(
|
||||
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
|
||||
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
|
||||
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
|
||||
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
||||
|
||||
for i in range(B):
|
||||
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
|
||||
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
|
||||
extend_start = b_start_loc_extend[i]
|
||||
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
|
||||
k_extend[extend_start:extend_end] = k_buffer[
|
||||
extend_start_in_buffer:extend_end_in_buffer
|
||||
]
|
||||
v_extend[extend_start:extend_end] = v_buffer[
|
||||
extend_start_in_buffer:extend_end_in_buffer
|
||||
]
|
||||
q_extend[extend_start:extend_end] = torch.empty(
|
||||
(b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda"
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
|
||||
# Setup for extend attention
|
||||
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
|
||||
qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
|
||||
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
|
||||
|
||||
# Run 2-stage kernel
|
||||
o_regular = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
||||
extend_attention_fwd(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_regular,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
custom_mask=None,
|
||||
is_causal=True,
|
||||
mask_indptr=None,
|
||||
max_len_extend=max_len_extend,
|
||||
)
|
||||
|
||||
# Build unified KV indices
|
||||
extend_kv_indices = torch.arange(
|
||||
total_token_num - extend_token_num,
|
||||
total_token_num,
|
||||
dtype=torch.int64,
|
||||
device="cuda",
|
||||
)
|
||||
extend_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
|
||||
extend_start_loc[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||
|
||||
unified_kv_indptr, unified_kv_indices, prefix_lens = build_unified_kv_indices(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
extend_start_loc,
|
||||
b_seq_len_extend,
|
||||
extend_kv_indices,
|
||||
B,
|
||||
)
|
||||
|
||||
# Run unified kernel
|
||||
o_unified = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
||||
extend_attention_fwd_unified(
|
||||
q_extend,
|
||||
o_unified,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
qo_indptr,
|
||||
unified_kv_indptr,
|
||||
unified_kv_indices,
|
||||
prefix_lens,
|
||||
max_len_extend=max_len_extend,
|
||||
custom_mask=None,
|
||||
mask_indptr=None,
|
||||
sm_scale=None,
|
||||
logit_cap=0.0,
|
||||
is_causal=True,
|
||||
)
|
||||
|
||||
# Compare results
|
||||
self.assertTrue(
|
||||
torch.allclose(o_regular, o_unified, rtol=0.15, atol=0.15),
|
||||
f"Unified kernel output differs from 2-stage kernel. "
|
||||
f"Max diff: {(o_regular - o_unified).abs().max()}",
|
||||
)
|
||||
|
||||
def test_extend_attention_unified_vs_regular(self):
|
||||
"""Test unified kernel matches 2-stage kernel across different configs."""
|
||||
configs = [
|
||||
(4, 512, 32, 8, 128), # Standard config
|
||||
(2, 2048, 32, 8, 128), # Long sequence (test 2048 specifically)
|
||||
(8, 256, 64, 8, 80), # Non-standard head dim
|
||||
]
|
||||
|
||||
for B, N_CTX, H_Q, H_KV, D in configs:
|
||||
with self.subTest(B=B, N_CTX=N_CTX, H_Q=H_Q, H_KV=H_KV, D=D):
|
||||
self._test_extend_attention_unified_vs_regular_once(
|
||||
B, N_CTX, H_Q, H_KV, D
|
||||
)
|
||||
|
||||
def test_build_unified_kv_indices(self):
|
||||
"""Test build_unified_kv_indices correctness."""
|
||||
B = 4
|
||||
dtype = torch.int64
|
||||
device = "cuda"
|
||||
|
||||
# Setup test data
|
||||
prefix_lens = torch.tensor([10, 20, 15, 25], dtype=torch.int32, device=device)
|
||||
extend_lens = torch.tensor([5, 3, 7, 4], dtype=torch.int32, device=device)
|
||||
|
||||
# Build prefix indices
|
||||
prefix_kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
|
||||
prefix_kv_indptr[1:] = torch.cumsum(prefix_lens, dim=0)
|
||||
prefix_kv_indices = torch.arange(
|
||||
prefix_lens.sum().item(), dtype=dtype, device=device
|
||||
)
|
||||
|
||||
# Build extend indices
|
||||
extend_start_loc = torch.zeros((B,), dtype=torch.int32, device=device)
|
||||
extend_start_loc[1:] = torch.cumsum(extend_lens[:-1], dim=0)
|
||||
extend_kv_indices = torch.arange(
|
||||
prefix_lens.sum().item(),
|
||||
prefix_lens.sum().item() + extend_lens.sum().item(),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Build unified indices
|
||||
unified_kv_indptr, unified_kv_indices, returned_prefix_lens = (
|
||||
build_unified_kv_indices(
|
||||
prefix_kv_indptr,
|
||||
prefix_kv_indices,
|
||||
extend_start_loc,
|
||||
extend_lens,
|
||||
extend_kv_indices,
|
||||
B,
|
||||
)
|
||||
)
|
||||
|
||||
# Verify unified_kv_indptr
|
||||
expected_lens = prefix_lens + extend_lens
|
||||
expected_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
|
||||
expected_indptr[1:] = torch.cumsum(expected_lens, dim=0)
|
||||
self.assertTrue(torch.equal(unified_kv_indptr, expected_indptr))
|
||||
|
||||
# Verify prefix_lens
|
||||
self.assertTrue(torch.equal(returned_prefix_lens, prefix_lens))
|
||||
|
||||
# Verify unified_kv_indices structure
|
||||
for i in range(B):
|
||||
start_idx = int(unified_kv_indptr[i])
|
||||
end_idx = int(unified_kv_indptr[i + 1])
|
||||
prefix_len = int(prefix_lens[i])
|
||||
extend_len = int(extend_lens[i])
|
||||
|
||||
# Check that prefix and extend are concatenated correctly
|
||||
unified_seq = unified_kv_indices[start_idx:end_idx]
|
||||
self.assertEqual(len(unified_seq), prefix_len + extend_len)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user