diff --git a/python/sglang/srt/layers/flashinfer_utils.py b/python/sglang/srt/layers/flashinfer_utils.py new file mode 100644 index 000000000..1f9ab1514 --- /dev/null +++ b/python/sglang/srt/layers/flashinfer_utils.py @@ -0,0 +1,237 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def create_flashinfer_kv_indices_triton( + req_to_token_ptr, # [max_batch, max_context_len] + req_pool_indices_ptr, + page_kernel_lens_ptr, + kv_indptr, + kv_start_idx, + max_context_len, + kv_indices_ptr, +): + BLOCK_SIZE: tl.constexpr = 512 + pid = tl.program_id(axis=0) + req_pool_index = tl.load(req_pool_indices_ptr + pid) + kv_indices_offset = tl.load(kv_indptr + pid) + + kv_start = 0 + kv_end = 0 + if kv_start_idx: + kv_start = tl.load(kv_start_idx + pid).to(tl.int32) + kv_end = kv_start + kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) + + req_to_token_ptr += req_pool_index * max_context_len + kv_indices_ptr += kv_indices_offset + + ld_offset = kv_start + tl.arange(0, BLOCK_SIZE) + st_offset = tl.arange(0, BLOCK_SIZE) + num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) + for _ in range(num_loop): + mask = ld_offset < kv_end + data = tl.load(req_to_token_ptr + ld_offset, mask=mask) + tl.store(kv_indices_ptr + st_offset, data, mask=mask) + ld_offset += BLOCK_SIZE + st_offset += BLOCK_SIZE + + +class FlashinferUpdater: + def __init__( + self, + forward_mode, + model_runner, + req_pool_indices, + seq_lens, + prefix_lens, + flashinfer_decode_wrapper=None, + flashinfer_use_ragged=False, + ): + self.forward_mode = forward_mode + self.model_runner = model_runner + self.req_pool_indices = req_pool_indices + self.seq_lens = seq_lens + self.prefix_lens = prefix_lens + self.flashinfer_use_ragged = flashinfer_use_ragged + + self.num_qo_heads = ( + model_runner.model_config.num_attention_heads // model_runner.tp_size + ) + self.num_kv_heads = model_runner.model_config.get_num_kv_heads( + model_runner.tp_size + ) + self.head_dim = model_runner.model_config.head_dim + self.batch_size = len(req_pool_indices) + + self.kv_last_page_len = torch.ones( + (self.batch_size,), dtype=torch.int32, device="cuda" + ) + + ( + self.flashinfer_decode_wrapper, + self.flashinfer_prefill_wrapper_ragged, + self.flashinfer_prefill_wrapper_paged, + ) = ( + flashinfer_decode_wrapper, + self.model_runner.flashinfer_prefill_wrapper_ragged, + self.model_runner.flashinfer_prefill_wrapper_paged, + ) + # CUDA graph uses different flashinfer_decode_wrapper + if self.flashinfer_decode_wrapper is None: + self.flashinfer_decode_wrapper = self.model_runner.flashinfer_decode_wrapper + + def _init_indices_no_window(self): + if self.flashinfer_use_ragged: + paged_kernel_lens = self.prefix_lens + else: + paged_kernel_lens = self.seq_lens + + self.kv_indptr = torch.zeros( + (self.batch_size + 1,), dtype=torch.int32, device="cuda" + ) + self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) + self.kv_indices = torch.empty( + self.kv_indptr[-1], dtype=torch.int32, device="cuda" + ) + + create_flashinfer_kv_indices_triton[(self.batch_size,)]( + self.model_runner.req_to_token_pool.req_to_token, + self.req_pool_indices, + paged_kernel_lens, + self.kv_indptr, + None, + self.model_runner.req_to_token_pool.req_to_token.size(1), + self.kv_indices, + ) + + def _init_indices_window(self, wrapper_id): + # window attention use paged only + if wrapper_id == 0: + if self.forward_mode.is_decode(): + paged_kernel_lens = torch.minimum( + self.seq_lens, + torch.tensor(self.model_runner.sliding_window_size + 1), + ) + else: + paged_kernel_lens = torch.minimum( + self.seq_lens, + torch.tensor(self.model_runner.sliding_window_size) + + self.seq_lens + - self.prefix_lens, + ) + else: + paged_kernel_lens = self.seq_lens + + kv_start_idx = self.seq_lens - paged_kernel_lens + self.kv_indptr = torch.zeros( + (self.batch_size + 1,), dtype=torch.int32, device="cuda" + ) + self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) + self.kv_indices = torch.empty( + self.kv_indptr[-1], dtype=torch.int32, device="cuda" + ) + create_flashinfer_kv_indices_triton[(self.batch_size,)]( + self.model_runner.req_to_token_pool.req_to_token, + self.req_pool_indices, + paged_kernel_lens, + self.kv_indptr, + kv_start_idx, + self.model_runner.req_to_token_pool.req_to_token.size(1), + self.kv_indices, + ) + + def _update_decode_indices(self, decode_wrapper): + decode_wrapper.end_forward() + decode_wrapper.begin_forward( + self.kv_indptr, + self.kv_indices, + self.kv_last_page_len, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + data_type=self.model_runner.kv_cache_dtype, + q_data_type=self.model_runner.dtype, + ) + + def _update_extend_indices(self, ragged_wrapper, paged_wrapper): + # extend part + qo_indptr = torch.zeros( + (self.batch_size + 1,), dtype=torch.int32, device="cuda" + ) + qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0) + + if self.flashinfer_use_ragged: + ragged_wrapper.end_forward() + ragged_wrapper.begin_forward( + qo_indptr, + qo_indptr, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + ) + + # cached part + paged_wrapper.end_forward() + paged_wrapper.begin_forward( + qo_indptr, + self.kv_indptr, + self.kv_indices, + self.kv_last_page_len, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + ) + + def update_indices_no_window(self): + self._init_indices_no_window() + + if self.forward_mode.is_decode(): + self._update_decode_indices(self.flashinfer_decode_wrapper) + else: + self._update_extend_indices( + self.flashinfer_prefill_wrapper_ragged, + self.flashinfer_prefill_wrapper_paged, + ) + + def update_indices_window(self): + assert self.flashinfer_use_ragged is False + + for wrapper_id in range(2): + self._init_indices_window(wrapper_id) + if self.forward_mode.is_decode(): + self._update_decode_indices(self.flashinfer_decode_wrapper[wrapper_id]) + else: + self._update_extend_indices( + None, + self.flashinfer_prefill_wrapper_paged[wrapper_id], + ) + + +def update_flashinfer_indices( + forward_mode, + model_runner, + req_pool_indices, + seq_lens, + prefix_lens, + flashinfer_decode_wrapper=None, + flashinfer_use_ragged=False, +): + flashinfer_updater = FlashinferUpdater( + forward_mode, + model_runner, + req_pool_indices, + seq_lens, + prefix_lens, + flashinfer_decode_wrapper, + flashinfer_use_ragged, + ) + + if model_runner.sliding_window_size is None: + flashinfer_updater.update_indices_no_window() + else: + flashinfer_updater.update_indices_window() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index e46177bdb..b6000734a 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -349,6 +349,7 @@ class ScheduleBatch: # For mixed chunekd prefill prefix_lens_cpu: List[int] = None + running_bs: int = None # For processing logprobs return_logprob: bool = False @@ -446,6 +447,9 @@ class ScheduleBatch: self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size) def mix_with_running(self, running_batch: "ScheduleBatch"): + self.forward_mode = ForwardMode.MIXED + self.running_bs = running_batch.batch_size() + # NOTE: prefix_indices is what has been cached, but we don't cache each decode step prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs] prefix_lens_cpu.extend( diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 4459213b0..c24dd5084 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -25,6 +25,7 @@ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels from vllm.distributed.parallel_state import graph_capture from vllm.model_executor.custom_op import CustomOp +from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices from sglang.srt.layers.logits_processor import ( LogitsMetadata, LogitsProcessor, @@ -32,11 +33,7 @@ from sglang.srt.layers.logits_processor import ( ) from sglang.srt.layers.sampler import SampleOutput from sglang.srt.managers.schedule_batch import ScheduleBatch -from sglang.srt.model_executor.forward_batch_info import ( - ForwardMode, - InputMetadata, - update_flashinfer_indices, -) +from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.utils import monkey_patch_vllm_all_gather diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index c158b3ce2..f3bed6bcf 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -22,8 +22,8 @@ from typing import TYPE_CHECKING, List import numpy as np import torch -import triton -import triton.language as tl + +from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch @@ -39,16 +39,21 @@ class ForwardMode(IntEnum): EXTEND = auto() # Decode one token. DECODE = auto() + # Contains both PREFILL and EXTEND. + MIXED = auto() def is_prefill(self): return self == ForwardMode.PREFILL def is_extend(self): - return self == ForwardMode.EXTEND + return self == ForwardMode.EXTEND or self == ForwardMode.MIXED def is_decode(self): return self == ForwardMode.DECODE + def is_mixed(self): + return self == ForwardMode.MIXED + @dataclass class InputMetadata: @@ -270,192 +275,3 @@ class InputMetadata: model_runner.flashinfer_decode_wrapper, flashinfer_use_ragged, ) - - -@triton.jit -def create_flashinfer_kv_indices_triton( - req_to_token_ptr, # [max_batch, max_context_len] - req_pool_indices_ptr, - page_kernel_lens_ptr, - kv_indptr, - kv_start_idx, - max_context_len, - kv_indices_ptr, -): - BLOCK_SIZE: tl.constexpr = 512 - pid = tl.program_id(axis=0) - req_pool_index = tl.load(req_pool_indices_ptr + pid) - kv_indices_offset = tl.load(kv_indptr + pid) - - kv_start = 0 - kv_end = 0 - if kv_start_idx: - kv_start = tl.load(kv_start_idx + pid).to(tl.int32) - kv_end = kv_start - kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) - - req_to_token_ptr += req_pool_index * max_context_len - kv_indices_ptr += kv_indices_offset - - ld_offset = kv_start + tl.arange(0, BLOCK_SIZE) - st_offset = tl.arange(0, BLOCK_SIZE) - num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) - for _ in range(num_loop): - mask = ld_offset < kv_end - data = tl.load(req_to_token_ptr + ld_offset, mask=mask) - tl.store(kv_indices_ptr + st_offset, data, mask=mask) - ld_offset += BLOCK_SIZE - st_offset += BLOCK_SIZE - - -def update_flashinfer_indices( - forward_mode, - model_runner, - req_pool_indices, - seq_lens, - prefix_lens, - flashinfer_decode_wrapper=None, - flashinfer_use_ragged=False, -): - """Init auxiliary variables for FlashInfer attention backend.""" - num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size - num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size) - head_dim = model_runner.model_config.head_dim - batch_size = len(req_pool_indices) - - if model_runner.sliding_window_size is None: - if flashinfer_use_ragged: - paged_kernel_lens = prefix_lens - else: - paged_kernel_lens = seq_lens - - kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") - kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - - kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") - create_flashinfer_kv_indices_triton[(batch_size,)]( - model_runner.req_to_token_pool.req_to_token, - req_pool_indices, - paged_kernel_lens, - kv_indptr, - None, - model_runner.req_to_token_pool.req_to_token.size(1), - kv_indices, - ) - - kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") - - if forward_mode.is_decode(): - # CUDA graph uses different flashinfer_decode_wrapper - if flashinfer_decode_wrapper is None: - flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper - - flashinfer_decode_wrapper.end_forward() - flashinfer_decode_wrapper.begin_forward( - kv_indptr, - kv_indices, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - 1, - data_type=model_runner.kv_cache_dtype, - q_data_type=model_runner.dtype, - ) - else: - # extend part - qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") - qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) - - if flashinfer_use_ragged: - model_runner.flashinfer_prefill_wrapper_ragged.end_forward() - model_runner.flashinfer_prefill_wrapper_ragged.begin_forward( - qo_indptr, - qo_indptr, - num_qo_heads, - num_kv_heads, - head_dim, - ) - - # cached part - model_runner.flashinfer_prefill_wrapper_paged.end_forward() - model_runner.flashinfer_prefill_wrapper_paged.begin_forward( - qo_indptr, - kv_indptr, - kv_indices, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - 1, - ) - else: - # window attention use paged only - kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") - for wrapper_id in range(2): - if wrapper_id == 0: - if forward_mode.is_decode(): - paged_kernel_lens = torch.minimum( - seq_lens, torch.tensor(model_runner.sliding_window_size + 1) - ) - else: - paged_kernel_lens = torch.minimum( - seq_lens, - torch.tensor(model_runner.sliding_window_size) - + seq_lens - - prefix_lens, - ) - else: - paged_kernel_lens = seq_lens - - kv_start_idx = seq_lens - paged_kernel_lens - - kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") - kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - - kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") - create_flashinfer_kv_indices_triton[(batch_size,)]( - model_runner.req_to_token_pool.req_to_token, - req_pool_indices, - paged_kernel_lens, - kv_indptr, - kv_start_idx, - model_runner.req_to_token_pool.req_to_token.size(1), - kv_indices, - ) - - if forward_mode.is_decode(): - # CUDA graph uses different flashinfer_decode_wrapper - if flashinfer_decode_wrapper is None: - flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper - - flashinfer_decode_wrapper[wrapper_id].end_forward() - flashinfer_decode_wrapper[wrapper_id].begin_forward( - kv_indptr, - kv_indices, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - 1, - data_type=model_runner.kv_cache_dtype, - q_data_type=model_runner.dtype, - ) - else: - # extend part - qo_indptr = torch.zeros( - (batch_size + 1,), dtype=torch.int32, device="cuda" - ) - qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) - - model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].end_forward() - model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].begin_forward( - qo_indptr, - kv_indptr, - kv_indices, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - 1, - ) diff --git a/test/srt/test_create_kvindices.py b/test/srt/test_create_kvindices.py index 230302f26..2159cca95 100644 --- a/test/srt/test_create_kvindices.py +++ b/test/srt/test_create_kvindices.py @@ -4,9 +4,7 @@ import unittest import numpy as np import torch -from sglang.srt.model_executor.forward_batch_info import ( - create_flashinfer_kv_indices_triton, -) +from sglang.srt.layers.flashinfer_utils import create_flashinfer_kv_indices_triton class TestCreateKvIndices(unittest.TestCase):