support using fa4 on deepseek on blackwell (#9928)
This commit is contained in:
@@ -4,9 +4,15 @@
|
||||
# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.1.0.
|
||||
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
@@ -20,6 +26,22 @@ def maybe_contiguous(x):
|
||||
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
||||
|
||||
|
||||
def _reason_recompile(compile_key, jit_func):
|
||||
compile_cache = jit_func.compile_cache
|
||||
compile_key_map = jit_func.compile_key_map
|
||||
if not compile_cache:
|
||||
return "not compiled yet"
|
||||
for k, v in compile_cache.items():
|
||||
if k == compile_key:
|
||||
continue
|
||||
if len(k) != len(compile_key):
|
||||
continue
|
||||
for i in range(len(k)):
|
||||
if k[i] != compile_key[i]:
|
||||
return f"diff at '{compile_key_map[i]}': {k[i]} vs {compile_key[i]} "
|
||||
return "unknown reason"
|
||||
|
||||
|
||||
torch2cute_dtype_map = {
|
||||
torch.float16: cutlass.Float16,
|
||||
torch.bfloat16: cutlass.BFloat16,
|
||||
@@ -254,6 +276,9 @@ def _flash_attn_fwd(
|
||||
compute_capability,
|
||||
)
|
||||
if compile_key not in _flash_attn_fwd.compile_cache:
|
||||
logger.info(
|
||||
f"Compiling FA4 kernel with reason: {_reason_recompile(compile_key, _flash_attn_fwd)}"
|
||||
)
|
||||
if compute_capability == 9:
|
||||
assert page_table is None, "paged KV not supported on SM 9.0"
|
||||
# fa_fwd = FlashAttentionForwardSm80(
|
||||
@@ -335,8 +360,85 @@ def _flash_attn_fwd(
|
||||
|
||||
|
||||
_flash_attn_fwd.compile_cache = {}
|
||||
_flash_attn_fwd.compile_key_map = [
|
||||
"dtype",
|
||||
"head_dim",
|
||||
"head_dim_v",
|
||||
"qhead_per_kvhead",
|
||||
"causal",
|
||||
"softcap is not None",
|
||||
"lse is None",
|
||||
"cu_seqlens_q is None",
|
||||
"cu_seqlens_k is None",
|
||||
"seqused_q is None",
|
||||
"seqused_k is None",
|
||||
"page_table is not None",
|
||||
"window_size_left is not None",
|
||||
"window_size_right is not None",
|
||||
"learnable_sink is not None",
|
||||
"m_block_size",
|
||||
"n_block_size",
|
||||
"num_threads",
|
||||
"pack_gqa",
|
||||
"compute_capability",
|
||||
]
|
||||
|
||||
|
||||
def warmup_flash_attn(f):
|
||||
"""
|
||||
Decorator for flash_attn_varlen_func:
|
||||
- On the first call, run several warmup passes with different flag combinations
|
||||
- Warmups are executed sequentially to minimize peak GPU memory usage
|
||||
- Does not modify user-provided tensors (clones data)
|
||||
- Easy to extend with more compile-key dimensions
|
||||
"""
|
||||
done = False
|
||||
|
||||
def _clone_args(args, kwargs):
|
||||
"""Clone tensor arguments to avoid sharing storage; deepcopy for others."""
|
||||
|
||||
def maybe_clone(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.clone()
|
||||
return copy.deepcopy(x)
|
||||
|
||||
return tuple(maybe_clone(a) for a in args), {
|
||||
k: maybe_clone(v) for k, v in kwargs.items()
|
||||
}
|
||||
|
||||
def _run_warmups(args, kwargs):
|
||||
"""Run warmup calls sequentially and release memory after each."""
|
||||
base_args, base_kwargs = _clone_args(args, kwargs)
|
||||
|
||||
# Warmup combinations for return_softmax_lse and causal
|
||||
combos = [
|
||||
dict(return_softmax_lse=False, causal=False),
|
||||
dict(return_softmax_lse=False, causal=True),
|
||||
dict(return_softmax_lse=True, causal=False),
|
||||
dict(return_softmax_lse=True, causal=True),
|
||||
]
|
||||
|
||||
for combo in combos:
|
||||
wa, wk = _clone_args(base_args, base_kwargs)
|
||||
wk.update(combo)
|
||||
with torch.cuda.stream(torch.cuda.current_stream()):
|
||||
f(*wa, **wk)
|
||||
del wa, wk
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal done
|
||||
if not done:
|
||||
logger.info("Running flash_attn_varlen_func warmup passes...")
|
||||
_run_warmups(args, kwargs)
|
||||
done = True
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@warmup_flash_attn
|
||||
def flash_attn_varlen_func(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user