Optimize the update flashinfer indices (#1262)

This commit is contained in:
xiaobochen
2024-08-31 23:40:28 -07:00
committed by GitHub
parent 6cc9c52521
commit d134c139a1
2 changed files with 137 additions and 23 deletions

View File

@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, List
import numpy as np
import torch
import triton
import triton.language as tl
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
@@ -262,6 +264,42 @@ class InputMetadata:
)
@triton.jit
def create_flashinfer_kv_indices_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_indptr,
kv_start_idx,
max_context_len,
kv_indices_ptr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(axis=0)
req_pool_index = tl.load(req_pool_indices_ptr + pid)
kv_indices_offset = tl.load(kv_indptr + pid)
kv_start = 0
kv_end = 0
if kv_start_idx:
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
kv_end = kv_start
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
req_to_token_ptr += req_pool_index * max_context_len
kv_indices_ptr += kv_indices_offset
ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
st_offset = tl.arange(0, BLOCK_SIZE)
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for _ in range(num_loop):
mask = ld_offset < kv_end
data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
tl.store(kv_indices_ptr + st_offset, data, mask=mask)
ld_offset += BLOCK_SIZE
st_offset += BLOCK_SIZE
def update_flashinfer_indices(
forward_mode,
model_runner,
@@ -285,17 +323,18 @@ def update_flashinfer_indices(
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
kv_indices = torch.cat(
[
model_runner.req_to_token_pool.req_to_token[
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
]
for i in range(batch_size)
],
dim=0,
).contiguous()
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
create_flashinfer_kv_indices_triton[(batch_size,)](
model_runner.req_to_token_pool.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
None,
model_runner.req_to_token_pool.req_to_token.size(1),
kv_indices,
)
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
if forward_mode == ForwardMode.DECODE:
@@ -365,18 +404,17 @@ def update_flashinfer_indices(
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
kv_indices = torch.cat(
[
model_runner.req_to_token_pool.req_to_token[
req_pool_indices_cpu[i],
kv_start_idx[i] : kv_start_idx[i] + paged_kernel_lens_cpu[i],
]
for i in range(batch_size)
],
dim=0,
).contiguous()
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
create_flashinfer_kv_indices_triton[(batch_size,)](
model_runner.req_to_token_pool.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
kv_start_idx,
model_runner.req_to_token_pool.req_to_token.size(1),
kv_indices,
)
if forward_mode == ForwardMode.DECODE:
# CUDA graph uses different flashinfer_decode_wrapper