Deterministic Mode: Add 1-stage triton kernel for prefill (#11147)

Co-authored-by: Minglei Zhu <mingleizhu1122@gmail.com>
Co-authored-by: Binyao Jiang <bijiang@linkedin.com>
This commit is contained in:
Stefan He
2025-10-19 10:47:36 -07:00
committed by GitHub
parent 7a020e0f3b
commit 4fff1ec1d9
5 changed files with 879 additions and 46 deletions

View File

@@ -10,7 +10,9 @@ from sglang.srt.layers.attention.triton_ops.decode_attention import (
decode_attention_fwd_normal,
)
from sglang.srt.layers.attention.triton_ops.extend_attention import (
build_unified_kv_indices,
extend_attention_fwd,
extend_attention_fwd_unified,
redundant_attention,
)
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
@@ -571,6 +573,204 @@ class TestTritonAttention(CustomTestCase):
for B, H_Q, H_KV, D, D_V in configs:
self._test_grouped_decode_attention_once(B, S, H_Q, H_KV, D, D_V)
def _test_extend_attention_unified_vs_regular_once(self, B, N_CTX, H_Q, H_KV, D):
"""Test that unified kernel produces same results as 2-stage kernel."""
dtype = torch.bfloat16
b_seq_len_prefix = torch.randint(
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
)
b_seq_len_extend = torch.randint(
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
)
b_seq_len = b_seq_len_prefix + b_seq_len_extend
b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
# Setup prefix KV indices
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
kv_indices = torch.zeros(
(b_seq_len_prefix.sum().item(),), dtype=torch.int64, device="cuda"
)
for i in range(B):
kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange(
b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i]
)
total_token_num = torch.sum(b_seq_len).item()
extend_token_num = torch.sum(b_seq_len_extend).item()
k_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
v_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
for i in range(B):
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
extend_start = b_start_loc_extend[i]
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
k_extend[extend_start:extend_end] = k_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
v_extend[extend_start:extend_end] = v_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
q_extend[extend_start:extend_end] = torch.empty(
(b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
# Setup for extend attention
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
# Run 2-stage kernel
o_regular = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
extend_attention_fwd(
q_extend,
k_extend,
v_extend,
o_regular,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask=None,
is_causal=True,
mask_indptr=None,
max_len_extend=max_len_extend,
)
# Build unified KV indices
extend_kv_indices = torch.arange(
total_token_num - extend_token_num,
total_token_num,
dtype=torch.int64,
device="cuda",
)
extend_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
extend_start_loc[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
unified_kv_indptr, unified_kv_indices, prefix_lens = build_unified_kv_indices(
kv_indptr,
kv_indices,
extend_start_loc,
b_seq_len_extend,
extend_kv_indices,
B,
)
# Run unified kernel
o_unified = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
extend_attention_fwd_unified(
q_extend,
o_unified,
k_buffer,
v_buffer,
qo_indptr,
unified_kv_indptr,
unified_kv_indices,
prefix_lens,
max_len_extend=max_len_extend,
custom_mask=None,
mask_indptr=None,
sm_scale=None,
logit_cap=0.0,
is_causal=True,
)
# Compare results
self.assertTrue(
torch.allclose(o_regular, o_unified, rtol=0.15, atol=0.15),
f"Unified kernel output differs from 2-stage kernel. "
f"Max diff: {(o_regular - o_unified).abs().max()}",
)
def test_extend_attention_unified_vs_regular(self):
"""Test unified kernel matches 2-stage kernel across different configs."""
configs = [
(4, 512, 32, 8, 128), # Standard config
(2, 2048, 32, 8, 128), # Long sequence (test 2048 specifically)
(8, 256, 64, 8, 80), # Non-standard head dim
]
for B, N_CTX, H_Q, H_KV, D in configs:
with self.subTest(B=B, N_CTX=N_CTX, H_Q=H_Q, H_KV=H_KV, D=D):
self._test_extend_attention_unified_vs_regular_once(
B, N_CTX, H_Q, H_KV, D
)
def test_build_unified_kv_indices(self):
"""Test build_unified_kv_indices correctness."""
B = 4
dtype = torch.int64
device = "cuda"
# Setup test data
prefix_lens = torch.tensor([10, 20, 15, 25], dtype=torch.int32, device=device)
extend_lens = torch.tensor([5, 3, 7, 4], dtype=torch.int32, device=device)
# Build prefix indices
prefix_kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
prefix_kv_indptr[1:] = torch.cumsum(prefix_lens, dim=0)
prefix_kv_indices = torch.arange(
prefix_lens.sum().item(), dtype=dtype, device=device
)
# Build extend indices
extend_start_loc = torch.zeros((B,), dtype=torch.int32, device=device)
extend_start_loc[1:] = torch.cumsum(extend_lens[:-1], dim=0)
extend_kv_indices = torch.arange(
prefix_lens.sum().item(),
prefix_lens.sum().item() + extend_lens.sum().item(),
dtype=dtype,
device=device,
)
# Build unified indices
unified_kv_indptr, unified_kv_indices, returned_prefix_lens = (
build_unified_kv_indices(
prefix_kv_indptr,
prefix_kv_indices,
extend_start_loc,
extend_lens,
extend_kv_indices,
B,
)
)
# Verify unified_kv_indptr
expected_lens = prefix_lens + extend_lens
expected_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
expected_indptr[1:] = torch.cumsum(expected_lens, dim=0)
self.assertTrue(torch.equal(unified_kv_indptr, expected_indptr))
# Verify prefix_lens
self.assertTrue(torch.equal(returned_prefix_lens, prefix_lens))
# Verify unified_kv_indices structure
for i in range(B):
start_idx = int(unified_kv_indptr[i])
end_idx = int(unified_kv_indptr[i + 1])
prefix_len = int(prefix_lens[i])
extend_len = int(extend_lens[i])
# Check that prefix and extend are concatenated correctly
unified_seq = unified_kv_indices[start_idx:end_idx]
self.assertEqual(len(unified_seq), prefix_len + extend_len)
if __name__ == "__main__":
unittest.main()