176 lines
6.7 KiB
Python
176 lines
6.7 KiB
Python
|
|
from typing import List, Optional, Tuple, Union
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
|
||
|
|
|
||
|
|
def maybe_contiguous(x):
|
||
|
|
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
||
|
|
|
||
|
|
|
||
|
|
def sparse_attn_func(
|
||
|
|
q,
|
||
|
|
k,
|
||
|
|
v,
|
||
|
|
block_count,
|
||
|
|
block_offset,
|
||
|
|
column_count,
|
||
|
|
column_index,
|
||
|
|
dropout_p=0.0,
|
||
|
|
softmax_scale=None,
|
||
|
|
causal=False,
|
||
|
|
softcap=0.0, # 0.0 means deactivated
|
||
|
|
alibi_slopes=None,
|
||
|
|
deterministic=False,
|
||
|
|
return_attn_probs=False,
|
||
|
|
*,
|
||
|
|
return_softmax_lse=False,
|
||
|
|
out=None,
|
||
|
|
):
|
||
|
|
"""Compute attention with vertical and slash sparsity patterns.
|
||
|
|
Most Arguments are the same with the flash_attn_func interface, except for 4 extra args:
|
||
|
|
block_count and block_offset for slash sparsity patterns, and
|
||
|
|
column_count and column_index for vertical sparsity patterns.
|
||
|
|
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
|
||
|
|
|
||
|
|
Arguments:
|
||
|
|
q: (batch_size, seqlen, nheads, headdim)
|
||
|
|
k: (batch_size, seqlen, nheads_k, headdim)
|
||
|
|
v: (batch_size, seqlen, nheads_k, headdim)
|
||
|
|
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||
|
|
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
|
||
|
|
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||
|
|
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
|
||
|
|
dropout_p: float. Dropout probability.
|
||
|
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||
|
|
Default to 1 / sqrt(headdim).
|
||
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||
|
|
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
||
|
|
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
||
|
|
is added to the attention score of query i and key j.
|
||
|
|
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
||
|
|
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
||
|
|
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||
|
|
testing only. The returned probabilities are not guaranteed to be correct
|
||
|
|
(they might not have the right scaling).
|
||
|
|
Return:
|
||
|
|
out: (batch_size, seqlen, nheads, headdim).
|
||
|
|
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
|
||
|
|
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||
|
|
normalization factor).
|
||
|
|
"""
|
||
|
|
if softmax_scale is None:
|
||
|
|
softmax_scale = q.shape[-1] ** (-0.5)
|
||
|
|
|
||
|
|
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||
|
|
out, softmax_lse = torch.ops.sgl_kernel.fwd_sparse.default(
|
||
|
|
q,
|
||
|
|
k,
|
||
|
|
v,
|
||
|
|
block_count,
|
||
|
|
block_offset,
|
||
|
|
column_count,
|
||
|
|
column_index,
|
||
|
|
out,
|
||
|
|
alibi_slopes,
|
||
|
|
dropout_p,
|
||
|
|
softmax_scale,
|
||
|
|
causal,
|
||
|
|
softcap,
|
||
|
|
return_attn_probs and dropout_p > 0,
|
||
|
|
None,
|
||
|
|
)
|
||
|
|
return (out, softmax_lse) if return_softmax_lse else out
|
||
|
|
|
||
|
|
|
||
|
|
def sparse_attn_varlen_func(
|
||
|
|
q,
|
||
|
|
k,
|
||
|
|
v,
|
||
|
|
block_count,
|
||
|
|
block_offset,
|
||
|
|
column_count,
|
||
|
|
column_index,
|
||
|
|
cu_seqlens_q,
|
||
|
|
cu_seqlens_k,
|
||
|
|
max_seqlen_q,
|
||
|
|
max_seqlen_k,
|
||
|
|
dropout_p=0.0,
|
||
|
|
softmax_scale=None,
|
||
|
|
causal=False,
|
||
|
|
softcap=0.0, # 0.0 means deactivated
|
||
|
|
alibi_slopes=None,
|
||
|
|
deterministic=False,
|
||
|
|
return_attn_probs=False,
|
||
|
|
*,
|
||
|
|
return_softmax_lse=False,
|
||
|
|
out=None,
|
||
|
|
):
|
||
|
|
"""Compute attention with vertical and slash sparsity patterns.
|
||
|
|
Most Arguments are the same with the flash_attn_varlen_func interface, except for 4 extra args:
|
||
|
|
block_count and block_offset for slash sparsity patterns, and
|
||
|
|
column_count and column_index for vertical sparsity patterns.
|
||
|
|
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
|
||
|
|
|
||
|
|
Arguments:
|
||
|
|
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||
|
|
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||
|
|
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||
|
|
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||
|
|
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
|
||
|
|
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||
|
|
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
|
||
|
|
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||
|
|
of the sequences in the batch, used to index into q.
|
||
|
|
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||
|
|
of the sequences in the batch, used to index into kv.
|
||
|
|
max_seqlen_q: int. Maximum query sequence length in the batch.
|
||
|
|
max_seqlen_k: int. Maximum key sequence length in the batch.
|
||
|
|
dropout_p: float. Dropout probability.
|
||
|
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||
|
|
Default to 1 / sqrt(headdim).
|
||
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||
|
|
softcap: float. Anything > 0 activates softcapping attention.
|
||
|
|
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
||
|
|
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
||
|
|
is added to the attention score of query i and key j.
|
||
|
|
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
||
|
|
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
||
|
|
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||
|
|
testing only. The returned probabilities are not guaranteed to be correct
|
||
|
|
(they might not have the right scaling).
|
||
|
|
Return:
|
||
|
|
out: (total, nheads, headdim).
|
||
|
|
softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
|
||
|
|
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||
|
|
normalization factor).
|
||
|
|
"""
|
||
|
|
if softmax_scale is None:
|
||
|
|
softmax_scale = q.shape[-1] ** (-0.5)
|
||
|
|
|
||
|
|
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||
|
|
out, softmax_lse = torch.ops.sgl_kernel.varlen_fwd_sparse.default(
|
||
|
|
q,
|
||
|
|
k,
|
||
|
|
v,
|
||
|
|
block_count,
|
||
|
|
block_offset,
|
||
|
|
column_count,
|
||
|
|
column_index,
|
||
|
|
out,
|
||
|
|
cu_seqlens_q,
|
||
|
|
cu_seqlens_k,
|
||
|
|
None,
|
||
|
|
alibi_slopes,
|
||
|
|
max_seqlen_q,
|
||
|
|
max_seqlen_k,
|
||
|
|
dropout_p,
|
||
|
|
softmax_scale,
|
||
|
|
False,
|
||
|
|
causal,
|
||
|
|
softcap,
|
||
|
|
return_attn_probs and dropout_p > 0,
|
||
|
|
None,
|
||
|
|
)
|
||
|
|
return (out, softmax_lse) if return_softmax_lse else out
|