support using fa4 on deepseek on blackwell (#9928)
This commit is contained in:
@@ -666,6 +666,13 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
if os.environ.get("TRTLLM_ENABLE_PDL", "1") != "0":
|
if os.environ.get("TRTLLM_ENABLE_PDL", "1") != "0":
|
||||||
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
||||||
|
|
||||||
|
if os.environ.get("CUTE_DSL_LOG_LEVEL") is None:
|
||||||
|
# Default to warning level, to avoid too many logs
|
||||||
|
os.environ["CUTE_DSL_LOG_LEVEL"] = "30"
|
||||||
|
if os.environ.get("CUTE_DSL_LOG_TO_CONSOLE") is None:
|
||||||
|
# Need to set log to console, otherwise the log level won't take effect
|
||||||
|
os.environ["CUTE_DSL_LOG_TO_CONSOLE"] = "1"
|
||||||
|
|
||||||
# Can also be passed as argument
|
# Can also be passed as argument
|
||||||
os.environ["SGLANG_RUN_ID"] = (
|
os.environ["SGLANG_RUN_ID"] = (
|
||||||
f"sglang-run-{time.time()}-{random.randint(0, 100000000)}"
|
f"sglang-run-{time.time()}-{random.randint(0, 100000000)}"
|
||||||
|
|||||||
@@ -305,6 +305,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
speculative_step_id=0,
|
speculative_step_id=0,
|
||||||
topk=0,
|
topk=0,
|
||||||
speculative_num_steps=0,
|
speculative_num_steps=0,
|
||||||
|
fa_impl_ver=3,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -338,6 +339,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
self.speculative_step_id = speculative_step_id
|
self.speculative_step_id = speculative_step_id
|
||||||
|
|
||||||
|
self.fa_impl_ver = fa_impl_ver
|
||||||
|
|
||||||
# Local attention settings
|
# Local attention settings
|
||||||
self.attention_chunk_size = (
|
self.attention_chunk_size = (
|
||||||
model_runner.attention_chunk_size
|
model_runner.attention_chunk_size
|
||||||
@@ -712,6 +715,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
if self.fa_impl_ver != 3:
|
||||||
|
kwargs["ver"] = self.fa_impl_ver
|
||||||
if sinks is not None:
|
if sinks is not None:
|
||||||
kwargs["sinks"] = sinks
|
kwargs["sinks"] = sinks
|
||||||
|
|
||||||
@@ -738,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
# Use Flash Attention for prefill
|
# Use Flash Attention for prefill
|
||||||
if not self.use_mla:
|
if not self.use_mla:
|
||||||
|
assert self.fa_impl_ver in [3], "Only FA3 support here"
|
||||||
# Do multi-head attention
|
# Do multi-head attention
|
||||||
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
||||||
layer.layer_id
|
layer.layer_id
|
||||||
@@ -830,6 +836,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
softmax_scale=layer.scaling,
|
softmax_scale=layer.scaling,
|
||||||
causal=False,
|
causal=False,
|
||||||
return_softmax_lse=True,
|
return_softmax_lse=True,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# MHA for extend part of sequence without attending prefix kv cache
|
# MHA for extend part of sequence without attending prefix kv cache
|
||||||
@@ -844,6 +851,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
softmax_scale=layer.scaling,
|
softmax_scale=layer.scaling,
|
||||||
causal=True,
|
causal=True,
|
||||||
return_softmax_lse=forward_batch.mha_return_lse,
|
return_softmax_lse=forward_batch.mha_return_lse,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
if forward_batch.mha_return_lse:
|
if forward_batch.mha_return_lse:
|
||||||
output, lse, *rest = output
|
output, lse, *rest = output
|
||||||
@@ -851,6 +859,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
return output, lse
|
return output, lse
|
||||||
return output
|
return output
|
||||||
else:
|
else:
|
||||||
|
assert self.fa_impl_ver in [3], "Only FA3 support here"
|
||||||
# Do absorbed multi-latent attention
|
# Do absorbed multi-latent attention
|
||||||
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
||||||
layer.layer_id
|
layer.layer_id
|
||||||
@@ -939,6 +948,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
k_rope: Optional[torch.Tensor] = None,
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
sinks: Optional[torch.Tensor] = None,
|
sinks: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
assert self.fa_impl_ver in [3], "Only FA3 support decoding"
|
||||||
if k is not None:
|
if k is not None:
|
||||||
assert v is not None
|
assert v is not None
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
@@ -985,6 +995,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
if self.fa_impl_ver != 3:
|
||||||
|
kwargs["ver"] = self.fa_impl_ver
|
||||||
if sinks is not None:
|
if sinks is not None:
|
||||||
kwargs["sinks"] = sinks
|
kwargs["sinks"] = sinks
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ class HybridAttnBackend(AttentionBackend):
|
|||||||
self.model_runner = model_runner
|
self.model_runner = model_runner
|
||||||
self.prefill_backend = prefill_backend
|
self.prefill_backend = prefill_backend
|
||||||
self.decode_backend = decode_backend
|
self.decode_backend = decode_backend
|
||||||
|
self.data_type = model_runner.kv_cache_dtype
|
||||||
|
|
||||||
def _select_backend(self, forward_mode: ForwardMode) -> AttentionBackend:
|
def _select_backend(self, forward_mode: ForwardMode) -> AttentionBackend:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -516,6 +516,7 @@ class ModelRunner:
|
|||||||
"aiter",
|
"aiter",
|
||||||
"flashinfer",
|
"flashinfer",
|
||||||
"fa3",
|
"fa3",
|
||||||
|
"fa4",
|
||||||
"triton",
|
"triton",
|
||||||
"flashmla",
|
"flashmla",
|
||||||
"cutlass_mla",
|
"cutlass_mla",
|
||||||
@@ -1800,6 +1801,15 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return FlashAttentionBackend(self)
|
return FlashAttentionBackend(self)
|
||||||
|
elif backend_str == "fa4":
|
||||||
|
assert (
|
||||||
|
self.use_mla_backend
|
||||||
|
), "FlashAttention v4 Support is at an early stage, only MLA model supported now"
|
||||||
|
from sglang.srt.layers.attention.flashattention_backend import (
|
||||||
|
FlashAttentionBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashAttentionBackend(self, fa_impl_ver=4)
|
||||||
elif backend_str == "cutlass_mla":
|
elif backend_str == "cutlass_mla":
|
||||||
from sglang.srt.layers.attention.cutlass_mla_backend import (
|
from sglang.srt.layers.attention.cutlass_mla_backend import (
|
||||||
CutlassMLABackend,
|
CutlassMLABackend,
|
||||||
|
|||||||
@@ -1124,6 +1124,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||||
else:
|
else:
|
||||||
return _dispatch_mla_subtype()
|
return _dispatch_mla_subtype()
|
||||||
|
elif attention_backend == "fa4":
|
||||||
|
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
|
||||||
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||||
elif attention_backend == "trtllm_mla":
|
elif attention_backend == "trtllm_mla":
|
||||||
original_mode = getattr(forward_batch, "_original_forward_mode", None)
|
original_mode = getattr(forward_batch, "_original_forward_mode", None)
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -96,6 +96,7 @@ ATTENTION_BACKEND_CHOICES = [
|
|||||||
# NVIDIA specific
|
# NVIDIA specific
|
||||||
"cutlass_mla",
|
"cutlass_mla",
|
||||||
"fa3",
|
"fa3",
|
||||||
|
"fa4",
|
||||||
"flashinfer",
|
"flashinfer",
|
||||||
"flashmla",
|
"flashmla",
|
||||||
"trtllm_mla",
|
"trtllm_mla",
|
||||||
|
|||||||
@@ -4,9 +4,15 @@
|
|||||||
# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.1.0.
|
# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.1.0.
|
||||||
|
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
import cuda.bindings.driver as cuda
|
import cuda.bindings.driver as cuda
|
||||||
import cutlass
|
import cutlass
|
||||||
import cutlass.cute as cute
|
import cutlass.cute as cute
|
||||||
@@ -20,6 +26,22 @@ def maybe_contiguous(x):
|
|||||||
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
||||||
|
|
||||||
|
|
||||||
|
def _reason_recompile(compile_key, jit_func):
|
||||||
|
compile_cache = jit_func.compile_cache
|
||||||
|
compile_key_map = jit_func.compile_key_map
|
||||||
|
if not compile_cache:
|
||||||
|
return "not compiled yet"
|
||||||
|
for k, v in compile_cache.items():
|
||||||
|
if k == compile_key:
|
||||||
|
continue
|
||||||
|
if len(k) != len(compile_key):
|
||||||
|
continue
|
||||||
|
for i in range(len(k)):
|
||||||
|
if k[i] != compile_key[i]:
|
||||||
|
return f"diff at '{compile_key_map[i]}': {k[i]} vs {compile_key[i]} "
|
||||||
|
return "unknown reason"
|
||||||
|
|
||||||
|
|
||||||
torch2cute_dtype_map = {
|
torch2cute_dtype_map = {
|
||||||
torch.float16: cutlass.Float16,
|
torch.float16: cutlass.Float16,
|
||||||
torch.bfloat16: cutlass.BFloat16,
|
torch.bfloat16: cutlass.BFloat16,
|
||||||
@@ -254,6 +276,9 @@ def _flash_attn_fwd(
|
|||||||
compute_capability,
|
compute_capability,
|
||||||
)
|
)
|
||||||
if compile_key not in _flash_attn_fwd.compile_cache:
|
if compile_key not in _flash_attn_fwd.compile_cache:
|
||||||
|
logger.info(
|
||||||
|
f"Compiling FA4 kernel with reason: {_reason_recompile(compile_key, _flash_attn_fwd)}"
|
||||||
|
)
|
||||||
if compute_capability == 9:
|
if compute_capability == 9:
|
||||||
assert page_table is None, "paged KV not supported on SM 9.0"
|
assert page_table is None, "paged KV not supported on SM 9.0"
|
||||||
# fa_fwd = FlashAttentionForwardSm80(
|
# fa_fwd = FlashAttentionForwardSm80(
|
||||||
@@ -335,8 +360,85 @@ def _flash_attn_fwd(
|
|||||||
|
|
||||||
|
|
||||||
_flash_attn_fwd.compile_cache = {}
|
_flash_attn_fwd.compile_cache = {}
|
||||||
|
_flash_attn_fwd.compile_key_map = [
|
||||||
|
"dtype",
|
||||||
|
"head_dim",
|
||||||
|
"head_dim_v",
|
||||||
|
"qhead_per_kvhead",
|
||||||
|
"causal",
|
||||||
|
"softcap is not None",
|
||||||
|
"lse is None",
|
||||||
|
"cu_seqlens_q is None",
|
||||||
|
"cu_seqlens_k is None",
|
||||||
|
"seqused_q is None",
|
||||||
|
"seqused_k is None",
|
||||||
|
"page_table is not None",
|
||||||
|
"window_size_left is not None",
|
||||||
|
"window_size_right is not None",
|
||||||
|
"learnable_sink is not None",
|
||||||
|
"m_block_size",
|
||||||
|
"n_block_size",
|
||||||
|
"num_threads",
|
||||||
|
"pack_gqa",
|
||||||
|
"compute_capability",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def warmup_flash_attn(f):
|
||||||
|
"""
|
||||||
|
Decorator for flash_attn_varlen_func:
|
||||||
|
- On the first call, run several warmup passes with different flag combinations
|
||||||
|
- Warmups are executed sequentially to minimize peak GPU memory usage
|
||||||
|
- Does not modify user-provided tensors (clones data)
|
||||||
|
- Easy to extend with more compile-key dimensions
|
||||||
|
"""
|
||||||
|
done = False
|
||||||
|
|
||||||
|
def _clone_args(args, kwargs):
|
||||||
|
"""Clone tensor arguments to avoid sharing storage; deepcopy for others."""
|
||||||
|
|
||||||
|
def maybe_clone(x):
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
return x.clone()
|
||||||
|
return copy.deepcopy(x)
|
||||||
|
|
||||||
|
return tuple(maybe_clone(a) for a in args), {
|
||||||
|
k: maybe_clone(v) for k, v in kwargs.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def _run_warmups(args, kwargs):
|
||||||
|
"""Run warmup calls sequentially and release memory after each."""
|
||||||
|
base_args, base_kwargs = _clone_args(args, kwargs)
|
||||||
|
|
||||||
|
# Warmup combinations for return_softmax_lse and causal
|
||||||
|
combos = [
|
||||||
|
dict(return_softmax_lse=False, causal=False),
|
||||||
|
dict(return_softmax_lse=False, causal=True),
|
||||||
|
dict(return_softmax_lse=True, causal=False),
|
||||||
|
dict(return_softmax_lse=True, causal=True),
|
||||||
|
]
|
||||||
|
|
||||||
|
for combo in combos:
|
||||||
|
wa, wk = _clone_args(base_args, base_kwargs)
|
||||||
|
wk.update(combo)
|
||||||
|
with torch.cuda.stream(torch.cuda.current_stream()):
|
||||||
|
f(*wa, **wk)
|
||||||
|
del wa, wk
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
nonlocal done
|
||||||
|
if not done:
|
||||||
|
logger.info("Running flash_attn_varlen_func warmup passes...")
|
||||||
|
_run_warmups(args, kwargs)
|
||||||
|
done = True
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@warmup_flash_attn
|
||||||
def flash_attn_varlen_func(
|
def flash_attn_varlen_func(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user