diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md index 2b6836d5c..6289fa357 100644 --- a/docs/references/deepseek.md +++ b/docs/references/deepseek.md @@ -84,7 +84,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be - **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase. -- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off. +- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off. Currently when using flashinfer mla wrapper and speculative decoding together, the `speculative_eagle_topk` parameter should be set to 1. - **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption. diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 9e81acc6f..9af027bd1 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -11,9 +11,10 @@ More details can be found in https://docs.flashinfer.ai/api/mla.html from dataclasses import dataclass from functools import partial -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, Union import torch +import triton from sglang.global_config import global_config from sglang.srt.layers.attention.base_attn_backend import AttentionBackend @@ -23,6 +24,7 @@ from sglang.srt.layers.attention.flashinfer_backend import ( from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.utils import is_flashinfer_available if TYPE_CHECKING: @@ -58,12 +60,16 @@ class FlashInferMLAAttnBackend(AttentionBackend): def __init__( self, model_runner: ModelRunner, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + q_indptr_decode_buf: Optional[torch.Tensor] = None, ): super().__init__() # Parse constants self.max_context_len = model_runner.model_config.context_len self.device = model_runner.device + self.skip_prefill = skip_prefill global_config.enable_flashinfer_mla = True @@ -78,35 +84,51 @@ class FlashInferMLAAttnBackend(AttentionBackend): self.workspace_buffer = global_workspace_buffer max_bs = model_runner.req_to_token_pool.size - self.kv_indptr = torch.zeros( - (max_bs + 1,), dtype=torch.int32, device=model_runner.device - ) + if kv_indptr_buf is None: + self.kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + else: + self.kv_indptr = kv_indptr_buf - self.qo_indptr = torch.zeros( - (max_bs + 1,), dtype=torch.int32, device=model_runner.device - ) + if not self.skip_prefill: + self.qo_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) - self.q_indptr_decode = torch.arange( - 0, max_bs + 1, dtype=torch.int32, device=model_runner.device - ) + if q_indptr_decode_buf is None: + self.q_indptr_decode = torch.arange( + 0, max_bs + 1, dtype=torch.int32, device=model_runner.device + ) + else: + self.q_indptr_decode = q_indptr_decode_buf self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( self.workspace_buffer, "NHD" ) - self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper( - self.workspace_buffer, - backend="auto", - ) + if not self.skip_prefill: + self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper( + self.workspace_buffer, + backend="auto", + ) + + # FlashinferMLA backend uses mla wrapper for target verify + self.prefill_wrapper_verify = BatchMLAPagedAttentionWrapper( + self.workspace_buffer, + backend="auto", + ) self.decode_wrapper = BatchMLAPagedAttentionWrapper( self.workspace_buffer, backend="auto" ) # Create indices updater - self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill( - model_runner, self - ) + if not skip_prefill: + self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill( + model_runner, self + ) + self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode( model_runner, self ) @@ -114,7 +136,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): # Other metadata self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None self.decode_cuda_graph_metadata = {} - self.prefill_cuda_graph_metadata = {} + self.prefill_cuda_graph_metadata = {} # For verify def init_forward_metadata(self, forward_batch: ForwardBatch): if forward_batch.forward_mode.is_decode_or_idle(): @@ -126,6 +148,28 @@ class FlashInferMLAAttnBackend(AttentionBackend): init_metadata_replay=False, ) self.forward_metadata = DecodeMetadata(self.decode_wrapper) + elif forward_batch.forward_mode.is_draft_extend(): + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + prefix_lens=None, + prefill_wrapper_paged=self.prefill_wrapper_paged, + use_ragged=False, + spec_info=forward_batch.spec_info, + ) + self.forward_metadata = PrefillMetadata(self.prefill_wrapper_paged, False) + elif forward_batch.forward_mode.is_target_verify(): + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + prefix_lens=None, + prefill_wrapper_paged=self.prefill_wrapper_verify, + use_ragged=False, + spec_info=forward_batch.spec_info, + ) + self.forward_metadata = PrefillMetadata(self.prefill_wrapper_verify, False) else: prefix_lens = forward_batch.extend_prefix_lens extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) @@ -202,10 +246,33 @@ class FlashInferMLAAttnBackend(AttentionBackend): seq_lens_sum, decode_wrapper=decode_wrapper, init_metadata_replay=False, + spec_info=spec_info, ) self.decode_cuda_graph_metadata[bs] = decode_wrapper self.forward_metadata = DecodeMetadata(decode_wrapper) decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper) + elif forward_mode.is_target_verify(): + verify_wrapper = BatchMLAPagedAttentionWrapper( + self.workspace_buffer, + use_cuda_graph=True, + qo_indptr=self.cuda_graph_qo_indptr[: bs + 1], + kv_indptr=self.cuda_graph_kv_indptr[: bs + 1], + kv_indices=self.cuda_graph_kv_indices, + kv_len_arr=self.cuda_graph_kv_lens[:bs], + backend="auto", + ) + seq_lens_sum = seq_lens.sum().item() + self.indices_updater_prefill.update( + req_pool_indices, + seq_lens, + seq_lens_sum, + prefix_lens=None, + prefill_wrapper_paged=verify_wrapper, + use_ragged=False, + spec_info=spec_info, + ) + self.prefill_cuda_graph_metadata[bs] = verify_wrapper + self.forward_metadata = PrefillMetadata(verify_wrapper, False) else: raise ValueError(f"Invalid mode: {forward_mode=}") @@ -221,6 +288,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): seq_lens_cpu: Optional[torch.Tensor], ): if forward_mode.is_decode_or_idle(): + assert seq_lens_cpu is not None kv_len_arr_cpu = seq_lens_cpu[:bs] self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum( kv_len_arr_cpu, dim=0 @@ -239,8 +307,19 @@ class FlashInferMLAAttnBackend(AttentionBackend): seq_lens_sum, decode_wrapper=self.decode_cuda_graph_metadata[bs], init_metadata_replay=True, + spec_info=spec_info, **self.fast_decode_kwargs, ) + elif forward_mode.is_target_verify(): + self.indices_updater_prefill.update( + req_pool_indices[:bs], + seq_lens[:bs], + seq_lens_sum, + prefix_lens=None, + prefill_wrapper_paged=self.prefill_cuda_graph_metadata[bs], + use_ragged=False, + spec_info=spec_info, + ) else: raise ValueError(f"Invalid forward mode: {forward_mode=}") @@ -254,7 +333,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, - save_kv_cache=True, + save_kv_cache: bool = True, ): cache_loc = forward_batch.out_cache_loc @@ -297,7 +376,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, - save_kv_cache=True, + save_kv_cache: bool = True, ): decode_wrapper = self.forward_metadata.decode_wrapper cache_loc = forward_batch.out_cache_loc @@ -349,6 +428,7 @@ class FlashInferMLAIndicesUpdaterDecode: seq_lens_sum: int, decode_wrapper: BatchMLAPagedAttentionWrapper, init_metadata_replay: bool = False, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None, **fast_decode_kwargs, ): decode_wrapper = decode_wrapper or self.decode_wrapper @@ -360,6 +440,7 @@ class FlashInferMLAIndicesUpdaterDecode: self.q_indptr, self.kv_indptr, init_metadata_replay, + spec_info, **fast_decode_kwargs, ) @@ -372,30 +453,33 @@ class FlashInferMLAIndicesUpdaterDecode: q_indptr: torch.Tensor, kv_indptr: torch.Tensor, init_metadata_replay: bool = False, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None, **fast_decode_kwargs, ): bs = len(req_pool_indices) q_indptr = q_indptr[: bs + 1] - kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] - kv_indices = ( - torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda") - if not init_metadata_replay - else fast_decode_kwargs["kv_indices"] - ) - kv_lens = paged_kernel_lens.to(torch.int32) sm_scale = self.scaling + if spec_info is None: + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = ( + torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda") + if not init_metadata_replay + else fast_decode_kwargs["kv_indices"] + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.shape[1], + ) + else: + kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices, - paged_kernel_lens, - kv_indptr, - None, - kv_indices, - self.req_to_token.shape[1], - ) if not init_metadata_replay: wrapper.plan( q_indptr, @@ -457,6 +541,7 @@ class FlashInferMLAIndicesUpdaterPrefill: prefix_lens: torch.Tensor, prefill_wrapper_paged: BatchMLAPagedAttentionWrapper, use_ragged: bool, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None, ): if use_ragged: paged_kernel_lens = prefix_lens @@ -476,6 +561,7 @@ class FlashInferMLAIndicesUpdaterPrefill: self.kv_indptr, self.qo_indptr, use_ragged, + spec_info, ) def call_begin_forward( @@ -490,29 +576,46 @@ class FlashInferMLAIndicesUpdaterPrefill: kv_indptr: torch.Tensor, qo_indptr: torch.Tensor, use_ragged: bool, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None, ): - bs = len(req_pool_indices) - kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.empty( - paged_kernel_lens_sum, - dtype=torch.int32, - device=req_pool_indices.device, - ) - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices, - paged_kernel_lens, - kv_indptr, - None, - kv_indices, - self.req_to_token.shape[1], - ) - - qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) - qo_indptr = qo_indptr[: bs + 1] + bs = len(seq_lens) sm_scale = self.scaling + if spec_info is None: + assert len(seq_lens) == len(req_pool_indices) + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + paged_kernel_lens_sum, + dtype=torch.int32, + device=req_pool_indices.device, + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.shape[1], + ) + qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) + qo_indptr = qo_indptr[: bs + 1] + custom_mask = None + else: + assert isinstance(spec_info, EagleDraftInput) or isinstance( + spec_info, EagleVerifyInput + ) + # TODO: Support topk > 1 with custom mask + kv_indices, kv_indptr, qo_indptr, custom_mask = ( + spec_info.generate_attn_arg_prefill( + req_pool_indices, + paged_kernel_lens, + paged_kernel_lens_sum, + self.req_to_token, + ) + ) + if use_ragged: # ragged prefill wrapper_ragged.begin_forward( @@ -543,6 +646,163 @@ class FlashInferMLAIndicesUpdaterPrefill: ) +class FlashInferMLAMultiStepDraftBackend: + """ + Wrap multiple flashinfer mla attention backends as one for multiple consecutive + draft decoding steps. + """ + + def __init__( + self, + model_runner: ModelRunner, + topk: int, + speculative_num_steps: int, + ): + from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices + + if topk > 1: + raise ValueError( + f"Currently Flashinfer MLA only supports topk=1 for speculative decoding" + ) + self.topk = topk + self.speculative_num_steps = speculative_num_steps + self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices + + max_bs = model_runner.req_to_token_pool.size * self.topk + self.kv_indptr = torch.zeros( + ( + self.speculative_num_steps, + max_bs + 1, + ), + dtype=torch.int32, + device=model_runner.device, + ) + self.q_indptr_decode = torch.arange( + 0, max_bs + 1, dtype=torch.int32, device=model_runner.device + ) + + self.attn_backends = [] + for i in range(self.speculative_num_steps): + self.attn_backends.append( + FlashInferMLAAttnBackend( + model_runner, + skip_prefill=True, + kv_indptr_buf=self.kv_indptr[i], + q_indptr_decode_buf=self.q_indptr_decode, + ) + ) + + self.max_context_len = self.attn_backends[0].max_context_len + + # Cached variables for generate_draft_decode_kv_indices + self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1] + + def common_template( + self, + forward_batch: ForwardBatch, + kv_indices_buffer: torch.Tensor, + call_fn: Callable, + ): + num_seqs = forward_batch.batch_size + bs = self.topk * num_seqs + seq_lens_sum = forward_batch.seq_lens_sum + + self.generate_draft_decode_kv_indices[ + (self.speculative_num_steps, num_seqs, self.topk) + ]( + forward_batch.req_pool_indices, + forward_batch.req_to_token_pool.req_to_token, + forward_batch.seq_lens, + kv_indices_buffer, + self.kv_indptr, + forward_batch.positions, + num_seqs, + self.topk, + self.pool_len, + kv_indices_buffer.shape[1], + self.kv_indptr.shape[1], + triton.next_power_of_2(num_seqs), + triton.next_power_of_2(self.speculative_num_steps), + triton.next_power_of_2(bs), + ) + + assert forward_batch.spec_info is not None + assert isinstance(forward_batch.spec_info, EagleDraftInput) + + for i in range(self.speculative_num_steps - 1): + forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1] + forward_batch.spec_info.kv_indices = kv_indices_buffer[i][ + : seq_lens_sum * self.topk + bs * (i + 1) + ] + call_fn(i, forward_batch) + + def init_forward_metadata(self, forward_batch: ForwardBatch): + kv_indices = torch.zeros( + ( + self.speculative_num_steps, + forward_batch.batch_size * self.topk * self.max_context_len, + ), + dtype=torch.int32, + device="cuda", + ) + + def call_fn(i, forward_batch): + assert forward_batch.spec_info is not None + assert isinstance(forward_batch.spec_info, EagleDraftInput) + forward_batch.spec_info.kv_indptr = ( + forward_batch.spec_info.kv_indptr.clone() + ) + forward_batch.spec_info.kv_indices = ( + forward_batch.spec_info.kv_indices.clone() + ) + self.attn_backends[i].init_forward_metadata(forward_batch) + + self.common_template(forward_batch, kv_indices, call_fn) + + def init_cuda_graph_state(self, max_bs: int): + self.cuda_graph_kv_indices = torch.zeros( + (self.speculative_num_steps, max_bs * self.max_context_len), + dtype=torch.int32, + device="cuda", + ) + + for i in range(self.speculative_num_steps): + self.attn_backends[i].init_cuda_graph_state( + max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i] + ) + + def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): + def call_fn(i, forward_batch): + self.attn_backends[i].init_forward_metadata_capture_cuda_graph( + forward_batch.batch_size, + forward_batch.batch_size * self.topk, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + ) + + self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) + + def init_forward_metadata_replay_cuda_graph( + self, forward_batch: ForwardBatch, bs: int + ): + def call_fn(i, forward_batch): + self.attn_backends[i].init_forward_metadata_replay_cuda_graph( + bs, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + seq_lens_sum=-1, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + seq_lens_cpu=forward_batch.decode_seq_lens_cpu, + ) + + self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) + + def fast_mla_decode_plan( self, qo_indptr_cpu: torch.Tensor, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 13544007e..82c73ec94 100755 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -555,6 +555,8 @@ class DeepseekV2AttentionMLA(nn.Module): return ( not global_server_args_dict["flashinfer_mla_disable_ragged"] and forward_batch.forward_mode.is_extend() + and not forward_batch.forward_mode.is_target_verify() + and not forward_batch.forward_mode.is_draft_extend() and forward_batch.extend_prefix_lens.sum() == 0 ) else: diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index bd2fa6009..90d47cc0f 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -123,6 +123,16 @@ class EAGLEWorker(TpModelWorker): self.topk, self.speculative_num_steps, ) + elif self.server_args.attention_backend == "flashinfer_mla": + from sglang.srt.layers.attention.flashinfer_mla_backend import ( + FlashInferMLAMultiStepDraftBackend, + ) + + self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend( + self.model_runner, + self.topk, + self.speculative_num_steps, + ) else: raise ValueError( f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}" diff --git a/test/srt/test_mla_flashinfer.py b/test/srt/test_mla_flashinfer.py index 04586acc5..e7113d03d 100644 --- a/test/srt/test_mla_flashinfer.py +++ b/test/srt/test_mla_flashinfer.py @@ -1,6 +1,7 @@ import unittest from types import SimpleNamespace +import requests import torch from sglang.srt.utils import kill_process_tree @@ -100,5 +101,67 @@ class TestFlashinferMLANoRagged(unittest.TestCase): self.assertGreater(metrics["accuracy"], 0.62) +class TestFlashinferMLAMTP(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = "lmsys/sglang-ci-dsv3-test" + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = ["--trust-remote-code"] + if torch.cuda.is_available() and torch.version.cuda: + other_args.extend( + [ + "--cuda-graph-max-bs", + "2", + "--disable-radix", + "--enable-torch-compile", + "--torch-compile-max-bs", + "1", + "--speculative-algorithm", + "EAGLE", + "--speculative-draft", + "lmsys/sglang-ci-dsv3-test-NextN", + "--speculative-num-steps", + "4", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + "--enable-flashinfer-mla", + ] + ) + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + requests.get(self.base_url + "/flush_cache") + + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.60) + + server_info = requests.get(self.base_url + "/get_server_info") + avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] + print(f"{avg_spec_accept_length=}") + self.assertGreater(avg_spec_accept_length, 2.5) + + if __name__ == "__main__": unittest.main()