Enable FlashInfer support encoder models and add head_dim padding workaround (#6230)

This commit is contained in:
Clay
2025-07-20 10:30:16 +08:00
committed by GitHub
parent 282eb59ff3
commit cbdfb77123
2 changed files with 25 additions and 3 deletions

View File

@@ -25,6 +25,7 @@ from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
@@ -486,12 +487,20 @@ class FlashInferAttnBackend(AttentionBackend):
v_scale=layer.v_scale,
)
else:
causal = True
if layer.attn_type == AttentionType.ENCODER_ONLY:
save_kv_cache = False
causal = False
if self.forward_metadata.extend_no_prefix:
# NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions
# The FlashInfer head_dim limitation itself is tracked here:
# https://github.com/flashinfer-ai/flashinfer/issues/1048
o = self.prefill_wrapper_ragged.forward(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.head_dim),
causal=True,
causal=causal,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)