Add Support for Page Size greater than 1 for Flashinfer MLA Backend (#8593)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -4,7 +4,10 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
from sglang.srt.layers.attention.utils import (
|
||||
create_flashinfer_kv_indices_triton,
|
||||
create_flashmla_kv_indices_triton,
|
||||
)
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
@@ -15,10 +18,14 @@ class TestCreateKvIndices(CustomTestCase):
|
||||
raise unittest.SkipTest("CUDA is not available")
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
def _run_test(self, batch, max_batch, max_context_len):
|
||||
def _run_test(self, batch, max_batch, max_context_len, page_size):
|
||||
np.random.seed(9)
|
||||
PAGE_SIZE = page_size
|
||||
req_to_token = torch.arange(
|
||||
max_batch * max_context_len, dtype=torch.int32, device="cuda"
|
||||
).reshape((max_batch, max_context_len))
|
||||
|
||||
# the block table
|
||||
req_pool_indices = torch.tensor(
|
||||
torch.from_numpy(
|
||||
np.random.choice(range(max_batch), size=batch, replace=False)
|
||||
@@ -26,49 +33,84 @@ class TestCreateKvIndices(CustomTestCase):
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
paged_kernel_lens = torch.tensor(
|
||||
seq_lens = torch.tensor(
|
||||
torch.from_numpy(
|
||||
np.random.choice(range(max_context_len), size=batch, replace=False)
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
num_pages_per_req = (seq_lens + PAGE_SIZE - 1) // PAGE_SIZE
|
||||
kv_indptr = torch.zeros((batch + 1,), dtype=torch.int32, device="cuda")
|
||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indptr[1:] = torch.cumsum(num_pages_per_req, dim=0)
|
||||
|
||||
# ref
|
||||
kv_indices_ref = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
||||
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
||||
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
||||
kv_indices_ref = torch.cat(
|
||||
[
|
||||
req_to_token[req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]]
|
||||
for i in range(batch)
|
||||
],
|
||||
dim=0,
|
||||
).contiguous()
|
||||
seq_lens_cpu = seq_lens.cpu().numpy()
|
||||
for i in range(batch):
|
||||
kv_indptr_req = kv_indptr[i]
|
||||
num_toks_seq = seq_lens_cpu[i]
|
||||
curr_req_pool = req_pool_indices_cpu[i]
|
||||
curr_num_pages = num_pages_per_req[i]
|
||||
curr_token_ids = req_to_token[curr_req_pool]
|
||||
curr_pages = (curr_token_ids[:num_toks_seq] // PAGE_SIZE).unique()
|
||||
assert (
|
||||
len(curr_pages) == curr_num_pages
|
||||
), f"req {i} has #{curr_num_pages} pages, but got {len(curr_pages)} pages"
|
||||
kv_indices_ref[kv_indptr_req : kv_indptr_req + curr_num_pages] = curr_pages
|
||||
|
||||
# triton
|
||||
kv_indices_triton = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
||||
create_flashinfer_kv_indices_triton[(batch,)](
|
||||
req_to_token,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices_triton,
|
||||
req_to_token.size(1),
|
||||
PAGE_SIZE,
|
||||
)
|
||||
max_pages = max_context_len // PAGE_SIZE
|
||||
kv_indices_flashmla = torch.empty(
|
||||
batch, max_pages, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
create_flashmla_kv_indices_triton[(batch,)](
|
||||
req_to_token,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
None,
|
||||
kv_indices_flashmla,
|
||||
req_to_token.size(1),
|
||||
max_pages,
|
||||
PAGE_SIZE,
|
||||
)
|
||||
# Check
|
||||
self.assertTrue(torch.equal(kv_indices_ref, kv_indices_triton))
|
||||
|
||||
def test_create_kvindices(self):
|
||||
BATCH = [1, 37, 1786]
|
||||
BATCH = [4, 37, 512, 1786]
|
||||
MAX_BATCH = 4096
|
||||
MAX_CONTEXT_LEN = 4096
|
||||
for batch in BATCH:
|
||||
self._run_test(batch, MAX_BATCH, MAX_CONTEXT_LEN)
|
||||
PAGE_SIZE = [1, 2, 16, 64]
|
||||
# for debug
|
||||
# BATCH = [4]
|
||||
# MAX_BATCH = 4
|
||||
# MAX_CONTEXT_LEN = 10
|
||||
# Test for small batch size
|
||||
for page_size in PAGE_SIZE[:1]:
|
||||
print(f"Running test for page size: {page_size} and batch size: {BATCH[0]}")
|
||||
self._run_test(BATCH[0], MAX_BATCH, MAX_CONTEXT_LEN, page_size)
|
||||
|
||||
# Test for larger batch size
|
||||
for batch in BATCH[1:]:
|
||||
for page_size in PAGE_SIZE:
|
||||
print(
|
||||
f"Running test for batch size: {batch} and page size: {page_size}"
|
||||
)
|
||||
self._run_test(batch, MAX_BATCH, MAX_CONTEXT_LEN, page_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -120,5 +120,49 @@ class TestFlashinferMLAMTP(CustomTestCase):
|
||||
self.assertGreater(avg_spec_accept_length, 2.5)
|
||||
|
||||
|
||||
class TestFlashinferMLAPageSize16(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
other_args = ["--trust-remote-code"]
|
||||
if torch.cuda.is_available() and torch.version.cuda:
|
||||
other_args.extend(
|
||||
[
|
||||
"--cuda-graph-max-bs",
|
||||
"4",
|
||||
"--attention-backend",
|
||||
"flashinfer",
|
||||
"--page-size",
|
||||
"16",
|
||||
]
|
||||
)
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_args,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
print(metrics)
|
||||
|
||||
self.assertGreater(metrics["accuracy"], 0.615)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user