first commit

This commit is contained in:
2026-03-10 13:31:25 +08:00
parent ba974cecfa
commit b62b889355
2604 changed files with 438977 additions and 0 deletions

View File

View File

@@ -0,0 +1,12 @@
__version__ = "2.7.2.post1"
# Use relative import to support build-from-source installation in vLLM
from .flash_attn_interface import (
flash_attn_varlen_func,
flash_attn_with_kvcache,
get_scheduler_metadata,
sparse_attn_func,
sparse_attn_varlen_func,
is_fa_version_supported,
fa_version_unsupported_reason
)

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,614 @@
# Copyright (c) 2023, Tri Dao.
from typing import Optional, Union, Tuple, List
import torch
import torch.nn as nn
# isort: off
# We need to import the CUDA kernels after importing torch
# Use relative import to support build-from-source installation in vLLM
try:
from . import _vllm_fa2_C # noqa: F401
FA2_UNAVAILABLE_REASON = None
FA2_AVAILABLE = True
except ImportError as e:
FA2_UNAVAILABLE_REASON = str(e)
FA2_AVAILABLE = False
try:
from . import _vllm_fa3_C # noqa: F401
FA3_UNAVAILABLE_REASON = None
FA3_AVAILABLE = True
except ImportError as e:
FA3_UNAVAILABLE_REASON = str(e)
FA3_AVAILABLE = False
# isort: on
DEFAULT_FA_VERSION = 2
def _is_fa2_supported(device = None) -> Tuple[bool, Optional[str]]:
if not FA2_AVAILABLE:
return False, f"FA2 is unavaible due to: {FA2_UNAVAILABLE_REASON}"
if torch.cuda.get_device_capability(device)[0] < 8:
return False, \
"FA2 is only supported on devices with compute capability >= 8"
return True, None
def _is_fa3_supported(device = None) -> Tuple[bool, Optional[str]]:
if not FA3_AVAILABLE:
return False, f"FA3 is unavaible due to: {FA3_UNAVAILABLE_REASON}"
if torch.cuda.get_device_capability(device)[0] < 8 \
or torch.cuda.get_device_capability(device)[0] >= 10 \
or torch.cuda.get_device_capability(device) == (8, 6) \
or torch.cuda.get_device_capability(device) == (8, 9):
return False, \
"FA3 is only supported on devices with compute capability >= 8" \
" excluding 8.6 and 8.9 and Blackwell archs (>=10)"
return True, None
def is_fa_version_supported(fa_version: int, device = None) -> bool:
assert fa_version in [2, 3], f"Unsupported FA version: {fa_version}"
if fa_version == 2:
return _is_fa2_supported(device)[0]
elif fa_version == 3:
return _is_fa3_supported(device)[0]
def fa_version_unsupported_reason(fa_version: int, device = None) \
-> Optional[str]:
assert fa_version in [2, 3], f"Unsupported FA version: {fa_version}"
if fa_version == 2:
return _is_fa2_supported(device)[1]
elif fa_version == 3:
return _is_fa3_supported(device)[1]
#
# For vLLM we only care about `flash_attn_varlen_func` and
# `flash_attn_with_kvcache` so we only maintain wrappers for these two.
#
def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
# NOTE only used in FA3
def get_scheduler_metadata(
batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
cache_seqlens: torch.Tensor,
qkv_dtype=torch.bfloat16,
headdim_v=None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
page_size: Optional[int] = None,
max_seqlen_k_new=0,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
has_softcap=False,
num_splits=0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
):
cache_seqlens = maybe_contiguous(cache_seqlens)
if headdim_v is None:
headdim_v = headdim
scheduler_metadata = torch.ops._vllm_fa3_C.get_scheduler_metadata(
batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
qkv_dtype,
cache_seqlens,
cu_seqlens_q,
None, # cu_seqlens_k
cu_seqlens_k_new,
None, # seqused_q
cache_leftpad,
page_size,
max_seqlen_k_new,
causal,
window_size[0], window_size[1],
has_softcap,
num_splits,
pack_gqa,
sm_margin,
)
return scheduler_metadata
def flash_attn_varlen_func(
q,
k,
v,
max_seqlen_q,
cu_seqlens_q,
max_seqlen_k,
cu_seqlens_k=None, # only used for non-paged prefill
seqused_k=None,
q_v=None,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size: Optional[List[int]] = None,
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
block_table=None,
return_softmax_lse=False,
out=None,
# FA3 Only
scheduler_metadata=None,
q_descale=None,
k_descale=None,
v_descale=None,
num_splits: int = 0,
# Version selector
fa_version: int = DEFAULT_FA_VERSION,
s_aux=None,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
assert cu_seqlens_k is not None or seqused_k is not None, \
"cu_seqlens_k or seqused_k must be provided"
assert cu_seqlens_k is None or seqused_k is None, \
"cu_seqlens_k and seqused_k cannot be provided at the same time"
assert block_table is None or seqused_k is not None, \
"seqused_k must be provided if block_table is provided"
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
# custom op does not support non-tuple input
real_window_size: Tuple[int, int]
if window_size is None:
real_window_size = (-1, -1)
else:
assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1])
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q)
if fa_version == 2:
if scheduler_metadata is not None and q_descale is not None \
and k_descale is not None and v_descale is not None:
raise NotImplementedError(
"FA2 does not support scheduler_metadata, q_descale, "
"k_descale, v_descale"
)
if s_aux is not None:
raise NotImplementedError("FA2 does not support s_aux")
if num_splits > 1:
raise NotImplementedError("FA2 does not support num_splits > 1")
out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd(
q, k, v,
out,
cu_seqlens_q,
# cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
# still wants it so we pass all zeros
dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k,
seqused_k,
None,
block_table,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
False,
causal,
real_window_size[0],
real_window_size[1],
softcap,
return_softmax_lse and dropout_p > 0,
None,
)
elif fa_version == 3:
assert alibi_slopes is None, "Alibi is not supported in FA3"
out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd(
q, k, v,
None, None, # k_new, v_new
q_v,
out,
cu_seqlens_q,
cu_seqlens_k, # cu_seqlens_k
None, # cu_seqlens_k_new
None, seqused_k, # seqused_q, seqused_k
max_seqlen_q, max_seqlen_k,
block_table,
None, # kv_batch_idx
None, # leftpad_k
None, None, None, # rotary_cos, rotary_sin, seqlens_rotary
q_descale, k_descale, v_descale,
softmax_scale,
causal,
real_window_size[0], real_window_size[1],
softcap,
True, # rotary_interleaved
scheduler_metadata,
num_splits,
None, # pack_gqa
0, # sm_margin
s_aux # s_aux
)
else:
raise ValueError(f"Unsupported FA version: {fa_version}")
return (out, softmax_lse) if return_softmax_lse else out
def flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=None,
v=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
rotary_interleaved=True,
alibi_slopes=None,
num_splits=0,
return_softmax_lse=False,
*,
out=None,
# FA3 Only
scheduler_metadata=None,
q_descale=None,
k_descale=None,
v_descale=None,
# Version selector
fa_version: int = DEFAULT_FA_VERSION,
s_aux=None,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
the previous step, and update them with the new keys/values from the current step, and do
attention with the updated cache, all in 1 kernel.
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Note: Does not support backward pass.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
page_block_size must be a multiple of 256.
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
k with k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache.
block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
If the indices are not distinct, and k and v are provided, the values updated in the cache
might come from any of the duplicate indices.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
(i.e. GPT-NeoX style).
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
Don't change this unless you know what you are doing.
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if cache_seqlens is not None and isinstance(cache_seqlens, int):
cache_seqlens = torch.full(
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
)
cache_seqlens = maybe_contiguous(cache_seqlens)
cache_batch_idx = maybe_contiguous(cache_batch_idx)
block_table = maybe_contiguous(block_table)
if s_aux is not None:
raise NotImplementedError("FA2 does not support s_aux")
if scheduler_metadata is not None and q_descale is not None \
and k_descale is not None and v_descale is not None:
raise NotImplementedError(
"FA2 does not support scheduler_metadata, q_descale, "
"k_descale, v_descale"
)
out, softmax_lse = torch.ops._vllm_fa2_C.fwd_kvcache(
q, k_cache, v_cache,
k, v, # k_new, v_new
cache_seqlens,
rotary_cos,
rotary_sin,
cache_batch_idx,
cache_leftpad,
block_table,
alibi_slopes,
out,
softmax_scale,
causal,
window_size[0],
window_size[1],
softcap,
rotary_interleaved,
num_splits,
)
return (out, softmax_lse) if return_softmax_lse else out
def sparse_attn_func(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
dropout_p=0.0,
softmax_scale=None,
causal=False,
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
return_softmax_lse=False,
out=None,
):
"""Compute attention with vertical and slash sparsity patterns.
Most Arguments are the same with the flash_attn_func interface, except for 4 extra args:
block_count and block_offset for slash sparsity patterns, and
column_count and column_index for vertical sparsity patterns.
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse = torch.ops._vllm_fa2_C.fwd_sparse(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
out,
alibi_slopes,
dropout_p,
softmax_scale,
causal,
softcap,
return_attn_probs and dropout_p > 0,
None,
)
return (out, softmax_lse) if return_softmax_lse else out
def sparse_attn_varlen_func(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=False,
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
return_softmax_lse=False,
out=None,
):
"""Compute attention with vertical and slash sparsity patterns.
Most Arguments are the same with the flash_attn_varlen_func interface, except for 4 extra args:
block_count and block_offset for slash sparsity patterns, and
column_count and column_index for vertical sparsity patterns.
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd_sparse(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
out,
cu_seqlens_q,
cu_seqlens_k,
None,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
False,
causal,
softcap,
return_attn_probs and dropout_p > 0,
None,
)
return (out, softmax_lse) if return_softmax_lse else out

