[2/2] Support MHA prefill with FlashAttention 4. (#10937)
Co-authored-by: Hieu Pham <hyhieu@gmail.com>
This commit is contained in:
@@ -53,7 +53,7 @@ dependencies = [
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"setproctitle",
|
||||
"sgl-kernel==0.3.14.post1",
|
||||
"sgl-kernel==0.3.15",
|
||||
"soundfile==0.13.1",
|
||||
"tiktoken",
|
||||
"timm==1.0.16",
|
||||
|
||||
@@ -65,7 +65,7 @@ tracing = [
|
||||
|
||||
srt = [
|
||||
"sglang[runtime_common]",
|
||||
"sgl-kernel==0.3.14.post1",
|
||||
"sgl-kernel==0.3.15",
|
||||
"torch==2.8.0",
|
||||
"torchaudio==2.8.0",
|
||||
"torchvision",
|
||||
|
||||
@@ -711,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
||||
assert_pkg_version(
|
||||
"sgl-kernel",
|
||||
"0.3.14",
|
||||
"0.3.15",
|
||||
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
||||
)
|
||||
|
||||
|
||||
@@ -129,9 +129,6 @@ def create_flashattention_v3_backend(runner):
|
||||
|
||||
@register_attention_backend("fa4")
|
||||
def create_flashattention_v4_backend(runner):
|
||||
assert (
|
||||
runner.use_mla_backend
|
||||
), "FlashAttention v4 Support is at an early stage, only MLA model supported now"
|
||||
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
|
||||
|
||||
return FlashAttentionBackend(runner, fa_impl_ver=4)
|
||||
|
||||
@@ -754,7 +754,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
# Use Flash Attention for prefill
|
||||
if not self.use_mla:
|
||||
assert self.fa_impl_ver in [3], "Only FA3 support here"
|
||||
# Do multi-head attention
|
||||
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
||||
layer.layer_id
|
||||
|
||||
@@ -1746,16 +1746,10 @@ class ModelRunner:
|
||||
|
||||
def _get_attention_backend(self):
|
||||
"""Init attention kernel backend."""
|
||||
self.decode_attention_backend_str = (
|
||||
self.server_args.decode_attention_backend
|
||||
if self.server_args.decode_attention_backend
|
||||
else self.server_args.attention_backend
|
||||
)
|
||||
self.prefill_attention_backend_str = (
|
||||
self.server_args.prefill_attention_backend
|
||||
if self.server_args.prefill_attention_backend
|
||||
else self.server_args.attention_backend
|
||||
self.prefill_attention_backend_str, self.decode_attention_backend_str = (
|
||||
self.server_args.get_attention_backends()
|
||||
)
|
||||
|
||||
if self.decode_attention_backend_str != self.prefill_attention_backend_str:
|
||||
from sglang.srt.layers.attention.hybrid_attn_backend import (
|
||||
HybridAttnBackend,
|
||||
|
||||
@@ -464,6 +464,19 @@ class ServerArgs:
|
||||
enable_pdmux: bool = False
|
||||
sm_group_num: int = 3
|
||||
|
||||
def get_attention_backends(server_args):
|
||||
prefill_attention_backend_str = (
|
||||
server_args.prefill_attention_backend
|
||||
if server_args.prefill_attention_backend
|
||||
else server_args.attention_backend
|
||||
)
|
||||
decode_attention_backend_str = (
|
||||
server_args.decode_attention_backend
|
||||
if server_args.decode_attention_backend
|
||||
else server_args.attention_backend
|
||||
)
|
||||
return prefill_attention_backend_str, decode_attention_backend_str
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
|
||||
@@ -748,20 +761,28 @@ class ServerArgs:
|
||||
hf_config = self.get_hf_config()
|
||||
model_arch = hf_config.architectures[0]
|
||||
if model_arch in ["GptOssForCausalLM"]:
|
||||
if self.attention_backend is None:
|
||||
if (
|
||||
self.attention_backend is None
|
||||
and self.prefill_attention_backend is None
|
||||
and self.decode_attention_backend is None
|
||||
):
|
||||
if is_cuda() and is_sm100_supported():
|
||||
self.attention_backend = "trtllm_mha"
|
||||
elif is_cuda() and is_sm90_supported():
|
||||
self.attention_backend = "fa3"
|
||||
else:
|
||||
self.attention_backend = "triton"
|
||||
supported_backends = ["triton", "trtllm_mha", "fa3"]
|
||||
logger.info(
|
||||
f"Use {self.attention_backend} as attention backend for GptOssForCausalLM"
|
||||
)
|
||||
|
||||
supported_backends = ["triton", "trtllm_mha", "fa3", "fa4"]
|
||||
prefill_attn_backend, decode_attn_backend = self.get_attention_backends()
|
||||
assert (
|
||||
self.attention_backend in supported_backends
|
||||
), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
|
||||
prefill_attn_backend in supported_backends
|
||||
and decode_attn_backend in supported_backends
|
||||
), (
|
||||
f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got the following backends\n"
|
||||
f"- Prefill: {prefill_attn_backend}\n"
|
||||
f"- Decode: {decode_attn_backend}\n"
|
||||
)
|
||||
|
||||
if is_sm100_supported():
|
||||
if not self.enable_dp_attention:
|
||||
|
||||
Reference in New Issue
Block a user