Enable FlashInfer support encoder models and add head_dim padding workaround (#6230)
This commit is contained in:
@@ -25,6 +25,7 @@ from sglang.global_config import global_config
|
|||||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
|
from sglang.srt.layers.radix_attention import AttentionType
|
||||||
from sglang.srt.layers.utils import is_sm100_supported
|
from sglang.srt.layers.utils import is_sm100_supported
|
||||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
@@ -486,12 +487,20 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
v_scale=layer.v_scale,
|
v_scale=layer.v_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
causal = True
|
||||||
|
if layer.attn_type == AttentionType.ENCODER_ONLY:
|
||||||
|
save_kv_cache = False
|
||||||
|
causal = False
|
||||||
|
|
||||||
if self.forward_metadata.extend_no_prefix:
|
if self.forward_metadata.extend_no_prefix:
|
||||||
|
# NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions
|
||||||
|
# The FlashInfer head_dim limitation itself is tracked here:
|
||||||
|
# https://github.com/flashinfer-ai/flashinfer/issues/1048
|
||||||
o = self.prefill_wrapper_ragged.forward(
|
o = self.prefill_wrapper_ragged.forward(
|
||||||
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||||
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
||||||
causal=True,
|
causal=causal,
|
||||||
sm_scale=layer.scaling,
|
sm_scale=layer.scaling,
|
||||||
logits_soft_cap=logits_soft_cap,
|
logits_soft_cap=logits_soft_cap,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -27,9 +27,9 @@ from sglang.test.test_utils import CustomTestCase, get_similarities, is_in_ci
|
|||||||
|
|
||||||
MODELS = [("BAAI/bge-small-en", 1, 1e-5), ("BAAI/bge-m3", 1, 1e-5)]
|
MODELS = [("BAAI/bge-small-en", 1, 1e-5), ("BAAI/bge-m3", 1, 1e-5)]
|
||||||
|
|
||||||
ATTENTION_BACKEND = ["torch_native", "triton"]
|
ATTENTION_BACKEND = ["torch_native", "triton", "flashinfer"]
|
||||||
BATCH_SIZE = [1, 2]
|
BATCH_SIZE = [1, 2]
|
||||||
TORCH_DTYPES = [torch.float32]
|
TORCH_DTYPES = [torch.float32, torch.float16]
|
||||||
sgl_to_st_ratio = []
|
sgl_to_st_ratio = []
|
||||||
|
|
||||||
|
|
||||||
@@ -126,6 +126,19 @@ class TestEncoderEmbeddingModels(CustomTestCase):
|
|||||||
for attention_backend in ATTENTION_BACKEND:
|
for attention_backend in ATTENTION_BACKEND:
|
||||||
for batch_size in BATCH_SIZE:
|
for batch_size in BATCH_SIZE:
|
||||||
for torch_dtype in TORCH_DTYPES:
|
for torch_dtype in TORCH_DTYPES:
|
||||||
|
# NOTE: FlashInfer currently has limitations with head_dim = 32 or
|
||||||
|
# other dimensions.
|
||||||
|
# The FlashInfer head_dim limitation itself is tracked here:
|
||||||
|
# https://github.com/flashinfer-ai/flashinfer/issues/1048
|
||||||
|
#
|
||||||
|
# Flashinfer does not support torch.float32 for dtype_q, so skip it
|
||||||
|
if attention_backend == "flashinfer":
|
||||||
|
if (
|
||||||
|
model == "BAAI/bge-small-en"
|
||||||
|
or torch_dtype == torch.float32
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
self.assert_close_prefill_logits(
|
self.assert_close_prefill_logits(
|
||||||
DEFAULT_PROMPTS,
|
DEFAULT_PROMPTS,
|
||||||
model,
|
model,
|
||||||
|
|||||||
Reference in New Issue
Block a user