support using fa4 on deepseek on blackwell (#9928)
This commit is contained in:
@@ -666,6 +666,13 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
if os.environ.get("TRTLLM_ENABLE_PDL", "1") != "0":
|
||||
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
||||
|
||||
if os.environ.get("CUTE_DSL_LOG_LEVEL") is None:
|
||||
# Default to warning level, to avoid too many logs
|
||||
os.environ["CUTE_DSL_LOG_LEVEL"] = "30"
|
||||
if os.environ.get("CUTE_DSL_LOG_TO_CONSOLE") is None:
|
||||
# Need to set log to console, otherwise the log level won't take effect
|
||||
os.environ["CUTE_DSL_LOG_TO_CONSOLE"] = "1"
|
||||
|
||||
# Can also be passed as argument
|
||||
os.environ["SGLANG_RUN_ID"] = (
|
||||
f"sglang-run-{time.time()}-{random.randint(0, 100000000)}"
|
||||
|
||||
@@ -305,6 +305,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
speculative_step_id=0,
|
||||
topk=0,
|
||||
speculative_num_steps=0,
|
||||
fa_impl_ver=3,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -338,6 +339,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
)
|
||||
self.speculative_step_id = speculative_step_id
|
||||
|
||||
self.fa_impl_ver = fa_impl_ver
|
||||
|
||||
# Local attention settings
|
||||
self.attention_chunk_size = (
|
||||
model_runner.attention_chunk_size
|
||||
@@ -712,6 +715,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
||||
kwargs = {}
|
||||
if self.fa_impl_ver != 3:
|
||||
kwargs["ver"] = self.fa_impl_ver
|
||||
if sinks is not None:
|
||||
kwargs["sinks"] = sinks
|
||||
|
||||
@@ -738,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
# Use Flash Attention for prefill
|
||||
if not self.use_mla:
|
||||
assert self.fa_impl_ver in [3], "Only FA3 support here"
|
||||
# Do multi-head attention
|
||||
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
||||
layer.layer_id
|
||||
@@ -830,6 +836,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
softmax_scale=layer.scaling,
|
||||
causal=False,
|
||||
return_softmax_lse=True,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
# MHA for extend part of sequence without attending prefix kv cache
|
||||
@@ -844,6 +851,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
return_softmax_lse=forward_batch.mha_return_lse,
|
||||
**kwargs,
|
||||
)
|
||||
if forward_batch.mha_return_lse:
|
||||
output, lse, *rest = output
|
||||
@@ -851,6 +859,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
return output, lse
|
||||
return output
|
||||
else:
|
||||
assert self.fa_impl_ver in [3], "Only FA3 support here"
|
||||
# Do absorbed multi-latent attention
|
||||
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
||||
layer.layer_id
|
||||
@@ -939,6 +948,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.fa_impl_ver in [3], "Only FA3 support decoding"
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
@@ -985,6 +995,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
||||
kwargs = {}
|
||||
if self.fa_impl_ver != 3:
|
||||
kwargs["ver"] = self.fa_impl_ver
|
||||
if sinks is not None:
|
||||
kwargs["sinks"] = sinks
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ class HybridAttnBackend(AttentionBackend):
|
||||
self.model_runner = model_runner
|
||||
self.prefill_backend = prefill_backend
|
||||
self.decode_backend = decode_backend
|
||||
self.data_type = model_runner.kv_cache_dtype
|
||||
|
||||
def _select_backend(self, forward_mode: ForwardMode) -> AttentionBackend:
|
||||
"""
|
||||
|
||||
@@ -516,6 +516,7 @@ class ModelRunner:
|
||||
"aiter",
|
||||
"flashinfer",
|
||||
"fa3",
|
||||
"fa4",
|
||||
"triton",
|
||||
"flashmla",
|
||||
"cutlass_mla",
|
||||
@@ -1800,6 +1801,15 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
return FlashAttentionBackend(self)
|
||||
elif backend_str == "fa4":
|
||||
assert (
|
||||
self.use_mla_backend
|
||||
), "FlashAttention v4 Support is at an early stage, only MLA model supported now"
|
||||
from sglang.srt.layers.attention.flashattention_backend import (
|
||||
FlashAttentionBackend,
|
||||
)
|
||||
|
||||
return FlashAttentionBackend(self, fa_impl_ver=4)
|
||||
elif backend_str == "cutlass_mla":
|
||||
from sglang.srt.layers.attention.cutlass_mla_backend import (
|
||||
CutlassMLABackend,
|
||||
|
||||
@@ -1124,6 +1124,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||
else:
|
||||
return _dispatch_mla_subtype()
|
||||
elif attention_backend == "fa4":
|
||||
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
|
||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||
elif attention_backend == "trtllm_mla":
|
||||
original_mode = getattr(forward_batch, "_original_forward_mode", None)
|
||||
if (
|
||||
|
||||
@@ -96,6 +96,7 @@ ATTENTION_BACKEND_CHOICES = [
|
||||
# NVIDIA specific
|
||||
"cutlass_mla",
|
||||
"fa3",
|
||||
"fa4",
|
||||
"flashinfer",
|
||||
"flashmla",
|
||||
"trtllm_mla",
|
||||
|
||||
Reference in New Issue
Block a user