diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 1dd9ab007..560b4deb2 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -192,5 +192,5 @@ Please consult the documentation below to learn more about the parameters you ma * `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you. * `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-, fp8wo, fp8dq-per_tensor, fp8dq-per_row. * `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8. -* `flashinfer_mla_disable_ragged`: Disable the use of the [ragged prefill](https://github.com/flashinfer-ai/flashinfer/blob/5751fc68f109877f6e0fc54f674cdcdef361af56/docs/tutorials/kv_layout.rst#L26) wrapper for the FlashInfer MLA attention backend. Ragged prefill increases throughput by computing MHA instead of paged MLA when there is no prefix match. Only use it when FlashInfer is being used as the MLA backend. +* `flashinfer_mla_disable_ragged`: Disable the use of the ragged prefill wrapper for the FlashInfer MLA attention backend. Only use it when FlashInfer is being used as the MLA backend. * `disable_chunked_prefix_cache`: Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend. diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 37b9f4d6e..8586006dc 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -428,25 +428,18 @@ class FlashInferAttnBackend(AttentionBackend): v_scale=v_scale, ) else: + o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( + q.view(-1, layer.tp_q_head_num, layer.head_dim), + k.view(-1, layer.tp_k_head_num, layer.head_dim), + v.view(-1, layer.tp_v_head_num, layer.head_dim), + causal=True, + sm_scale=layer.scaling, + logits_soft_cap=logits_soft_cap, + ) + if self.forward_metadata.extend_no_prefix: - o = prefill_wrapper_paged.forward( - q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), - causal=not layer.is_cross_attention, - sm_scale=layer.scaling, - logits_soft_cap=logits_soft_cap, - k_scale=k_scale, - v_scale=v_scale, - ) + o = o1 else: - o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( - q.view(-1, layer.tp_q_head_num, layer.head_dim), - k.view(-1, layer.tp_k_head_num, layer.head_dim), - v.view(-1, layer.tp_v_head_num, layer.head_dim), - causal=True, - sm_scale=layer.scaling, - logits_soft_cap=logits_soft_cap, - ) o2, s2 = prefill_wrapper_paged.forward_return_lse( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index a43dd0f86..81afcb9da 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -348,7 +348,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): if self.forward_metadata.use_ragged: # ragged prefill - o = self.prefill_wrapper_ragged.forward( + o, _ = self.prefill_wrapper_ragged.forward_return_lse( qall, k.view(-1, layer.tp_k_head_num, layer.head_dim), v.view(-1, layer.tp_k_head_num, layer.v_head_dim),