Optimize the update flashinfer indices (#1262)
This commit is contained in:
@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, List
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
@@ -262,6 +264,42 @@ class InputMetadata:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def create_flashinfer_kv_indices_triton(
|
||||||
|
req_to_token_ptr, # [max_batch, max_context_len]
|
||||||
|
req_pool_indices_ptr,
|
||||||
|
page_kernel_lens_ptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_start_idx,
|
||||||
|
max_context_len,
|
||||||
|
kv_indices_ptr,
|
||||||
|
):
|
||||||
|
BLOCK_SIZE: tl.constexpr = 512
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
||||||
|
kv_indices_offset = tl.load(kv_indptr + pid)
|
||||||
|
|
||||||
|
kv_start = 0
|
||||||
|
kv_end = 0
|
||||||
|
if kv_start_idx:
|
||||||
|
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
||||||
|
kv_end = kv_start
|
||||||
|
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
||||||
|
|
||||||
|
req_to_token_ptr += req_pool_index * max_context_len
|
||||||
|
kv_indices_ptr += kv_indices_offset
|
||||||
|
|
||||||
|
ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
st_offset = tl.arange(0, BLOCK_SIZE)
|
||||||
|
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
||||||
|
for _ in range(num_loop):
|
||||||
|
mask = ld_offset < kv_end
|
||||||
|
data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
|
||||||
|
tl.store(kv_indices_ptr + st_offset, data, mask=mask)
|
||||||
|
ld_offset += BLOCK_SIZE
|
||||||
|
st_offset += BLOCK_SIZE
|
||||||
|
|
||||||
|
|
||||||
def update_flashinfer_indices(
|
def update_flashinfer_indices(
|
||||||
forward_mode,
|
forward_mode,
|
||||||
model_runner,
|
model_runner,
|
||||||
@@ -285,17 +323,18 @@ def update_flashinfer_indices(
|
|||||||
|
|
||||||
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
|
||||||
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
||||||
kv_indices = torch.cat(
|
create_flashinfer_kv_indices_triton[(batch_size,)](
|
||||||
[
|
model_runner.req_to_token_pool.req_to_token,
|
||||||
model_runner.req_to_token_pool.req_to_token[
|
req_pool_indices,
|
||||||
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
paged_kernel_lens,
|
||||||
]
|
kv_indptr,
|
||||||
for i in range(batch_size)
|
None,
|
||||||
],
|
model_runner.req_to_token_pool.req_to_token.size(1),
|
||||||
dim=0,
|
kv_indices,
|
||||||
).contiguous()
|
)
|
||||||
|
|
||||||
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
if forward_mode == ForwardMode.DECODE:
|
if forward_mode == ForwardMode.DECODE:
|
||||||
@@ -365,18 +404,17 @@ def update_flashinfer_indices(
|
|||||||
|
|
||||||
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
|
||||||
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
||||||
kv_indices = torch.cat(
|
create_flashinfer_kv_indices_triton[(batch_size,)](
|
||||||
[
|
model_runner.req_to_token_pool.req_to_token,
|
||||||
model_runner.req_to_token_pool.req_to_token[
|
req_pool_indices,
|
||||||
req_pool_indices_cpu[i],
|
paged_kernel_lens,
|
||||||
kv_start_idx[i] : kv_start_idx[i] + paged_kernel_lens_cpu[i],
|
kv_indptr,
|
||||||
]
|
kv_start_idx,
|
||||||
for i in range(batch_size)
|
model_runner.req_to_token_pool.req_to_token.size(1),
|
||||||
],
|
kv_indices,
|
||||||
dim=0,
|
)
|
||||||
).contiguous()
|
|
||||||
|
|
||||||
if forward_mode == ForwardMode.DECODE:
|
if forward_mode == ForwardMode.DECODE:
|
||||||
# CUDA graph uses different flashinfer_decode_wrapper
|
# CUDA graph uses different flashinfer_decode_wrapper
|
||||||
|
|||||||
76
test/srt/test_create_kvindices.py
Normal file
76
test/srt/test_create_kvindices.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
import itertools
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
|
create_flashinfer_kv_indices_triton,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateKvIndices(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise unittest.SkipTest("CUDA is not available")
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
|
def _run_test(self, batch, max_batch, max_context_len):
|
||||||
|
req_to_token = torch.arange(
|
||||||
|
max_batch * max_context_len, dtype=torch.int32, device="cuda"
|
||||||
|
).reshape((max_batch, max_context_len))
|
||||||
|
req_pool_indices = torch.tensor(
|
||||||
|
torch.from_numpy(
|
||||||
|
np.random.choice(range(max_batch), size=batch, replace=False)
|
||||||
|
),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
paged_kernel_lens = torch.tensor(
|
||||||
|
torch.from_numpy(
|
||||||
|
np.random.choice(range(max_context_len), size=batch, replace=False)
|
||||||
|
),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_indptr = torch.zeros((batch + 1,), dtype=torch.int32, device="cuda")
|
||||||
|
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
|
|
||||||
|
# ref
|
||||||
|
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
||||||
|
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
||||||
|
kv_indices_ref = torch.cat(
|
||||||
|
[
|
||||||
|
req_to_token[req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]]
|
||||||
|
for i in range(batch)
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
).contiguous()
|
||||||
|
|
||||||
|
# triton
|
||||||
|
kv_indices_triton = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
||||||
|
create_flashinfer_kv_indices_triton[(batch,)](
|
||||||
|
req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
paged_kernel_lens,
|
||||||
|
kv_indptr,
|
||||||
|
None,
|
||||||
|
req_to_token.size(1),
|
||||||
|
kv_indices_triton,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check
|
||||||
|
self.assertTrue(torch.equal(kv_indices_ref, kv_indices_triton))
|
||||||
|
|
||||||
|
def test_create_kvindices(self):
|
||||||
|
BATCH = [1, 37, 1786]
|
||||||
|
MAX_BATCH = 4096
|
||||||
|
MAX_CONTEXT_LEN = 4096
|
||||||
|
for batch in BATCH:
|
||||||
|
self._run_test(batch, MAX_BATCH, MAX_CONTEXT_LEN)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user