[Feat] Add sparse attn to sgl-kernel (#5327)
This commit is contained in:
175
sgl-kernel/python/sgl_kernel/sparse_flash_attn.py
Normal file
175
sgl-kernel/python/sgl_kernel/sparse_flash_attn.py
Normal file
@@ -0,0 +1,175 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def maybe_contiguous(x):
|
||||
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
||||
|
||||
|
||||
def sparse_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
softcap=0.0, # 0.0 means deactivated
|
||||
alibi_slopes=None,
|
||||
deterministic=False,
|
||||
return_attn_probs=False,
|
||||
*,
|
||||
return_softmax_lse=False,
|
||||
out=None,
|
||||
):
|
||||
"""Compute attention with vertical and slash sparsity patterns.
|
||||
Most Arguments are the same with the flash_attn_func interface, except for 4 extra args:
|
||||
block_count and block_offset for slash sparsity patterns, and
|
||||
column_count and column_index for vertical sparsity patterns.
|
||||
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
|
||||
|
||||
Arguments:
|
||||
q: (batch_size, seqlen, nheads, headdim)
|
||||
k: (batch_size, seqlen, nheads_k, headdim)
|
||||
v: (batch_size, seqlen, nheads_k, headdim)
|
||||
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
|
||||
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
||||
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
||||
is added to the attention score of query i and key j.
|
||||
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
||||
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
Return:
|
||||
out: (batch_size, seqlen, nheads, headdim).
|
||||
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
out, softmax_lse = torch.ops.sgl_kernel.fwd_sparse.default(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
out,
|
||||
alibi_slopes,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
causal,
|
||||
softcap,
|
||||
return_attn_probs and dropout_p > 0,
|
||||
None,
|
||||
)
|
||||
return (out, softmax_lse) if return_softmax_lse else out
|
||||
|
||||
|
||||
def sparse_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
softcap=0.0, # 0.0 means deactivated
|
||||
alibi_slopes=None,
|
||||
deterministic=False,
|
||||
return_attn_probs=False,
|
||||
*,
|
||||
return_softmax_lse=False,
|
||||
out=None,
|
||||
):
|
||||
"""Compute attention with vertical and slash sparsity patterns.
|
||||
Most Arguments are the same with the flash_attn_varlen_func interface, except for 4 extra args:
|
||||
block_count and block_offset for slash sparsity patterns, and
|
||||
column_count and column_index for vertical sparsity patterns.
|
||||
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
|
||||
|
||||
Arguments:
|
||||
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||||
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
|
||||
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
|
||||
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into q.
|
||||
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into kv.
|
||||
max_seqlen_q: int. Maximum query sequence length in the batch.
|
||||
max_seqlen_k: int. Maximum key sequence length in the batch.
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
softcap: float. Anything > 0 activates softcapping attention.
|
||||
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
||||
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
||||
is added to the attention score of query i and key j.
|
||||
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
||||
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
out, softmax_lse = torch.ops.sgl_kernel.varlen_fwd_sparse.default(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
out,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
None,
|
||||
alibi_slopes,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
softcap,
|
||||
return_attn_probs and dropout_p > 0,
|
||||
None,
|
||||
)
|
||||
return (out, softmax_lse) if return_softmax_lse else out
|
||||
Reference in New Issue
Block a user