Optimize the update flashinfer indices (#1262)
This commit is contained in:
@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
@@ -262,6 +264,42 @@ class InputMetadata:
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def create_flashinfer_kv_indices_triton(
|
||||
req_to_token_ptr, # [max_batch, max_context_len]
|
||||
req_pool_indices_ptr,
|
||||
page_kernel_lens_ptr,
|
||||
kv_indptr,
|
||||
kv_start_idx,
|
||||
max_context_len,
|
||||
kv_indices_ptr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 512
|
||||
pid = tl.program_id(axis=0)
|
||||
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
||||
kv_indices_offset = tl.load(kv_indptr + pid)
|
||||
|
||||
kv_start = 0
|
||||
kv_end = 0
|
||||
if kv_start_idx:
|
||||
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
||||
kv_end = kv_start
|
||||
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
||||
|
||||
req_to_token_ptr += req_pool_index * max_context_len
|
||||
kv_indices_ptr += kv_indices_offset
|
||||
|
||||
ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
|
||||
st_offset = tl.arange(0, BLOCK_SIZE)
|
||||
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
||||
for _ in range(num_loop):
|
||||
mask = ld_offset < kv_end
|
||||
data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
|
||||
tl.store(kv_indices_ptr + st_offset, data, mask=mask)
|
||||
ld_offset += BLOCK_SIZE
|
||||
st_offset += BLOCK_SIZE
|
||||
|
||||
|
||||
def update_flashinfer_indices(
|
||||
forward_mode,
|
||||
model_runner,
|
||||
@@ -285,17 +323,18 @@ def update_flashinfer_indices(
|
||||
|
||||
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
||||
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
||||
kv_indices = torch.cat(
|
||||
[
|
||||
model_runner.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
||||
]
|
||||
for i in range(batch_size)
|
||||
],
|
||||
dim=0,
|
||||
).contiguous()
|
||||
|
||||
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
||||
create_flashinfer_kv_indices_triton[(batch_size,)](
|
||||
model_runner.req_to_token_pool.req_to_token,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
model_runner.req_to_token_pool.req_to_token.size(1),
|
||||
kv_indices,
|
||||
)
|
||||
|
||||
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
@@ -365,18 +404,17 @@ def update_flashinfer_indices(
|
||||
|
||||
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
||||
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
||||
kv_indices = torch.cat(
|
||||
[
|
||||
model_runner.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[i],
|
||||
kv_start_idx[i] : kv_start_idx[i] + paged_kernel_lens_cpu[i],
|
||||
]
|
||||
for i in range(batch_size)
|
||||
],
|
||||
dim=0,
|
||||
).contiguous()
|
||||
|
||||
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
||||
create_flashinfer_kv_indices_triton[(batch_size,)](
|
||||
model_runner.req_to_token_pool.req_to_token,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
kv_indptr,
|
||||
kv_start_idx,
|
||||
model_runner.req_to_token_pool.req_to_token.size(1),
|
||||
kv_indices,
|
||||
)
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
# CUDA graph uses different flashinfer_decode_wrapper
|
||||
|
||||
Reference in New Issue
Block a user