View File

View File

@@ -0,0 +1,530 @@
# Adapted from https://github.com/vllm-project/flash-attention/blob/main/flash_attn/layers/rotary.py
# Modified lines are marked with `# modified from original` comment
# Copyright (c) 2023, Tri Dao.
import math
from typing import Optional, Tuple, Union
import torch
from einops import rearrange, repeat
from ..ops.triton.rotary import apply_rotary # modified from original
def rotate_half(x, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
return torch.cat(
[x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
dim=-1,
)
class ApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x,
cos,
sin,
interleaved=False,
inplace=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
):
out = apply_rotary(
x,
cos,
sin,
seqlen_offsets=seqlen_offsets,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
interleaved=interleaved,
inplace=inplace,
)
if isinstance(seqlen_offsets, int):
ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
ctx.seqlen_offsets = seqlen_offsets
else:
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
ctx.seqlen_offsets = None
ctx.interleaved = interleaved
ctx.inplace = inplace
ctx.max_seqlen = max_seqlen
return out if not inplace else x
@staticmethod
def backward(ctx, do):
seqlen_offsets = ctx.seqlen_offsets
if seqlen_offsets is None:
cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
else:
cos, sin, cu_seqlens = ctx.saved_tensors
# TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
# "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
if not ctx.interleaved and not ctx.inplace:
do = do.clone()
dx = apply_rotary(
do,
cos,
sin,
seqlen_offsets=seqlen_offsets,
cu_seqlens=cu_seqlens,
max_seqlen=ctx.max_seqlen,
interleaved=ctx.interleaved,
inplace=ctx.inplace,
conjugate=True,
)
return dx, None, None, None, None, None, None, None
def apply_rotary_emb(
x,
cos,
sin,
interleaved=False,
inplace=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
):
"""
Arguments:
x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim)
cos, sin: (seqlen_rotary, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
inplace: if True, apply rotary embedding in-place.
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Return:
out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
return ApplyRotaryEmb.apply(
x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
)
# For backward compatibility
apply_rotary_emb_func = apply_rotary_emb
class ApplyRotaryEmbQKV_(torch.autograd.Function):
@staticmethod
def forward(
ctx,
qkv,
cos,
sin,
cos_k=None,
sin_k=None,
interleaved=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
num_heads_q: Union[int] = None,
):
if cos_k is None and sin_k is None and qkv.is_contiguous():
# Call 1 kernel instead of 2 kernels
# We need qkv to be contiguous so that when we reshape to combine (3, nheads)
# dimensions, we get the same tensor
if qkv.dim() == 5:
batch, seqlen, three, nheads, headdim = qkv.shape
assert three == 3
# qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
else:
assert qkv.dim() == 4
assert num_heads_q is not None
num_heads_k = (qkv.shape[2] - num_heads_q) // 2
assert qkv.shape[2] == num_heads_q + 2 * num_heads_k
qk = qkv[:, :, :num_heads_q + num_heads_k]
apply_rotary(
qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
)
else:
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
if qkv.dim() == 5:
q, k = qkv[:, :, 0], qkv[:, :, 1]
else:
assert qkv.dim() == 4
assert num_heads_q is not None
num_heads_k = (qkv.shape[2] - num_heads_q) // 2
assert qkv.shape[2] == num_heads_q + 2 * num_heads_k
q, k = qkv[:, :, :num_heads_q], qkv[:, :, num_heads_q : num_heads_q + num_heads_k]
apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True)
apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True)
ctx.save_for_backward(cos, sin, cos_k, sin_k)
if isinstance(seqlen_offsets, int):
ctx.save_for_backward(cos, sin, cos_k, sin_k)
ctx.seqlen_offsets = seqlen_offsets
else:
ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets)
ctx.seqlen_offsets = None
ctx.interleaved = interleaved
ctx.num_heads_q = num_heads_q
return qkv
@staticmethod
def backward(ctx, dqkv):
seqlen_offsets = ctx.seqlen_offsets
if seqlen_offsets is None:
cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors
else:
cos, sin, cos_k, sin_k = ctx.saved_tensors
if cos_k is None and sin_k is None and dqkv.is_contiguous():
# Call 1 kernel instead of 2 kernels
# We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
# dimensions, we get the same tensor
if dqkv.dim() == 5:
dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d")
else:
assert dqkv.dim() == 4
assert ctx.num_heads_q is not None
num_heads_k = (dqkv.shape[2] - ctx.num_heads_q) // 2
assert dqkv.shape[2] == ctx.num_heads_q + 2 * num_heads_k
dqk = dqkv[:, :, : ctx.num_heads_q + num_heads_k]
apply_rotary(
dqk,
cos,
sin,
seqlen_offsets=seqlen_offsets,
interleaved=ctx.interleaved,
inplace=True,
conjugate=True,
)
else:
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
if dqkv.dim() == 5:
dq, dk = dqkv[:, :, 0], dqkv[:, :, 1]
else:
assert dqkv.dim() == 4
assert ctx.num_heads_q is not None
num_heads_k = (dqkv.shape[2] - ctx.num_heads_q) // 2
assert dqkv.shape[2] == ctx.num_heads_q + 2 * num_heads_k
dq = dqkv[:, :, : ctx.num_heads_q]
dk = dqkv[:, :, ctx.num_heads_q : ctx.num_heads_q + num_heads_k]
apply_rotary(
dq,
cos,
sin,
seqlen_offsets,
interleaved=ctx.interleaved,
inplace=True,
conjugate=True,
)
apply_rotary(
dk,
cos_k,
sin_k,
seqlen_offsets,
interleaved=ctx.interleaved,
inplace=True,
conjugate=True,
)
return dqkv, None, None, None, None, None, None, None
def apply_rotary_emb_qkv_(
qkv,
cos,
sin,
cos_k=None,
sin_k=None,
interleaved=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
num_heads_q: Optional[int] = None,
):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim) or (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim).
If qkv has shape (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim) (e.g. MQA / GQA),
then num_heads_q must be provided.
cos, sin: (seqlen, rotary_dim / 2)
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
Most commonly used in inference when we have KV cache.
Return:
qkv: (batch_size, seqlen, 3, nheads, headdim) or (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim)
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
"""
return ApplyRotaryEmbQKV_.apply(
qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, num_heads_q
)
class ApplyRotaryEmbKV_(torch.autograd.Function):
@staticmethod
def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0):
batch, seqlen, two, nheads, headdim = kv.shape
assert two == 2
k = kv[:, :, 0]
apply_rotary(
k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
)
if isinstance(seqlen_offsets, int):
ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward
ctx.seqlen_offsets = seqlen_offsets
else:
ctx.save_for_backward(cos, sin, seqlen_offsets)
ctx.seqlen_offsets = None
ctx.interleaved = interleaved
return kv
@staticmethod
def backward(ctx, dkv):
seqlen_offsets = ctx.seqlen_offsets
if seqlen_offsets is None:
cos, sin, seqlen_offsets = ctx.saved_tensors
else:
cos, sin = ctx.saved_tensors
apply_rotary(
dkv[:, :, 0],
cos,
sin,
seqlen_offsets=seqlen_offsets,
interleaved=ctx.interleaved,
inplace=True,
conjugate=True,
)
return dkv, None, None, None, None
apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply
def apply_rotary_emb_kv_(
kv,
cos,
sin,
interleaved=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
):
"""
Arguments:
kv: (batch_size, seqlen, 2, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
Most commonly used in inference when we have KV cache.
Return:
kv: (batch_size, seqlen, 2, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of K.
"""
return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets)
class RotaryEmbedding(torch.nn.Module):
"""
The rotary position embeddings from RoFormer_ (Su et. al).
A crucial insight from the method is that the query and keys are
transformed by rotation matrices which depend on the relative positions.
Other implementations are available in the Rotary Transformer repo_ and in
GPT-NeoX_, GPT-NeoX was an inspiration
.. _RoFormer: https://arxiv.org/abs/2104.09864
.. _repo: https://github.com/ZhuiyiTechnology/roformer
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
"""
def __init__(
self,
dim: int,
base=10000.0,
interleaved=False,
scale_base=None,
pos_idx_in_fp32=True,
device=None,
):
"""
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
otherwise they might be in lower precision.
This option was added because previously (before 2023-07-02), when we construct
the position indices, we use the dtype of self.inv_freq. In most cases this would
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
self.inv_freq would be bf16, and the position indices are also in bf16.
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
embeddings for some positions will coincide.
To maintain compatibility with models previously trained in pure bf16,
we add this option.
"""
super().__init__()
self.dim = dim
self.base = float(base)
self.pos_idx_in_fp32 = pos_idx_in_fp32
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = self._compute_inv_freq(device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.interleaved = interleaved
self.scale_base = scale_base
scale = (
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
if scale_base is not None
else None
)
self.register_buffer("scale", scale, persistent=False)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
def _compute_inv_freq(self, device=None):
return 1.0 / (
self.base
** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
)
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
# Reset the tables if the sequence length has changed,
# if we're on a new device (possibly due to tracing for instance),
# or if we're switching from inference mode to training
if (
seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
or (self.training and self._cos_cached.is_inference())
):
self._seq_len_cached = seqlen
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
if self.pos_idx_in_fp32:
t = torch.arange(seqlen, device=device, dtype=torch.float32)
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
# will be large. Having it in bf16 will lose a lot of precision and cause the
# cos & sin output to change significantly.
# We want to recompute self.inv_freq if it was not loaded in fp32
if self.inv_freq.dtype != torch.float32:
inv_freq = self._compute_inv_freq(device=device)
else:
inv_freq = self.inv_freq
else:
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
inv_freq = self.inv_freq
# Don't do einsum, it converts fp32 to fp16 under AMP
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, inv_freq)
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
else:
power = (
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
- seqlen // 2
) / self.scale_base
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
def forward(
self,
qkv: torch.Tensor,
kv: Optional[torch.Tensor] = None,
seqlen_offset: Union[int, torch.Tensor] = 0,
max_seqlen: Optional[int] = None,
num_heads_q: Optional[int] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
qkv: (batch, seqlen, 3, nheads, headdim) or (batch, seqlen, num_heads_q + 2 * num_heads_k, headdim)
if kv is none, else it's just q of shape (batch, seqlen, nheads, headdim).
If qkv has shape (batch, seqlen, num_heads_q + 2 * num_heads_k, headdim) (e.g. MQA / GQA),
then num_heads_q must be provided.
kv: (batch, seqlen, 2, nheads, headdim)
seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
should pass in max_seqlen, which will update the cos / sin cache up to that length.
Apply rotary embedding *inplace* to qkv and / or kv.
"""
seqlen = qkv.shape[1]
if max_seqlen is not None:
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
elif isinstance(seqlen_offset, int):
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
if kv is None:
if self.scale is None:
return apply_rotary_emb_qkv_(
qkv,
self._cos_cached,
self._sin_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
num_heads_q=num_heads_q,
)
else:
return apply_rotary_emb_qkv_(
qkv,
self._cos_cached,
self._sin_cached,
self._cos_k_cached,
self._sin_k_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
num_heads_q=num_heads_q,
)
else:
q = qkv
q = apply_rotary_emb_func(
q,
self._cos_cached,
self._sin_cached,
interleaved=self.interleaved,
inplace=True,
seqlen_offsets=seqlen_offset,
)
if self.scale is None:
kv = apply_rotary_emb_kv_(
kv,
self._cos_cached,
self._sin_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
)
else:
kv = apply_rotary_emb_kv_(
kv,
self._cos_k_cached,
self._sin_k_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
)
return q, kv

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,229 @@
# Copy from https://github.com/vllm-project/flash-attention/blob/main/flash_attn/ops/triton/rotary.py
# Copyright (c) 2023, Tri Dao.
from typing import Optional, Union
import torch
import triton
import triton.language as tl
@triton.jit
def rotary_kernel(
OUT, # Pointers to matrices
X,
COS,
SIN,
CU_SEQLENS,
SEQLEN_OFFSETS, # this could be int or a pointer
# Matrix dimensions
seqlen,
rotary_dim,
seqlen_ro,
# strides
stride_out_batch,
stride_out_seqlen,
stride_out_nheads,
stride_out_headdim,
stride_x_batch,
stride_x_seqlen,
stride_x_nheads,
stride_x_headdim,
# Meta-parameters
BLOCK_K: tl.constexpr,
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
IS_VARLEN: tl.constexpr,
INTERLEAVED: tl.constexpr,
CONJUGATE: tl.constexpr,
BLOCK_M: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_head = tl.program_id(axis=1)
pid_batch = tl.program_id(axis=2)
rotary_dim_half = rotary_dim // 2
if not IS_VARLEN:
X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
else:
start_idx = tl.load(CU_SEQLENS + pid_batch)
seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
if pid_m * BLOCK_M >= seqlen:
return
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
if not IS_SEQLEN_OFFSETS_TENSOR:
rm_cs = rm + SEQLEN_OFFSETS
else:
rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
rk = tl.arange(0, BLOCK_K)
rk_half = tl.arange(0, BLOCK_K // 2)
if not INTERLEAVED:
# Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
cos = tl.load(
COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0
).to(tl.float32)
sin = tl.load(
SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0
).to(tl.float32)
x0 = tl.load(
X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0
).to(tl.float32)
x1 = tl.load(
X + rotary_dim_half * stride_x_headdim,
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
other=0.0,
).to(tl.float32)
if CONJUGATE:
sin = -sin
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
# write back result
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
tl.store(
OUT + rotary_dim_half * stride_out_headdim,
o1,
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
)
else:
# We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
# Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
# Loading x0 will be fast but x1 will be slow.
# Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
# Then we do the calculation and use tl.where to pick put the right outputs for the even
# and for the odd indices.
rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
rk_repeat = tl.arange(0, BLOCK_K) // 2
X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
cos = tl.load(
COS,
mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
other=1.0,
).to(tl.float32)
sin = tl.load(
SIN,
mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
other=0.0,
).to(tl.float32)
x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(
tl.float32
)
x1 = tl.load(
X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
).to(tl.float32)
if CONJUGATE:
sin = -sin
x0_cos = x0 * cos
x1_sin = x1 * sin
out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
def apply_rotary(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
interleaved=False,
inplace=False,
conjugate=False,
) -> torch.Tensor:
"""
Arguments:
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim).
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Returns:
y: (batch, seqlen, nheads, headdim)
"""
is_varlen = cu_seqlens is not None
if not is_varlen:
batch, seqlen, nheads, headdim = x.shape
else:
assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed"
total_seqlen, nheads, headdim = x.shape
batch_p_1 = cu_seqlens.shape[0]
batch = batch_p_1 - 1
seqlen = max_seqlen
seqlen_ro, rotary_dim = cos.shape
assert sin.shape == cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
assert headdim <= 256, "Only support headdim <= 256"
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
assert (
cos.dtype == sin.dtype
), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
assert (
x.dtype == cos.dtype
), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
cos, sin = cos.contiguous(), sin.contiguous()
if isinstance(seqlen_offsets, torch.Tensor):
assert seqlen_offsets.shape == (batch,)
assert seqlen_offsets.dtype in [torch.int32, torch.int64]
seqlen_offsets = seqlen_offsets.contiguous()
else:
assert seqlen_offsets + seqlen <= seqlen_ro
output = torch.empty_like(x) if not inplace else x
if rotary_dim < headdim and not inplace:
output[..., rotary_dim:].copy_(x[..., rotary_dim:])
BLOCK_K = (
32
if rotary_dim <= 32
else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
)
grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), nheads, batch) # noqa
BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 128 else 4)
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with torch.cuda.device(x.device.index):
rotary_kernel[grid](
output, # data ptrs
x,
cos,
sin,
cu_seqlens,
seqlen_offsets,
seqlen, # shapes
rotary_dim,
seqlen_ro,
output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
output.stride(-3), # seqlen_stride or total_seqlen_stride
output.stride(-2), # nheads_stride
output.stride(-1), # headdim_stride
x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
x.stride(-3), # seqlen stride or total_seqlen_stride
x.stride(-2), # nheads stride
x.stride(-1), # headdim stride
BLOCK_K,
isinstance(seqlen_offsets, torch.Tensor),
is_varlen,
interleaved,
conjugate,
BLOCK_M,
num_warps=2 if rotary_dim <= 64 else 4,
)
return output