[NVIDIA] FA3/FA4 Fix (#11606)
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
@@ -45,7 +45,7 @@ def flash_attn_with_kvcache(
|
||||
qv=None,
|
||||
rotary_cos=None,
|
||||
rotary_sin=None,
|
||||
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
|
||||
cache_seqlens: Optional[Union[int, torch.Tensor]] = None,
|
||||
cache_batch_idx: Optional[torch.Tensor] = None,
|
||||
cache_leftpad: Optional[torch.Tensor] = None,
|
||||
page_table: Optional[torch.Tensor] = None,
|
||||
@@ -59,6 +59,7 @@ def flash_attn_with_kvcache(
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
window_size=(-1, -1), # -1 means infinite context window
|
||||
attention_chunk: Optional[int] = None,
|
||||
softcap=0.0, # 0.0 means deactivated
|
||||
rotary_interleaved=True,
|
||||
scheduler_metadata=None,
|
||||
@@ -137,6 +138,7 @@ def flash_attn_with_kvcache(
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
||||
attention_chunk: Optional[int]. If not None, splits the query into chunks of this size to save memory.
|
||||
softcap: float. Anything > 0 activates softcapping attention.
|
||||
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
|
||||
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
|
||||
@@ -216,6 +218,7 @@ def flash_attn_with_kvcache(
|
||||
]
|
||||
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
|
||||
rotary_seqlens = maybe_contiguous(rotary_seqlens)
|
||||
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
|
||||
|
||||
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
|
||||
q,
|
||||
@@ -245,6 +248,7 @@ def flash_attn_with_kvcache(
|
||||
causal,
|
||||
window_size[0],
|
||||
window_size[1],
|
||||
attention_chunk,
|
||||
softcap,
|
||||
rotary_interleaved,
|
||||
scheduler_metadata,
|
||||
@@ -263,10 +267,11 @@ def flash_attn_varlen_func(
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
max_seqlen_q=None,
|
||||
max_seqlen_k=None,
|
||||
seqused_q=None,
|
||||
seqused_k=None,
|
||||
page_table=None,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
qv=None,
|
||||
@@ -274,6 +279,7 @@ def flash_attn_varlen_func(
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
window_size=(-1, -1),
|
||||
attention_chunk=0,
|
||||
softcap=0.0,
|
||||
num_splits=1,
|
||||
pack_gqa=None,
|
||||
@@ -293,25 +299,18 @@ def flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
# max_seqlen_q,
|
||||
# max_seqlen_k,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
seqused_q=seqused_q,
|
||||
seqused_k=seqused_k,
|
||||
page_table=page_table,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
# qv=qv,
|
||||
# q_descale=q_descale,
|
||||
# k_descale=k_descale,
|
||||
# v_descale=v_descale,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
# num_splits=num_splits,
|
||||
pack_gqa=pack_gqa,
|
||||
# sm_margin=sm_margin,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
learnable_sink=sinks,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
)
|
||||
|
||||
if not is_fa3_supported():
|
||||
@@ -319,10 +318,15 @@ def flash_attn_varlen_func(
|
||||
"flash_attn at sgl-kernel is only supported on sm90 and above"
|
||||
)
|
||||
|
||||
# FA3 requires max_seqlen_q and max_seqlen_k
|
||||
if max_seqlen_q is None or max_seqlen_k is None:
|
||||
raise ValueError("max_seqlen_q and max_seqlen_k are required for FA3")
|
||||
|
||||
if softmax_scale is None:
|
||||
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
|
||||
-0.5
|
||||
)
|
||||
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
|
||||
|
||||
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
|
||||
q,
|
||||
@@ -352,6 +356,7 @@ def flash_attn_varlen_func(
|
||||
causal,
|
||||
window_size[0],
|
||||
window_size[1],
|
||||
attention_chunk,
|
||||
softcap,
|
||||
is_rotary_interleaved=False,
|
||||
scheduler_metadata=None,
|
||||
|
||||
Reference in New Issue
Block a user