diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py new file mode 100644 index 000000000..4cad3d8aa --- /dev/null +++ b/python/sglang/srt/layers/attention/__init__.py @@ -0,0 +1,49 @@ +from abc import ABC, abstractmethod + +from torch import nn + +from sglang.srt.model_executor.forward_batch_info import ForwardBatch + + +class AttentionBackend(ABC): + """The base class of attention backends""" + + @abstractmethod + def init_forward_metadata(self, forward_batch: ForwardBatch): + """Init the metadata for a forward pass.""" + raise NotImplementedError() + + def init_cuda_graph_state(self, max_bs: int): + """Init the global shared states for cuda graph.""" + raise NotImplementedError() + + def init_forward_metadata_capture_cuda_graph( + self, bs: int, req_pool_indices, seq_lens + ): + """Init the metadata for a forward pass for capturing a cuda graph.""" + raise NotImplementedError() + + def init_forward_metadata_replay_cuda_graph( + self, bs: int, req_pool_indices, seq_lens + ): + """Init the metadata for a forward pass for replying a cuda graph.""" + raise NotImplementedError() + + def get_cuda_graph_seq_len_fill_value(self): + """Get the fill value for padded seq lens. Typically, it is 0 or 1.""" + raise NotImplementedError() + + def forward(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): + """Run forward on an attention layer.""" + if forward_batch.forward_mode.is_decode(): + return self.forward_decode(q, k, v, layer, forward_batch) + else: + return self.forward_extend(q, k, v, layer, forward_batch) + + def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): + """Run a forward for decode.""" + raise NotImplementedError() + + def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): + """Run a forward for extend.""" + raise NotImplementedError() diff --git a/python/sglang/srt/layers/attention_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py similarity index 58% rename from python/sglang/srt/layers/attention_backend.py rename to python/sglang/srt/layers/attention/flashinfer_backend.py index cff0a707a..305e46345 100644 --- a/python/sglang/srt/layers/attention_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -7,15 +7,14 @@ FlashInfer is faster and Triton is easier to customize. Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode. """ -from abc import ABC, abstractmethod from typing import TYPE_CHECKING import torch import torch.nn as nn from sglang.global_config import global_config -from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices -from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.layers.attention import AttentionBackend +from sglang.srt.layers.attention.flashinfer_utils import update_flashinfer_indices from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import is_hip @@ -33,50 +32,6 @@ if not is_hip(): from flashinfer.decode import _grouped_size_compiled_for_decode_kernels -class AttentionBackend(ABC): - """The base class of attention backends""" - - @abstractmethod - def init_forward_metadata(self, forward_batch: ForwardBatch): - """Init the metadata for a forward pass.""" - raise NotImplementedError() - - def init_cuda_graph_state(self, max_bs: int): - """Init the global shared states for cuda graph.""" - raise NotImplementedError() - - def init_forward_metadata_capture_cuda_graph( - self, bs: int, req_pool_indices, seq_lens - ): - """Init the metadata for a forward pass for capturing a cuda graph.""" - raise NotImplementedError() - - def init_forward_metadata_replay_cuda_graph( - self, bs: int, req_pool_indices, seq_lens - ): - """Init the metadata for a forward pass for replying a cuda graph.""" - raise NotImplementedError() - - def get_cuda_graph_seq_len_fill_value(self): - """Get the fill value for padded seq lens. Typically, it is 0 or 1.""" - raise NotImplementedError() - - def forward(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): - """Run forward on an attention layer.""" - if forward_batch.forward_mode.is_decode(): - return self.forward_decode(q, k, v, layer, forward_batch) - else: - return self.forward_extend(q, k, v, layer, forward_batch) - - def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): - """Run a forward for decode.""" - raise NotImplementedError() - - def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): - """Run a forward for extend.""" - raise NotImplementedError() - - class FlashInferAttnBackend(AttentionBackend): """Flashinfer attention kernels.""" @@ -329,151 +284,3 @@ class FlashInferAttnBackend(AttentionBackend): ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) - - -class TritonAttnBackend(AttentionBackend): - def __init__(self, model_runner: ModelRunner): - # Lazy import to avoid the initialization of cuda context - from sglang.srt.layers.triton_attention.decode_attention import ( - decode_attention_fwd, - ) - from sglang.srt.layers.triton_attention.extend_attention import ( - extend_attention_fwd, - ) - - super().__init__() - - self.decode_attention_fwd = decode_attention_fwd - self.extend_attention_fwd = extend_attention_fwd - self.num_head = ( - model_runner.model_config.num_attention_heads // model_runner.tp_size - ) - - if global_server_args_dict.get("triton_attention_reduce_in_fp32", False): - self.reduce_dtype = torch.float32 - else: - self.reduce_dtype = torch.float16 - - self.forward_metadata = None - - self.cuda_graph_max_seq_len = model_runner.model_config.context_len - - def init_forward_metadata(self, forward_batch: ForwardBatch): - """Init auxiliary variables for triton attention backend.""" - - if forward_batch.forward_mode.is_decode(): - start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32) - start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0) - - total_num_tokens = torch.sum(forward_batch.seq_lens).item() - attn_logits = torch.empty( - (self.num_head, total_num_tokens), - dtype=self.reduce_dtype, - device="cuda", - ) - - max_seq_len = torch.max(forward_batch.seq_lens).item() - max_extend_len = None - else: - start_loc = attn_logits = max_seq_len = None - prefix_lens = forward_batch.extend_prefix_lens - max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item() - - self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len - - def init_cuda_graph_state(self, max_bs: int): - self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len - - self.cuda_graph_start_loc = torch.zeros( - (max_bs,), dtype=torch.int32, device="cuda" - ) - self.cuda_graph_attn_logits = torch.empty( - ( - self.num_head, - self.cuda_graph_max_total_num_tokens, - ), - dtype=self.reduce_dtype, - device="cuda", - ) - - def init_forward_metadata_capture_cuda_graph( - self, bs: int, req_pool_indices, seq_lens - ): - self.forward_metadata = ( - self.cuda_graph_start_loc, - self.cuda_graph_attn_logits, - self.cuda_graph_max_seq_len, - None, - ) - - def init_forward_metadata_replay_cuda_graph( - self, bs: int, req_pool_indices, seq_lens - ): - self.cuda_graph_start_loc.zero_() - self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) - - def get_cuda_graph_seq_len_fill_value(self): - return 1 - - def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): - # TODO: reuse the buffer across layers - if layer.qk_head_dim != layer.v_head_dim: - o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) - else: - o = torch.empty_like(q) - - forward_batch.token_to_kv_pool.set_kv_buffer( - layer.layer_id, forward_batch.out_cache_loc, k, v - ) - - start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata - self.extend_attention_fwd( - q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), - k.contiguous(), - v.contiguous(), - o.view(-1, layer.tp_q_head_num, layer.v_head_dim), - forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), - forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), - forward_batch.req_to_token_pool.req_to_token, - forward_batch.req_pool_indices, - forward_batch.seq_lens, - forward_batch.extend_seq_lens, - forward_batch.extend_start_loc, - max_extend_len, - layer.scaling, - layer.logit_cap, - ) - return o - - def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): - # During torch.compile, there is a bug in rotary_emb that causes the - # output value to have a 3D tensor shape. This reshapes the output correctly. - q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) - - # TODO: reuse the buffer across layers - if layer.qk_head_dim != layer.v_head_dim: - o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) - else: - o = torch.empty_like(q) - - start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata - - forward_batch.token_to_kv_pool.set_kv_buffer( - layer.layer_id, forward_batch.out_cache_loc, k, v - ) - - self.decode_attention_fwd( - q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), - forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), - forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), - o.view(-1, layer.tp_q_head_num, layer.v_head_dim), - forward_batch.req_to_token_pool.req_to_token, - forward_batch.req_pool_indices, - start_loc, - forward_batch.seq_lens, - attn_logits, - max_seq_len, - layer.scaling, - layer.logit_cap, - ) - return o diff --git a/python/sglang/srt/layers/flashinfer_utils.py b/python/sglang/srt/layers/attention/flashinfer_utils.py similarity index 100% rename from python/sglang/srt/layers/flashinfer_utils.py rename to python/sglang/srt/layers/attention/flashinfer_utils.py diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py new file mode 100644 index 000000000..82b9596bf --- /dev/null +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +import torch.nn as nn + +from sglang.srt.layers.attention import AttentionBackend +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.forward_batch_info import ForwardBatch + +if TYPE_CHECKING: + from sglang.srt.model_executor.model_runner import ModelRunner + + +class TritonAttnBackend(AttentionBackend): + def __init__(self, model_runner: ModelRunner): + # Lazy import to avoid the initialization of cuda context + from sglang.srt.layers.attention.triton_ops.decode_attention import ( + decode_attention_fwd, + ) + from sglang.srt.layers.attention.triton_ops.extend_attention import ( + extend_attention_fwd, + ) + + super().__init__() + + self.decode_attention_fwd = decode_attention_fwd + self.extend_attention_fwd = extend_attention_fwd + self.num_head = ( + model_runner.model_config.num_attention_heads // model_runner.tp_size + ) + + if global_server_args_dict.get("triton_attention_reduce_in_fp32", False): + self.reduce_dtype = torch.float32 + else: + self.reduce_dtype = torch.float16 + + self.forward_metadata = None + + self.cuda_graph_max_seq_len = model_runner.model_config.context_len + + def init_forward_metadata(self, forward_batch: ForwardBatch): + """Init auxiliary variables for triton attention backend.""" + + if forward_batch.forward_mode.is_decode(): + start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32) + start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0) + + total_num_tokens = torch.sum(forward_batch.seq_lens).item() + attn_logits = torch.empty( + (self.num_head, total_num_tokens), + dtype=self.reduce_dtype, + device="cuda", + ) + + max_seq_len = torch.max(forward_batch.seq_lens).item() + max_extend_len = None + else: + start_loc = attn_logits = max_seq_len = None + prefix_lens = forward_batch.extend_prefix_lens + max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item() + + self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len + + def init_cuda_graph_state(self, max_bs: int): + self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len + + self.cuda_graph_start_loc = torch.zeros( + (max_bs,), dtype=torch.int32, device="cuda" + ) + self.cuda_graph_attn_logits = torch.empty( + ( + self.num_head, + self.cuda_graph_max_total_num_tokens, + ), + dtype=self.reduce_dtype, + device="cuda", + ) + + def init_forward_metadata_capture_cuda_graph( + self, bs: int, req_pool_indices, seq_lens + ): + self.forward_metadata = ( + self.cuda_graph_start_loc, + self.cuda_graph_attn_logits, + self.cuda_graph_max_seq_len, + None, + ) + + def init_forward_metadata_replay_cuda_graph( + self, bs: int, req_pool_indices, seq_lens + ): + self.cuda_graph_start_loc.zero_() + self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) + + def get_cuda_graph_seq_len_fill_value(self): + return 1 + + def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): + # TODO: reuse the buffer across layers + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + forward_batch.token_to_kv_pool.set_kv_buffer( + layer.layer_id, forward_batch.out_cache_loc, k, v + ) + + start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata + self.extend_attention_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k.contiguous(), + v.contiguous(), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), + forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), + forward_batch.req_to_token_pool.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.extend_seq_lens, + forward_batch.extend_start_loc, + max_extend_len, + layer.scaling, + layer.logit_cap, + ) + return o + + def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): + # During torch.compile, there is a bug in rotary_emb that causes the + # output value to have a 3D tensor shape. This reshapes the output correctly. + q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) + + # TODO: reuse the buffer across layers + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata + + forward_batch.token_to_kv_pool.set_kv_buffer( + layer.layer_id, forward_batch.out_cache_loc, k, v + ) + + self.decode_attention_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), + forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + forward_batch.req_to_token_pool.req_to_token, + forward_batch.req_pool_indices, + start_loc, + forward_batch.seq_lens, + attn_logits, + max_seq_len, + layer.scaling, + layer.logit_cap, + ) + return o diff --git a/python/sglang/srt/layers/triton_attention/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py similarity index 100% rename from python/sglang/srt/layers/triton_attention/decode_attention.py rename to python/sglang/srt/layers/attention/triton_ops/decode_attention.py diff --git a/python/sglang/srt/layers/triton_attention/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py similarity index 99% rename from python/sglang/srt/layers/triton_attention/extend_attention.py rename to python/sglang/srt/layers/attention/triton_ops/extend_attention.py index 3cf150d8d..919ef3d2e 100644 --- a/python/sglang/srt/layers/triton_attention/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -22,7 +22,9 @@ import torch import triton import triton.language as tl -from sglang.srt.layers.triton_attention.prefill_attention import context_attention_fwd +from sglang.srt.layers.attention.triton_ops.prefill_attention import ( + context_attention_fwd, +) CUDA_CAPABILITY = torch.cuda.get_device_capability() diff --git a/python/sglang/srt/layers/triton_attention/prefill_attention.py b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py similarity index 100% rename from python/sglang/srt/layers/triton_attention/prefill_attention.py rename to python/sglang/srt/layers/attention/triton_ops/prefill_attention.py diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 6351d54e3..0fdf300cd 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -37,7 +37,7 @@ import numpy as np import torch if TYPE_CHECKING: - from sglang.srt.layers.attention_backend import AttentionBackend + from sglang.srt.layers.attention import AttentionBackend from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.model_executor.model_runner import ModelRunner diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d4687a0a5..63cd1d3d6 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -39,7 +39,8 @@ from vllm.model_executor.models import ModelRegistry from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.constrained import disable_cache -from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend +from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend +from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import Sampler from sglang.srt.lora.lora_manager import LoRAManager diff --git a/scripts/deprecated/test_flashinfer.py b/scripts/deprecated/test_flashinfer.py index 7f0a081f6..2929d7bb8 100644 --- a/scripts/deprecated/test_flashinfer.py +++ b/scripts/deprecated/test_flashinfer.py @@ -6,8 +6,8 @@ from flashinfer import ( ) from flashinfer.decode import _grouped_size_compiled_for_decode_kernels -from sglang.srt.layers.token_attention import token_attention_fwd -from sglang.srt.layers.triton_attention.extend_attention import ( +from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd +from sglang.srt.layers.attention.triton_ops.extend_attention import ( extend_attention_fwd, redundant_attention, ) @@ -159,7 +159,7 @@ def test_batch_decode_with_paged_kv_cache( b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0) max_len_in_batch = kv_len other_kv_index = 0 - token_attention_fwd( + decode_attention_fwd( q, k_buffer, v_buffer, diff --git a/test/srt/test_create_kvindices.py b/test/srt/test_create_kvindices.py index 8fb0231d8..383b7ded5 100644 --- a/test/srt/test_create_kvindices.py +++ b/test/srt/test_create_kvindices.py @@ -4,7 +4,9 @@ import unittest import numpy as np import torch -from sglang.srt.layers.flashinfer_utils import create_flashinfer_kv_indices_triton +from sglang.srt.layers.attention.flashinfer_utils import ( + create_flashinfer_kv_indices_triton, +) class TestCreateKvIndices(unittest.TestCase): diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index b312a8c30..539b4d4e0 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -3,12 +3,14 @@ import unittest import torch -from sglang.srt.layers.triton_attention.decode_attention import decode_attention_fwd -from sglang.srt.layers.triton_attention.extend_attention import ( +from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd +from sglang.srt.layers.attention.triton_ops.extend_attention import ( extend_attention_fwd, redundant_attention, ) -from sglang.srt.layers.triton_attention.prefill_attention import context_attention_fwd +from sglang.srt.layers.attention.triton_ops.prefill_attention import ( + context_attention_fwd, +) class TestExtendAttention(unittest.TestCase):