diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index d7c5ff520..f742083f1 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -188,6 +188,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s | Arguments | Description | Defaults | |-----------|-------------|----------| | `--attention-backend` | Choose the kernels for attention layers. | None | +| `decode_attention_backend` | (Experimental) This argument specifies the backend for decode attention computation. Note that this argument has priority over `attention_backend`. | None | +| `prefill_attention_backend` | (Experimental) This argument specifies the backend for prefill attention computation. Note that this argument has priority over `attention_backend`. | None | | `--sampling-backend` | Choose the kernels for sampling layers. | None | | `--grammar-backend` | Choose the backend for grammar-guided decoding. | None | | `--mm-attention-backend` | Set multimodal attention backend. | None | diff --git a/python/sglang/srt/layers/attention/hybrid_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_attn_backend.py new file mode 100644 index 000000000..370961864 --- /dev/null +++ b/python/sglang/srt/layers/attention/hybrid_attn_backend.py @@ -0,0 +1,100 @@ +from typing import TYPE_CHECKING, Optional, Union + +import torch + +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput + + +class HybridAttnBackend(AttentionBackend): + """Support different backends for prefill and decode.""" + + def __init__( + self, prefill_backend: AttentionBackend, decode_backend: AttentionBackend + ): + self.prefill_backend = prefill_backend + self.decode_backend = decode_backend + + def init_forward_metadata(self, forward_batch: ForwardBatch): + if forward_batch.forward_mode.is_decode(): + self.decode_backend.init_forward_metadata(forward_batch) + else: + self.prefill_backend.init_forward_metadata(forward_batch) + + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): + self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + ): + self.decode_backend.init_forward_metadata_capture_cuda_graph( + bs, + num_tokens, + req_pool_indices, + seq_lens, + encoder_lens, + forward_mode, + spec_info, + ) + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + seq_lens_cpu: Optional[torch.Tensor], + ): + self.decode_backend.init_forward_metadata_replay_cuda_graph( + bs, + req_pool_indices, + seq_lens, + seq_lens_sum, + encoder_lens, + forward_mode, + spec_info, + seq_lens_cpu, + ) + + def get_cuda_graph_seq_len_fill_value(self): + return self.decode_backend.get_cuda_graph_seq_len_fill_value() + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + **kwargs, + ): + return self.decode_backend.forward_decode( + q, k, v, layer, forward_batch, save_kv_cache, **kwargs + ) + + def forward_extend( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + **kwargs, + ): + return self.prefill_backend.forward_extend( + q, k, v, layer, forward_batch, save_kv_cache, **kwargs + ) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 5d174db77..e698bf85b 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1690,16 +1690,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): extend_prefix_lens = self.prefix_lens extend_logprob_start_lens = self.extend_logprob_start_lens + if self.forward_mode.is_decode_or_idle(): + attention_backend_str = global_server_args_dict["decode_attention_backend"] + else: + attention_backend_str = global_server_args_dict["prefill_attention_backend"] # Create seq_lens_cpu when needed if ( - global_server_args_dict["attention_backend"] == "fa3" + attention_backend_str == "fa3" or ( global_server_args_dict["use_mla_backend"] - and global_server_args_dict["attention_backend"] == "flashinfer" + and attention_backend_str == "flashinfer" ) - or global_server_args_dict["attention_backend"] == "flashmla" - or global_server_args_dict["attention_backend"] == "cutlass_mla" - or global_server_args_dict["attention_backend"] == "ascend" + or attention_backend_str == "flashmla" + or attention_backend_str == "cutlass_mla" + or attention_backend_str == "ascend" or global_server_args_dict["enable_two_batch_overlap"] ): seq_lens_cpu = ( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index fbb08077c..13555adeb 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1308,9 +1308,58 @@ class ModelRunner: else: self.attn_backend = self._get_attention_backend() - # TODO unify with 6338 def _get_attention_backend(self): - if self.server_args.attention_backend == "flashinfer": + """Init attention kernel backend.""" + self.decode_attention_backend_str = ( + self.server_args.decode_attention_backend + if self.server_args.decode_attention_backend + else self.server_args.attention_backend + ) + self.prefill_attention_backend_str = ( + self.server_args.prefill_attention_backend + if self.server_args.prefill_attention_backend + else self.server_args.attention_backend + ) + if self.decode_attention_backend_str != self.prefill_attention_backend_str: + assert ( + self.server_args.speculative_algorithm is None + ), "Currently HybridAttentionBackend does not support speculative decoding." + from sglang.srt.layers.attention.hybrid_attn_backend import ( + HybridAttnBackend, + ) + + attn_backend = HybridAttnBackend( + decode_backend=self._get_attention_backend_from_str( + self.decode_attention_backend_str + ), + prefill_backend=self._get_attention_backend_from_str( + self.prefill_attention_backend_str + ), + ) + logger.info( + f"Using hybrid attention backend for decode and prefill: " + f"decode_backend={self.decode_attention_backend_str}, " + f"prefill_backend={self.prefill_attention_backend_str}." + ) + logger.warning( + f"Warning: Attention backend specified by --attention-backend or default backend might be overridden." + f"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem." + ) + else: + attn_backend = self._get_attention_backend_from_str( + self.server_args.attention_backend + ) + + global_server_args_dict.update( + { + "decode_attention_backend": self.decode_attention_backend_str, + "prefill_attention_backend": self.prefill_attention_backend_str, + } + ) + return attn_backend + + def _get_attention_backend_from_str(self, backend_str: str): + if backend_str == "flashinfer": if not self.use_mla_backend: from sglang.srt.layers.attention.flashinfer_backend import ( FlashInferAttnBackend, @@ -1318,7 +1367,11 @@ class ModelRunner: # Init streams if self.server_args.speculative_algorithm == "EAGLE": - self.plan_stream_for_flashinfer = torch.cuda.Stream() + if ( + not hasattr(self, "plan_stream_for_flashinfer") + or not self.plan_stream_for_flashinfer + ): + self.plan_stream_for_flashinfer = torch.cuda.Stream() return FlashInferAttnBackend(self) else: from sglang.srt.layers.attention.flashinfer_mla_backend import ( @@ -1326,15 +1379,15 @@ class ModelRunner: ) return FlashInferMLAAttnBackend(self) - elif self.server_args.attention_backend == "aiter": + elif backend_str == "aiter": from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend return AiterAttnBackend(self) - elif self.server_args.attention_backend == "ascend": + elif backend_str == "ascend": from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend return AscendAttnBackend(self) - elif self.server_args.attention_backend == "triton": + elif backend_str == "triton": assert not self.model_config.is_encoder_decoder, ( "Cross attention is not supported in the triton attention backend. " "Please use `--attention-backend flashinfer`." @@ -1349,17 +1402,17 @@ class ModelRunner: from sglang.srt.layers.attention.triton_backend import TritonAttnBackend return TritonAttnBackend(self) - elif self.server_args.attention_backend == "torch_native": + elif backend_str == "torch_native": from sglang.srt.layers.attention.torch_native_backend import ( TorchNativeAttnBackend, ) return TorchNativeAttnBackend(self) - elif self.server_args.attention_backend == "flashmla": + elif backend_str == "flashmla": from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend return FlashMLABackend(self) - elif self.server_args.attention_backend == "fa3": + elif backend_str == "fa3": assert ( torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend ) or torch.cuda.get_device_capability()[0] == 9, ( @@ -1371,7 +1424,7 @@ class ModelRunner: ) return FlashAttentionBackend(self) - elif self.server_args.attention_backend == "cutlass_mla": + elif backend_str == "cutlass_mla": from sglang.srt.layers.attention.cutlass_mla_backend import ( CutlassMLABackend, ) @@ -1385,9 +1438,7 @@ class ModelRunner: logger.info(f"Intel AMX attention backend is enabled.") return IntelAMXAttnBackend(self) else: - raise ValueError( - f"Invalid attention backend: {self.server_args.attention_backend}" - ) + raise ValueError(f"Invalid attention backend: {backend_str}") def init_double_sparsity_channel_config(self, selected_channel): selected_channel = "." + selected_channel + "_proj" @@ -1475,7 +1526,10 @@ class ModelRunner: if self.support_pp: kwargs["pp_proxy_tensors"] = pp_proxy_tensors return self.model.forward( - forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + **kwargs, ) def forward_extend( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index be6ef9bf3..777b8e0c8 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -925,7 +925,10 @@ class DeepseekV2AttentionMLA(nn.Module): self.disable_chunked_prefix_cache = global_server_args_dict[ "disable_chunked_prefix_cache" ] - self.attention_backend = global_server_args_dict["attention_backend"] + + self.current_attention_backend = ( + None # Attention backend used by current forward batch + ) self.rocm_fused_decode_mla = get_bool_env_var( "SGLANG_ROCM_FUSED_DECODE_MLA", "false" ) @@ -1009,9 +1012,16 @@ class DeepseekV2AttentionMLA(nn.Module): else: return AttnForwardMethod.MLA - if self.attention_backend == "ascend": + # Determine attention backend used by current forward batch + if forward_batch.forward_mode.is_decode_or_idle(): + attention_backend = global_server_args_dict["decode_attention_backend"] + else: + attention_backend = global_server_args_dict["prefill_attention_backend"] + self.current_attention_backend = attention_backend + + if attention_backend == "ascend": return AttnForwardMethod.MLA - elif self.attention_backend == "flashinfer": + elif attention_backend == "flashinfer": # Flashinfer MLA: Do not absorb when enabling ragged prefill if ( not self.flashinfer_mla_disable_ragged @@ -1023,7 +1033,7 @@ class DeepseekV2AttentionMLA(nn.Module): return AttnForwardMethod.MHA else: return _dispatch_mla_subtype() - elif self.attention_backend == "fa3": + elif attention_backend == "fa3": # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences. if forward_batch.extend_prefix_lens_cpu is not None: sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu) @@ -1040,7 +1050,7 @@ class DeepseekV2AttentionMLA(nn.Module): return AttnForwardMethod.MHA_CHUNKED_KV else: return _dispatch_mla_subtype() - elif self.attention_backend == "aiter": + elif attention_backend == "aiter": if ( forward_batch.forward_mode.is_extend() and not forward_batch.forward_mode.is_target_verify() @@ -1288,9 +1298,9 @@ class DeepseekV2AttentionMLA(nn.Module): self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator ): if ( - self.attention_backend == "fa3" - or self.attention_backend == "flashinfer" - or self.attention_backend == "cutlass_mla" + self.current_attention_backend == "fa3" + or self.current_attention_backend == "flashinfer" + or self.current_attention_backend == "cutlass_mla" ): attn_output = self.attn_mqa( q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f1497d2a6..b0e6fbab3 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -151,6 +151,8 @@ class ServerArgs: # Kernel backend attention_backend: Optional[str] = None + decode_attention_backend: Optional[str] = None + prefill_attention_backend: Optional[str] = None sampling_backend: Optional[str] = None grammar_backend: Optional[str] = None mm_attention_backend: Optional[str] = None @@ -387,13 +389,19 @@ class ServerArgs: ) self.page_size = 128 - if self.attention_backend == "flashmla": + if ( + self.attention_backend == "flashmla" + or self.decode_attention_backend == "flashmla" + ): logger.warning( "FlashMLA only supports a page_size of 64, change page_size to 64." ) self.page_size = 64 - if self.attention_backend == "cutlass_mla": + if ( + self.attention_backend == "cutlass_mla" + or self.decode_attention_backend == "cutlass_mla" + ): logger.warning( "Cutlass MLA only supports a page_size of 128, change page_size to 128." ) @@ -1213,6 +1221,35 @@ class ServerArgs: default=ServerArgs.attention_backend, help="Choose the kernels for attention layers.", ) + parser.add_argument( + "--decode-attention-backend", + type=str, + choices=[ + "flashinfer", + "triton", + "torch_native", + "fa3", + "flashmla", + "cutlass_mla", + ], + default=ServerArgs.decode_attention_backend, + help="Choose the kernels for decode attention layers (have priority over --attention-backend).", + ) + + parser.add_argument( + "--prefill-attention-backend", + type=str, + choices=[ + "flashinfer", + "triton", + "torch_native", + "fa3", + "flashmla", + "cutlass_mla", + ], + default=ServerArgs.prefill_attention_backend, + help="Choose the kernels for prefill attention layers (have priority over --attention-backend).", + ) parser.add_argument( "--sampling-backend", type=str, diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index ed30b3687..60c010e31 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -491,6 +491,8 @@ class SRTRunner: lora_paths: List[str] = None, max_loras_per_batch: int = 4, attention_backend: Optional[str] = None, + prefill_attention_backend: Optional[str] = None, + decode_attention_backend: Optional[str] = None, lora_backend: str = "triton", disable_cuda_graph: bool = False, disable_radix_cache: bool = False, @@ -540,6 +542,8 @@ class SRTRunner: max_loras_per_batch=max_loras_per_batch, lora_backend=lora_backend, attention_backend=attention_backend, + prefill_attention_backend=prefill_attention_backend, + decode_attention_backend=decode_attention_backend, disable_cuda_graph=disable_cuda_graph, disable_radix_cache=disable_radix_cache, chunked_prefill_size=chunked_prefill_size, diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 19ff9d560..c9876e161 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -109,6 +109,7 @@ suites = { TestFile("test_vision_openai_server_b.py", 620), TestFile("test_w8a8_quantization.py", 46), TestFile("test_reasoning_parser.py", 5), + TestFile("test_hybrid_attn_backend.py", 100), ], "per-commit-amd": [ TestFile("models/lora/test_lora_backend.py", 99), diff --git a/test/srt/test_hybrid_attn_backend.py b/test/srt/test_hybrid_attn_backend.py new file mode 100644 index 000000000..6791447f4 --- /dev/null +++ b/test/srt/test_hybrid_attn_backend.py @@ -0,0 +1,109 @@ +import os +import unittest +from types import SimpleNamespace + +import requests + +from sglang.srt.utils import get_device_sm, kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_MODEL_NAME_FOR_TEST_MLA, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +GSM_DATASET_PATH = None + +# Default server arguments shared across all tests +DEFAULT_SERVER_ARGS = [ + "--trust-remote-code", + "--cuda-graph-max-bs", + "8", + "--prefill-attention-backend", + "fa3", + "--decode-attention-backend", + "flashinfer", +] + + +@unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher") +class TestHybridAttnBackendBase(CustomTestCase): + + model = DEFAULT_MODEL_NAME_FOR_TEST + base_url = DEFAULT_URL_FOR_TEST + accuracy_threshold = 0.65 # derived tests need to override this + speculative_decode = False + spec_decode_threshold = 1.0 # derived spec decoding tests need to override this + + @classmethod + def get_server_args(cls): + """Return the arguments for the server launch. Override in subclasses.""" + return DEFAULT_SERVER_ARGS + + @classmethod + def setUpClass(cls): + # disable deep gemm precompile to make launch server faster + # please don't do this if you want to make your inference workload faster + os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false" + os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=cls.get_server_args(), + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + requests.get(self.base_url + "/flush_cache") + + args = SimpleNamespace( + num_shots=4, + num_questions=100, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + data_path=GSM_DATASET_PATH, + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"{metrics=}") + + # Use the appropriate metric key based on the test class + metric_key = "accuracy" + self.assertGreater(metrics[metric_key], self.accuracy_threshold) + + if self.speculative_decode: + server_info = requests.get(self.base_url + "/get_server_info") + avg_spec_accept_length = server_info.json()["internal_states"][0][ + "avg_spec_accept_length" + ] + print(f"{avg_spec_accept_length=}") + self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold) + + +class TestHybridAttnBackendMLA(TestHybridAttnBackendBase): + accuracy_threshold = 0.60 + model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + + @classmethod + def get_server_args(cls): + return DEFAULT_SERVER_ARGS + + +class TestHybridAttnBackendTorchCompile(TestHybridAttnBackendBase): + accuracy_threshold = 0.65 + + @classmethod + def get_server_args(cls): + return DEFAULT_SERVER_ARGS + ["--enable-torch-compile"] + + +if __name__ == "__main__": + unittest.main()