From 6d0fa73ece7d8e6694ce9a435b5204acaed20876 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 17 Oct 2024 22:54:14 -0700 Subject: [PATCH] Simplify flashinfer utilities (#1704) --- .../sglang/srt/layers/attention/__init__.py | 32 +- .../attention/double_sparsity_backend.py | 4 +- .../layers/attention/flashinfer_backend.py | 435 ++++++++++++++---- .../srt/layers/attention/flashinfer_utils.py | 237 ---------- .../srt/layers/attention/triton_backend.py | 4 +- python/sglang/srt/managers/schedule_batch.py | 8 +- .../srt/model_executor/forward_batch_info.py | 5 +- .../sglang/srt/model_executor/model_runner.py | 3 + 8 files changed, 391 insertions(+), 337 deletions(-) delete mode 100644 python/sglang/srt/layers/attention/flashinfer_utils.py diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index 4cad3d8aa..f6d10170c 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod +import torch from torch import nn from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -18,13 +19,13 @@ class AttentionBackend(ABC): raise NotImplementedError() def init_forward_metadata_capture_cuda_graph( - self, bs: int, req_pool_indices, seq_lens + self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor ): """Init the metadata for a forward pass for capturing a cuda graph.""" raise NotImplementedError() def init_forward_metadata_replay_cuda_graph( - self, bs: int, req_pool_indices, seq_lens + self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor ): """Init the metadata for a forward pass for replying a cuda graph.""" raise NotImplementedError() @@ -33,17 +34,38 @@ class AttentionBackend(ABC): """Get the fill value for padded seq lens. Typically, it is 0 or 1.""" raise NotImplementedError() - def forward(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: nn.Module, + forward_batch: ForwardBatch, + ): """Run forward on an attention layer.""" if forward_batch.forward_mode.is_decode(): return self.forward_decode(q, k, v, layer, forward_batch) else: return self.forward_extend(q, k, v, layer, forward_batch) - def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: nn.Module, + forward_batch: ForwardBatch, + ): """Run a forward for decode.""" raise NotImplementedError() - def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): + def forward_extend( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: nn.Module, + forward_batch: ForwardBatch, + ): """Run a forward for extend.""" raise NotImplementedError() diff --git a/python/sglang/srt/layers/attention/double_sparsity_backend.py b/python/sglang/srt/layers/attention/double_sparsity_backend.py index 2ddbee44f..e2cd98ec2 100644 --- a/python/sglang/srt/layers/attention/double_sparsity_backend.py +++ b/python/sglang/srt/layers/attention/double_sparsity_backend.py @@ -134,7 +134,7 @@ class DoubleSparseAttnBackend(AttentionBackend): ) def init_forward_metadata_capture_cuda_graph( - self, bs: int, req_pool_indices, seq_lens + self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor ): self.forward_metadata = ( self.cuda_graph_start_loc, @@ -144,7 +144,7 @@ class DoubleSparseAttnBackend(AttentionBackend): ) def init_forward_metadata_replay_cuda_graph( - self, bs: int, req_pool_indices, seq_lens + self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor ): self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 0c9ca8f9d..c2cfa5fb6 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -7,18 +7,17 @@ FlashInfer is faster and Triton is easier to customize. Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode. """ +from enum import Enum, auto from typing import TYPE_CHECKING import torch import torch.nn as nn +import triton +import triton.language as tl from sglang.global_config import global_config from sglang.srt.layers.attention import AttentionBackend -from sglang.srt.layers.attention.flashinfer_utils import ( - WrapperDispatch, - update_flashinfer_indices, -) -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import is_flashinfer_available if TYPE_CHECKING: @@ -34,13 +33,18 @@ if is_flashinfer_available(): from flashinfer.decode import _grouped_size_compiled_for_decode_kernels +class WrapperDispatch(Enum): + SLIDING_WINDOW = auto() + CROSS_ATTENTION = auto() + + class FlashInferAttnBackend(AttentionBackend): """Flashinfer attention kernels.""" def __init__(self, model_runner: ModelRunner): super().__init__() - self.model_runner = model_runner + # Parse constants if not _grouped_size_compiled_for_decode_kernels( model_runner.model_config.num_attention_heads // model_runner.tp_size, model_runner.model_config.get_num_kv_heads(model_runner.tp_size), @@ -48,27 +52,43 @@ class FlashInferAttnBackend(AttentionBackend): self.decode_use_tensor_cores = True else: self.decode_use_tensor_cores = False - - self.workspace_buffer = torch.empty( - global_config.flashinfer_workspace_size, - dtype=torch.uint8, - device="cuda", - ) + self.max_context_len = model_runner.model_config.context_len assert not ( model_runner.sliding_window_size is not None and model_runner.has_cross_attention ), "Sliding window and cross attention are not supported together" - self.num_wrappers = 1 - self.dispatch_reason = None if model_runner.sliding_window_size is not None: self.num_wrappers = 2 self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW elif model_runner.has_cross_attention: self.num_wrappers = 2 self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION + else: + self.num_wrappers = 1 + self.dispatch_reason = None + # Allocate buffers + self.workspace_buffer = torch.empty( + global_config.flashinfer_workspace_size, + dtype=torch.uint8, + device=model_runner.device, + ) + max_bs = model_runner.req_to_token_pool.size + self.kv_indptr = [ + torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device) + for _ in range(self.num_wrappers) + ] + self.kv_last_page_len = torch.ones( + (max_bs,), dtype=torch.int32, device=model_runner.device + ) + self.qo_indptr = [ + torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device) + for _ in range(self.num_wrappers) + ] + + # Create wrappers # NOTE: we do not use ragged attention when there are multiple wrappers self.prefill_wrapper_ragged = ( BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD") @@ -92,26 +112,23 @@ class FlashInferAttnBackend(AttentionBackend): ) ) + # Create indices updater + self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self) + self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill( + model_runner, self + ) + + # Other metadata self.forward_metadata = None self.cuda_graph_metadata = {} - def _get_wrapper_idx(self, layer: nn.Module): - if self.num_wrappers == 1: - return 0 - - if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: - return layer.sliding_window_size == -1 - if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: - return layer.is_cross_attention - - raise ValueError(f"Unknown dispatch reason: {self.dispatch_reason}") - def init_forward_metadata(self, forward_batch: ForwardBatch): if forward_batch.forward_mode.is_decode(): - prefix_lens = None - use_ragged = False - extend_no_prefix = False - total_num_tokens = None + self.indices_updater_decode.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + ) + self.forward_metadata = (self.decode_wrappers,) else: prefix_lens = forward_batch.extend_prefix_lens @@ -123,48 +140,32 @@ class FlashInferAttnBackend(AttentionBackend): ): use_ragged = True - total_num_tokens = torch.sum(forward_batch.seq_lens).item() extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item() - update_flashinfer_indices( - forward_batch.forward_mode, - self.model_runner, - forward_batch.req_pool_indices, - forward_batch.seq_lens, - prefix_lens, - use_ragged=use_ragged, - ) + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + prefix_lens, + use_ragged, + ) - self.forward_metadata = ( - use_ragged, - extend_no_prefix, - total_num_tokens, - self.decode_wrappers, - ) + self.forward_metadata = ( + use_ragged, + extend_no_prefix, + ) def init_cuda_graph_state(self, max_bs: int): - self.cuda_graph_kv_indptr = torch.zeros( - (max_bs + 1,), dtype=torch.int32, device="cuda" - ) - self.cuda_graph_kv_indices = torch.zeros( - (max_bs * self.model_runner.model_config.context_len,), + cuda_graph_kv_indices = torch.zeros( + (max_bs * self.max_context_len,), dtype=torch.int32, device="cuda", ) - self.cuda_graph_kv_last_page_len = torch.ones( - (max_bs,), dtype=torch.int32, device="cuda" - ) - - # NOTE: the buffers are always in the form of list - self.cuda_graph_kv_indptr = [self.cuda_graph_kv_indptr] + [ - self.cuda_graph_kv_indptr.clone() for _ in range(self.num_wrappers - 1) - ] - self.cuda_graph_kv_indices = [self.cuda_graph_kv_indices] + [ - self.cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1) + self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [ + cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1) ] def init_forward_metadata_capture_cuda_graph( - self, bs: int, req_pool_indices, seq_lens + self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor ): decode_wrappers = [] for i in range(self.num_wrappers): @@ -174,35 +175,21 @@ class FlashInferAttnBackend(AttentionBackend): "NHD", use_cuda_graph=True, use_tensor_cores=self.decode_use_tensor_cores, - paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1], + paged_kv_indptr_buffer=self.kv_indptr[i][: bs + 1], paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], - paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs], + paged_kv_last_page_len_buffer=self.kv_last_page_len[:bs], ) ) - update_flashinfer_indices( - ForwardMode.DECODE, - self.model_runner, - req_pool_indices, - seq_lens, - None, - decode_wrappers, - ) - + self.indices_updater_decode.update(req_pool_indices, seq_lens, decode_wrappers) self.cuda_graph_metadata[bs] = decode_wrappers - - self.forward_metadata = (False, False, None, decode_wrappers) + self.forward_metadata = (decode_wrappers,) def init_forward_metadata_replay_cuda_graph( - self, bs: int, req_pool_indices, seq_lens + self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor ): - update_flashinfer_indices( - ForwardMode.DECODE, - self.model_runner, - req_pool_indices[:bs], - seq_lens[:bs], - None, - self.cuda_graph_metadata[bs], + self.indices_updater_decode.update( + req_pool_indices[:bs], seq_lens[:bs], self.cuda_graph_metadata[bs] ) def get_cuda_graph_seq_len_fill_value(self): @@ -213,7 +200,7 @@ class FlashInferAttnBackend(AttentionBackend): self._get_wrapper_idx(layer) ] - use_ragged, extend_no_prefix, _, _ = self.forward_metadata + use_ragged, extend_no_prefix = self.forward_metadata if not use_ragged: if k is not None: @@ -259,7 +246,7 @@ class FlashInferAttnBackend(AttentionBackend): return o.view(-1, layer.tp_q_head_num * layer.head_dim) def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): - decode_wrapper = self.forward_metadata[-1][self._get_wrapper_idx(layer)] + decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)] if k is not None: assert v is not None @@ -275,3 +262,285 @@ class FlashInferAttnBackend(AttentionBackend): ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) + + def _get_wrapper_idx(self, layer: nn.Module): + if self.num_wrappers == 1: + return 0 + + if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: + return layer.sliding_window_size == -1 + if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: + return layer.is_cross_attention + + raise ValueError(f"Unknown dispatch reason: {self.dispatch_reason}") + + +class FlashInferIndicesUpdaterDecode: + def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): + # Constants + self.num_qo_heads = ( + model_runner.model_config.num_attention_heads // model_runner.tp_size + ) + self.num_kv_heads = model_runner.model_config.get_num_kv_heads( + model_runner.tp_size + ) + self.head_dim = model_runner.model_config.head_dim + self.data_type = model_runner.kv_cache_dtype + self.q_data_type = model_runner.dtype + self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1) + self.sliding_window_size = model_runner.sliding_window_size + + # Buffers and wrappers + self.kv_indptr = attn_backend.kv_indptr + self.kv_last_page_len = attn_backend.kv_last_page_len + self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.decode_wrappers = attn_backend.decode_wrappers + + # Dispatch + if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: + self.update = self.update_sliding_window + elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: + self.update = self.update_cross_attention + else: + assert attn_backend.num_wrappers == 1 + self.update = self.update_single_wrapper + + def update_single_wrapper(self, req_pool_indices, seq_lens, decode_wrappers=None): + decode_wrappers = decode_wrappers or self.decode_wrappers + self.call_begin_forward( + decode_wrappers[0], req_pool_indices, seq_lens, self.kv_indptr[0], None + ) + + def update_sliding_window(self, req_pool_indices, seq_lens, decode_wrappers=None): + decode_wrappers = decode_wrappers or self.decode_wrappers + + for wrapper_id in range(2): + if wrapper_id == 0: + # Sliding window attention + paged_kernel_lens = torch.minimum( # TODO: replace this with clamp + seq_lens, + torch.tensor(self.sliding_window_size + 1), + ) + else: + # Full attention + paged_kernel_lens = seq_lens + + kv_start_idx = seq_lens - paged_kernel_lens + + self.call_begin_forward( + decode_wrappers[wrapper_id], + req_pool_indices, + paged_kernel_lens, + self.kv_indptr[wrapper_id], + kv_start_idx, + ) + + def update_cross_attention(self): + raise NotImplementedError() + + def call_begin_forward( + self, wrapper, req_pool_indices, paged_kernel_lens, kv_indptr, kv_start_idx + ): + bs = len(req_pool_indices) + kv_indptr = kv_indptr[: bs + 1] + # TODO: optimize the blocking call on kv_indptr[-1] + kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") + + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + self.max_context_len, + ) + + wrapper.end_forward() + wrapper.begin_forward( + kv_indptr, + kv_indices, + self.kv_last_page_len[:bs], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + data_type=self.data_type, + q_data_type=self.q_data_type, + ) + + +class FlashInferIndicesUpdaterPrefill: + def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): + # Constants + self.num_qo_heads = ( + model_runner.model_config.num_attention_heads // model_runner.tp_size + ) + self.num_kv_heads = model_runner.model_config.get_num_kv_heads( + model_runner.tp_size + ) + self.head_dim = model_runner.model_config.head_dim + self.data_type = model_runner.kv_cache_dtype + self.q_data_type = model_runner.dtype + self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1) + self.sliding_window_size = model_runner.sliding_window_size + + # Buffers and wrappers + self.kv_indptr = attn_backend.kv_indptr + self.kv_last_page_len = attn_backend.kv_last_page_len + self.qo_indptr = attn_backend.qo_indptr + self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.wrapper_ragged = attn_backend.prefill_wrapper_ragged + self.wrappers_paged = attn_backend.prefill_wrappers_paged + + # Dispatch + if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: + self.update = self.update_sliding_window + elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: + self.update = self.update_cross_attention + else: + assert attn_backend.num_wrappers == 1 + self.update = self.update_single_wrapper + + def update_single_wrapper( + self, req_pool_indices, seq_lens, prefix_lens, use_ragged + ): + if use_ragged: + paged_kernel_lens = prefix_lens + else: + paged_kernel_lens = seq_lens + + self.call_begin_forward( + self.wrapper_ragged, + self.wrappers_paged[0], + req_pool_indices, + paged_kernel_lens, + seq_lens, + prefix_lens, + None, + self.kv_indptr[0], + self.qo_indptr[0], + use_ragged, + ) + + def update_sliding_window( + self, req_pool_indices, seq_lens, prefix_lens, use_ragged + ): + for wrapper_id in range(2): + if wrapper_id == 0: + # window attention use paged only + paged_kernel_lens = torch.minimum( + seq_lens, + torch.tensor(self.sliding_window_size) + seq_lens - prefix_lens, + ) + else: + # full attention + paged_kernel_lens = seq_lens + kv_start_idx = seq_lens - paged_kernel_lens + + self.call_begin_forward( + self.wrapper_ragged, + self.wrappers_paged[wrapper_id], + req_pool_indices, + paged_kernel_lens, + seq_lens, + prefix_lens, + kv_start_idx, + self.kv_indptr[wrapper_id], + self.qo_indptr[wrapper_id], + use_ragged, + ) + + def update_cross_attention(self): + raise NotImplementedError() + + def call_begin_forward( + self, + wrapper_ragged, + wrapper_paged, + req_pool_indices, + paged_kernel_lens, + seq_lens, + prefix_lens, + kv_start_idx, + kv_indptr, + qo_indptr, + use_ragged, + ): + bs = len(req_pool_indices) + kv_indptr = kv_indptr[: bs + 1] + kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + self.max_context_len, + ) + + qo_indptr = qo_indptr[: bs + 1] + qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) + + # extend part + if use_ragged: + wrapper_ragged.end_forward() + wrapper_ragged.begin_forward( + qo_indptr, + qo_indptr, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + ) + + # cached part + wrapper_paged.end_forward() + wrapper_paged.begin_forward( + qo_indptr, + kv_indptr, + kv_indices, + self.kv_last_page_len[:bs], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + ) + + +@triton.jit +def create_flashinfer_kv_indices_triton( + req_to_token_ptr, # [max_batch, max_context_len] + req_pool_indices_ptr, + page_kernel_lens_ptr, + kv_indptr, + kv_start_idx, + kv_indices_ptr, + max_context_len: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 512 + pid = tl.program_id(axis=0) + req_pool_index = tl.load(req_pool_indices_ptr + pid) + kv_indices_offset = tl.load(kv_indptr + pid) + + kv_start = 0 + kv_end = 0 + if kv_start_idx: + kv_start = tl.load(kv_start_idx + pid).to(tl.int32) + kv_end = kv_start + kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) + + req_to_token_ptr += req_pool_index * max_context_len + kv_indices_ptr += kv_indices_offset + + ld_offset = kv_start + tl.arange(0, BLOCK_SIZE) + st_offset = tl.arange(0, BLOCK_SIZE) + num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) + for _ in range(num_loop): + mask = ld_offset < kv_end + data = tl.load(req_to_token_ptr + ld_offset, mask=mask) + tl.store(kv_indices_ptr + st_offset, data, mask=mask) + ld_offset += BLOCK_SIZE + st_offset += BLOCK_SIZE diff --git a/python/sglang/srt/layers/attention/flashinfer_utils.py b/python/sglang/srt/layers/attention/flashinfer_utils.py deleted file mode 100644 index 796203c93..000000000 --- a/python/sglang/srt/layers/attention/flashinfer_utils.py +++ /dev/null @@ -1,237 +0,0 @@ -from enum import Enum, auto - -import torch -import triton -import triton.language as tl - - -class WrapperDispatch(Enum): - SLIDING_WINDOW = auto() - CROSS_ATTENTION = auto() - - -@triton.jit -def create_flashinfer_kv_indices_triton( - req_to_token_ptr, # [max_batch, max_context_len] - req_pool_indices_ptr, - page_kernel_lens_ptr, - kv_indptr, - kv_start_idx, - kv_indices_ptr, - max_context_len: tl.constexpr, -): - BLOCK_SIZE: tl.constexpr = 512 - pid = tl.program_id(axis=0) - req_pool_index = tl.load(req_pool_indices_ptr + pid) - kv_indices_offset = tl.load(kv_indptr + pid) - - kv_start = 0 - kv_end = 0 - if kv_start_idx: - kv_start = tl.load(kv_start_idx + pid).to(tl.int32) - kv_end = kv_start - kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) - - req_to_token_ptr += req_pool_index * max_context_len - kv_indices_ptr += kv_indices_offset - - ld_offset = kv_start + tl.arange(0, BLOCK_SIZE) - st_offset = tl.arange(0, BLOCK_SIZE) - num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) - for _ in range(num_loop): - mask = ld_offset < kv_end - data = tl.load(req_to_token_ptr + ld_offset, mask=mask) - tl.store(kv_indices_ptr + st_offset, data, mask=mask) - ld_offset += BLOCK_SIZE - st_offset += BLOCK_SIZE - - -class FlashinferUpdater: - def __init__( - self, - forward_mode, - model_runner, - req_pool_indices, - seq_lens, - prefix_lens, - decode_wrappers=None, - use_ragged=False, - ): - self.forward_mode = forward_mode - self.model_runner = model_runner - self.req_pool_indices = req_pool_indices - self.seq_lens = seq_lens - self.prefix_lens = prefix_lens - self.use_ragged = use_ragged - - self.num_qo_heads = ( - model_runner.model_config.num_attention_heads // model_runner.tp_size - ) - self.num_kv_heads = model_runner.model_config.get_num_kv_heads( - model_runner.tp_size - ) - self.head_dim = model_runner.model_config.head_dim - self.batch_size = len(req_pool_indices) - - self.decode_wrappers = ( - decode_wrappers or self.model_runner.attn_backend.decode_wrappers - ) - self.prefill_wrapper_ragged = ( - self.model_runner.attn_backend.prefill_wrapper_ragged - ) - self.prefill_wrappers_paged = ( - self.model_runner.attn_backend.prefill_wrappers_paged - ) - - self.kv_last_page_len = torch.ones( - (self.batch_size,), dtype=torch.int32, device="cuda" - ) - - def _update_decode_indices(self, decode_wrapper): - assert not isinstance(decode_wrapper, list) - decode_wrapper.end_forward() - decode_wrapper.begin_forward( - self.kv_indptr, - self.kv_indices, - self.kv_last_page_len, - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - 1, - data_type=self.model_runner.kv_cache_dtype, - q_data_type=self.model_runner.dtype, - ) - - def _update_extend_indices(self, ragged_wrapper, paged_wrapper): - assert not isinstance(paged_wrapper, list) - assert not isinstance(ragged_wrapper, list) - - # extend part - qo_indptr = torch.zeros( - (self.batch_size + 1,), dtype=torch.int32, device="cuda" - ) - qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0) - - if self.use_ragged: - ragged_wrapper.end_forward() - ragged_wrapper.begin_forward( - qo_indptr, - qo_indptr, - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - ) - - # cached part - paged_wrapper.end_forward() - paged_wrapper.begin_forward( - qo_indptr, - self.kv_indptr, - self.kv_indices, - self.kv_last_page_len, - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - 1, - ) - - def _get_indices(self, dispatch_reason: WrapperDispatch = None, wrapper_id=0): - if dispatch_reason is None: - if self.use_ragged: - paged_kernel_lens = self.prefix_lens - else: - paged_kernel_lens = self.seq_lens - self.kv_start_idx = None - elif dispatch_reason == WrapperDispatch.SLIDING_WINDOW: - if wrapper_id == 0: - # window attention use paged only - if self.forward_mode.is_decode(): - paged_kernel_lens = torch.minimum( - self.seq_lens, - torch.tensor(self.model_runner.sliding_window_size + 1), - ) - else: - paged_kernel_lens = torch.minimum( - self.seq_lens, - torch.tensor(self.model_runner.sliding_window_size) - + self.seq_lens - - self.prefix_lens, - ) - else: - # full attention - paged_kernel_lens = self.seq_lens - self.kv_start_idx = self.seq_lens - paged_kernel_lens - - self.kv_indptr = torch.zeros( - (self.batch_size + 1,), dtype=torch.int32, device="cuda" - ) - self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - self.kv_indices = torch.empty( - self.kv_indptr[-1], dtype=torch.int32, device="cuda" - ) - - create_flashinfer_kv_indices_triton[(self.batch_size,)]( - self.model_runner.req_to_token_pool.req_to_token, - self.req_pool_indices, - paged_kernel_lens, - self.kv_indptr, - self.kv_start_idx, - self.kv_indices, - self.model_runner.req_to_token_pool.req_to_token.size(1), - ) - - def _update_indicess_single_wrapper(self): - self._get_indices() - - if self.forward_mode.is_decode(): - self._update_decode_indices(self.decode_wrappers[0]) - else: - self._update_extend_indices( - self.prefill_wrapper_ragged, - self.prefill_wrappers_paged[0], - ) - - def _update_indices_cross_attention(self): - pass - - def _update_indices_sliding_window(self): - assert self.use_ragged is False - for wrapper_id in range(2): - self._get_indices(WrapperDispatch.SLIDING_WINDOW, wrapper_id) - if self.forward_mode.is_decode(): - self._update_decode_indices(self.decode_wrappers[wrapper_id]) - else: - self._update_extend_indices( - None, - self.prefill_wrappers_paged[wrapper_id], - ) - - -def update_flashinfer_indices( - forward_mode, - model_runner, - req_pool_indices, - seq_lens, - prefix_lens, - decode_wrappers=None, - use_ragged=False, -): - updater = FlashinferUpdater( - forward_mode, - model_runner, - req_pool_indices, - seq_lens, - prefix_lens, - decode_wrappers, - use_ragged, - ) - - dispatch_reason = model_runner.attn_backend.dispatch_reason - - if dispatch_reason == WrapperDispatch.SLIDING_WINDOW: - updater._update_indices_sliding_window() - elif dispatch_reason == WrapperDispatch.CROSS_ATTENTION: - updater._update_indices_cross_attention() - else: - assert model_runner.attn_backend.num_wrappers == 1 - updater._update_indicess_single_wrapper() diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 14c849a9f..e1f5bf371 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -81,7 +81,7 @@ class TritonAttnBackend(AttentionBackend): ) def init_forward_metadata_capture_cuda_graph( - self, bs: int, req_pool_indices, seq_lens + self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor ): self.forward_metadata = ( self.cuda_graph_start_loc, @@ -91,7 +91,7 @@ class TritonAttnBackend(AttentionBackend): ) def init_forward_metadata_replay_cuda_graph( - self, bs: int, req_pool_indices, seq_lens + self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor ): self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 0eeb3359e..f64b0699a 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -744,7 +744,6 @@ class ScheduleBatch: self.forward_mode = ForwardMode.DECODE self.input_ids = self.output_ids - self.seq_lens.add_(1) self.output_ids = None if self.sampling_info.penalizer_orchestrator: self.sampling_info.penalizer_orchestrator.cumulate_output_tokens( @@ -755,9 +754,10 @@ class ScheduleBatch: bs = len(self.reqs) self.out_cache_loc = self.alloc_token_slots(bs) - self.req_to_token_pool.req_to_token[ - self.req_pool_indices, self.seq_lens - 1 - ] = self.out_cache_loc + self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = ( + self.out_cache_loc + ) + self.seq_lens.add_(1) def filter_batch( self, diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index be6a72afd..fb7109464 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -134,9 +134,7 @@ class ForwardBatch: ) # Init position information - if ret.forward_mode.is_decode(): - ret.positions = (ret.seq_lens - 1).to(torch.int64) - else: + if not ret.forward_mode.is_decode(): ret.positions = torch.tensor( np.concatenate( [ @@ -164,7 +162,6 @@ class ForwardBatch: ret.req_to_token_pool = model_runner.req_to_token_pool ret.token_to_kv_pool = model_runner.token_to_kv_pool ret.attn_backend = model_runner.attn_backend - model_runner.attn_backend.init_forward_metadata(ret) # Init lora information if model_runner.server_args.lora_paths is not None: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d583dcd34..a60ac1c70 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -551,11 +551,14 @@ class ModelRunner: ): return self.cuda_graph_runner.replay(forward_batch) + forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64) + self.attn_backend.init_forward_metadata(forward_batch) return self.model.forward( forward_batch.input_ids, forward_batch.positions, forward_batch ) def forward_extend(self, forward_batch: ForwardBatch): + self.attn_backend.init_forward_metadata(forward_batch) if self.is_generation: return self.model.forward( forward_batch.input_ids, forward_batch.positions, forward_batch