[NVIDIA] BUMP FA3 (#11444)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: ishandhanani <82981111+ishandhanani@users.noreply.github.com>
This commit is contained in:
Johnny
2025-10-13 18:30:57 +02:00
committed by GitHub
parent f35f120d70
commit b8c430f1ce
4 changed files with 75 additions and 66 deletions

View File

@@ -43,7 +43,7 @@ def flash_attn_with_kvcache(
qv=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_seqlens: Optional[Union[int, torch.Tensor]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
@@ -57,6 +57,7 @@ def flash_attn_with_kvcache(
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
attention_chunk: Optional[int] = None,
softcap=0.0, # 0.0 means deactivated
rotary_interleaved=True,
scheduler_metadata=None,
@@ -135,6 +136,7 @@ def flash_attn_with_kvcache(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
attention_chunk: Optional[int]. If not None, splits the query into chunks of this size to save memory.
softcap: float. Anything > 0 activates softcapping attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
@@ -214,6 +216,7 @@ def flash_attn_with_kvcache(
]
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
rotary_seqlens = maybe_contiguous(rotary_seqlens)
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
q,
@@ -243,6 +246,7 @@ def flash_attn_with_kvcache(
causal,
window_size[0],
window_size[1],
attention_chunk,
softcap,
rotary_interleaved,
scheduler_metadata,
@@ -272,6 +276,7 @@ def flash_attn_varlen_func(
k_descale=None,
v_descale=None,
window_size=(-1, -1),
attention_chunk: Optional[int] = None,
softcap=0.0,
num_splits=1,
pack_gqa=None,
@@ -321,6 +326,7 @@ def flash_attn_varlen_func(
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
-0.5
)
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
q,
@@ -350,6 +356,7 @@ def flash_attn_varlen_func(
causal,
window_size[0],
window_size[1],
attention_chunk,
softcap,
is_rotary_interleaved=False,
scheduler_metadata=None,