Files
sglang/test/srt/test_create_kvindices.py
2025-08-21 18:15:06 -07:00

118 lines
3.9 KiB
Python

import itertools
import unittest
import numpy as np
import torch
from sglang.srt.layers.attention.utils import (
create_flashinfer_kv_indices_triton,
create_flashmla_kv_indices_triton,
)
from sglang.test.test_utils import CustomTestCase
class TestCreateKvIndices(CustomTestCase):
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _run_test(self, batch, max_batch, max_context_len, page_size):
np.random.seed(9)
PAGE_SIZE = page_size
req_to_token = torch.arange(
max_batch * max_context_len, dtype=torch.int32, device="cuda"
).reshape((max_batch, max_context_len))
# the block table
req_pool_indices = torch.tensor(
torch.from_numpy(
np.random.choice(range(max_batch), size=batch, replace=False)
),
dtype=torch.int32,
device="cuda",
)
seq_lens = torch.tensor(
torch.from_numpy(
np.random.choice(range(max_context_len), size=batch, replace=False)
),
dtype=torch.int32,
device="cuda",
)
num_pages_per_req = (seq_lens + PAGE_SIZE - 1) // PAGE_SIZE
kv_indptr = torch.zeros((batch + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1:] = torch.cumsum(num_pages_per_req, dim=0)
# ref
kv_indices_ref = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
seq_lens_cpu = seq_lens.cpu().numpy()
for i in range(batch):
kv_indptr_req = kv_indptr[i]
num_toks_seq = seq_lens_cpu[i]
curr_req_pool = req_pool_indices_cpu[i]
curr_num_pages = num_pages_per_req[i]
curr_token_ids = req_to_token[curr_req_pool]
curr_pages = (curr_token_ids[:num_toks_seq] // PAGE_SIZE).unique()
assert (
len(curr_pages) == curr_num_pages
), f"req {i} has #{curr_num_pages} pages, but got {len(curr_pages)} pages"
kv_indices_ref[kv_indptr_req : kv_indptr_req + curr_num_pages] = curr_pages
# triton
kv_indices_triton = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
create_flashinfer_kv_indices_triton[(batch,)](
req_to_token,
req_pool_indices,
seq_lens,
kv_indptr,
None,
kv_indices_triton,
req_to_token.size(1),
PAGE_SIZE,
)
max_pages = max_context_len // PAGE_SIZE
kv_indices_flashmla = torch.empty(
batch, max_pages, dtype=torch.int32, device="cuda"
)
create_flashmla_kv_indices_triton[(batch,)](
req_to_token,
req_pool_indices,
seq_lens,
None,
kv_indices_flashmla,
req_to_token.size(1),
max_pages,
PAGE_SIZE,
)
# Check
self.assertTrue(torch.equal(kv_indices_ref, kv_indices_triton))
def test_create_kvindices(self):
BATCH = [4, 37, 512, 1786]
MAX_BATCH = 4096
MAX_CONTEXT_LEN = 4096
PAGE_SIZE = [1, 2, 16, 64]
# for debug
# BATCH = [4]
# MAX_BATCH = 4
# MAX_CONTEXT_LEN = 10
# Test for small batch size
for page_size in PAGE_SIZE[:1]:
print(f"Running test for page size: {page_size} and batch size: {BATCH[0]}")
self._run_test(BATCH[0], MAX_BATCH, MAX_CONTEXT_LEN, page_size)
# Test for larger batch size
for batch in BATCH[1:]:
for page_size in PAGE_SIZE:
print(
f"Running test for batch size: {batch} and page size: {page_size}"
)
self._run_test(batch, MAX_BATCH, MAX_CONTEXT_LEN, page_size)
if __name__ == "__main__":
unittest.main()