diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 3d40c9d75..a443b113d 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, List import numpy as np import torch +import triton +import triton.language as tl from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool @@ -262,6 +264,42 @@ class InputMetadata: ) +@triton.jit +def create_flashinfer_kv_indices_triton( + req_to_token_ptr, # [max_batch, max_context_len] + req_pool_indices_ptr, + page_kernel_lens_ptr, + kv_indptr, + kv_start_idx, + max_context_len, + kv_indices_ptr, +): + BLOCK_SIZE: tl.constexpr = 512 + pid = tl.program_id(axis=0) + req_pool_index = tl.load(req_pool_indices_ptr + pid) + kv_indices_offset = tl.load(kv_indptr + pid) + + kv_start = 0 + kv_end = 0 + if kv_start_idx: + kv_start = tl.load(kv_start_idx + pid).to(tl.int32) + kv_end = kv_start + kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) + + req_to_token_ptr += req_pool_index * max_context_len + kv_indices_ptr += kv_indices_offset + + ld_offset = kv_start + tl.arange(0, BLOCK_SIZE) + st_offset = tl.arange(0, BLOCK_SIZE) + num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) + for _ in range(num_loop): + mask = ld_offset < kv_end + data = tl.load(req_to_token_ptr + ld_offset, mask=mask) + tl.store(kv_indices_ptr + st_offset, data, mask=mask) + ld_offset += BLOCK_SIZE + st_offset += BLOCK_SIZE + + def update_flashinfer_indices( forward_mode, model_runner, @@ -285,17 +323,18 @@ def update_flashinfer_indices( kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - req_pool_indices_cpu = req_pool_indices.cpu().numpy() - paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() - kv_indices = torch.cat( - [ - model_runner.req_to_token_pool.req_to_token[ - req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i] - ] - for i in range(batch_size) - ], - dim=0, - ).contiguous() + + kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") + create_flashinfer_kv_indices_triton[(batch_size,)]( + model_runner.req_to_token_pool.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + None, + model_runner.req_to_token_pool.req_to_token.size(1), + kv_indices, + ) + kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") if forward_mode == ForwardMode.DECODE: @@ -365,18 +404,17 @@ def update_flashinfer_indices( kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - req_pool_indices_cpu = req_pool_indices.cpu().numpy() - paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() - kv_indices = torch.cat( - [ - model_runner.req_to_token_pool.req_to_token[ - req_pool_indices_cpu[i], - kv_start_idx[i] : kv_start_idx[i] + paged_kernel_lens_cpu[i], - ] - for i in range(batch_size) - ], - dim=0, - ).contiguous() + + kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") + create_flashinfer_kv_indices_triton[(batch_size,)]( + model_runner.req_to_token_pool.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + kv_start_idx, + model_runner.req_to_token_pool.req_to_token.size(1), + kv_indices, + ) if forward_mode == ForwardMode.DECODE: # CUDA graph uses different flashinfer_decode_wrapper diff --git a/test/srt/test_create_kvindices.py b/test/srt/test_create_kvindices.py new file mode 100644 index 000000000..230302f26 --- /dev/null +++ b/test/srt/test_create_kvindices.py @@ -0,0 +1,76 @@ +import itertools +import unittest + +import numpy as np +import torch + +from sglang.srt.model_executor.forward_batch_info import ( + create_flashinfer_kv_indices_triton, +) + + +class TestCreateKvIndices(unittest.TestCase): + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _run_test(self, batch, max_batch, max_context_len): + req_to_token = torch.arange( + max_batch * max_context_len, dtype=torch.int32, device="cuda" + ).reshape((max_batch, max_context_len)) + req_pool_indices = torch.tensor( + torch.from_numpy( + np.random.choice(range(max_batch), size=batch, replace=False) + ), + dtype=torch.int32, + device="cuda", + ) + paged_kernel_lens = torch.tensor( + torch.from_numpy( + np.random.choice(range(max_context_len), size=batch, replace=False) + ), + dtype=torch.int32, + device="cuda", + ) + + kv_indptr = torch.zeros((batch + 1,), dtype=torch.int32, device="cuda") + kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + # ref + req_pool_indices_cpu = req_pool_indices.cpu().numpy() + paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() + kv_indices_ref = torch.cat( + [ + req_to_token[req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]] + for i in range(batch) + ], + dim=0, + ).contiguous() + + # triton + kv_indices_triton = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") + create_flashinfer_kv_indices_triton[(batch,)]( + req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + None, + req_to_token.size(1), + kv_indices_triton, + ) + + # Check + self.assertTrue(torch.equal(kv_indices_ref, kv_indices_triton)) + + def test_create_kvindices(self): + BATCH = [1, 37, 1786] + MAX_BATCH = 4096 + MAX_CONTEXT_LEN = 4096 + for batch in BATCH: + self._run_test(batch, MAX_BATCH, MAX_CONTEXT_LEN) + + +if __name__ == "__main__": + unittest.main()