From cbdfb77123e020aa6d45e423b283f9a3d96e4f96 Mon Sep 17 00:00:00 2001 From: Clay Date: Sun, 20 Jul 2025 10:30:16 +0800 Subject: [PATCH] Enable FlashInfer support encoder models and add head_dim padding workaround (#6230) --- .../srt/layers/attention/flashinfer_backend.py | 11 ++++++++++- .../srt/models/test_encoder_embedding_models.py | 17 +++++++++++++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index f65e533d9..c7da38ac5 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -25,6 +25,7 @@ from sglang.global_config import global_config from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode @@ -486,12 +487,20 @@ class FlashInferAttnBackend(AttentionBackend): v_scale=layer.v_scale, ) else: + causal = True + if layer.attn_type == AttentionType.ENCODER_ONLY: + save_kv_cache = False + causal = False + if self.forward_metadata.extend_no_prefix: + # NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions + # The FlashInfer head_dim limitation itself is tracked here: + # https://github.com/flashinfer-ai/flashinfer/issues/1048 o = self.prefill_wrapper_ragged.forward( q.view(-1, layer.tp_q_head_num, layer.head_dim), k.view(-1, layer.tp_k_head_num, layer.head_dim), v.view(-1, layer.tp_v_head_num, layer.head_dim), - causal=True, + causal=causal, sm_scale=layer.scaling, logits_soft_cap=logits_soft_cap, ) diff --git a/test/srt/models/test_encoder_embedding_models.py b/test/srt/models/test_encoder_embedding_models.py index bea5d4aff..dafaa72db 100644 --- a/test/srt/models/test_encoder_embedding_models.py +++ b/test/srt/models/test_encoder_embedding_models.py @@ -27,9 +27,9 @@ from sglang.test.test_utils import CustomTestCase, get_similarities, is_in_ci MODELS = [("BAAI/bge-small-en", 1, 1e-5), ("BAAI/bge-m3", 1, 1e-5)] -ATTENTION_BACKEND = ["torch_native", "triton"] +ATTENTION_BACKEND = ["torch_native", "triton", "flashinfer"] BATCH_SIZE = [1, 2] -TORCH_DTYPES = [torch.float32] +TORCH_DTYPES = [torch.float32, torch.float16] sgl_to_st_ratio = [] @@ -126,6 +126,19 @@ class TestEncoderEmbeddingModels(CustomTestCase): for attention_backend in ATTENTION_BACKEND: for batch_size in BATCH_SIZE: for torch_dtype in TORCH_DTYPES: + # NOTE: FlashInfer currently has limitations with head_dim = 32 or + # other dimensions. + # The FlashInfer head_dim limitation itself is tracked here: + # https://github.com/flashinfer-ai/flashinfer/issues/1048 + # + # Flashinfer does not support torch.float32 for dtype_q, so skip it + if attention_backend == "flashinfer": + if ( + model == "BAAI/bge-small-en" + or torch_dtype == torch.float32 + ): + continue + self.assert_close_prefill_logits( DEFAULT_PROMPTS, model,