Enable FlashInfer support encoder models and add head_dim padding workaround (#6230)
This commit is contained in:
@@ -25,6 +25,7 @@ from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.layers.radix_attention import AttentionType
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
@@ -486,12 +487,20 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
v_scale=layer.v_scale,
|
||||
)
|
||||
else:
|
||||
causal = True
|
||||
if layer.attn_type == AttentionType.ENCODER_ONLY:
|
||||
save_kv_cache = False
|
||||
causal = False
|
||||
|
||||
if self.forward_metadata.extend_no_prefix:
|
||||
# NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions
|
||||
# The FlashInfer head_dim limitation itself is tracked here:
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/1048
|
||||
o = self.prefill_wrapper_ragged.forward(
|
||||
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
||||
causal=True,
|
||||
causal=causal,
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user