first commit
This commit is contained in:
0
vllm/vllm_flash_attn/.gitkeep
Normal file
0
vllm/vllm_flash_attn/.gitkeep
Normal file
12
vllm/vllm_flash_attn/__init__.py
Normal file
12
vllm/vllm_flash_attn/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
__version__ = "2.7.2.post1"
|
||||
|
||||
# Use relative import to support build-from-source installation in vLLM
|
||||
from .flash_attn_interface import (
|
||||
flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache,
|
||||
get_scheduler_metadata,
|
||||
sparse_attn_func,
|
||||
sparse_attn_varlen_func,
|
||||
is_fa_version_supported,
|
||||
fa_version_unsupported_reason
|
||||
)
|
||||
BIN
vllm/vllm_flash_attn/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/vllm_flash_attn/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so
Normal file
BIN
vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so
Normal file
Binary file not shown.
BIN
vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so
Normal file
BIN
vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so
Normal file
Binary file not shown.
614
vllm/vllm_flash_attn/flash_attn_interface.py
Normal file
614
vllm/vllm_flash_attn/flash_attn_interface.py
Normal file
@@ -0,0 +1,614 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
from typing import Optional, Union, Tuple, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# isort: off
|
||||
# We need to import the CUDA kernels after importing torch
|
||||
# Use relative import to support build-from-source installation in vLLM
|
||||
|
||||
try:
|
||||
from . import _vllm_fa2_C # noqa: F401
|
||||
FA2_UNAVAILABLE_REASON = None
|
||||
FA2_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
FA2_UNAVAILABLE_REASON = str(e)
|
||||
FA2_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from . import _vllm_fa3_C # noqa: F401
|
||||
FA3_UNAVAILABLE_REASON = None
|
||||
FA3_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
FA3_UNAVAILABLE_REASON = str(e)
|
||||
FA3_AVAILABLE = False
|
||||
|
||||
# isort: on
|
||||
|
||||
DEFAULT_FA_VERSION = 2
|
||||
|
||||
def _is_fa2_supported(device = None) -> Tuple[bool, Optional[str]]:
|
||||
if not FA2_AVAILABLE:
|
||||
return False, f"FA2 is unavaible due to: {FA2_UNAVAILABLE_REASON}"
|
||||
if torch.cuda.get_device_capability(device)[0] < 8:
|
||||
return False, \
|
||||
"FA2 is only supported on devices with compute capability >= 8"
|
||||
return True, None
|
||||
|
||||
def _is_fa3_supported(device = None) -> Tuple[bool, Optional[str]]:
|
||||
if not FA3_AVAILABLE:
|
||||
return False, f"FA3 is unavaible due to: {FA3_UNAVAILABLE_REASON}"
|
||||
if torch.cuda.get_device_capability(device)[0] < 8 \
|
||||
or torch.cuda.get_device_capability(device)[0] >= 10 \
|
||||
or torch.cuda.get_device_capability(device) == (8, 6) \
|
||||
or torch.cuda.get_device_capability(device) == (8, 9):
|
||||
return False, \
|
||||
"FA3 is only supported on devices with compute capability >= 8" \
|
||||
" excluding 8.6 and 8.9 and Blackwell archs (>=10)"
|
||||
return True, None
|
||||
|
||||
def is_fa_version_supported(fa_version: int, device = None) -> bool:
|
||||
assert fa_version in [2, 3], f"Unsupported FA version: {fa_version}"
|
||||
if fa_version == 2:
|
||||
return _is_fa2_supported(device)[0]
|
||||
elif fa_version == 3:
|
||||
return _is_fa3_supported(device)[0]
|
||||
|
||||
def fa_version_unsupported_reason(fa_version: int, device = None) \
|
||||
-> Optional[str]:
|
||||
assert fa_version in [2, 3], f"Unsupported FA version: {fa_version}"
|
||||
if fa_version == 2:
|
||||
return _is_fa2_supported(device)[1]
|
||||
elif fa_version == 3:
|
||||
return _is_fa3_supported(device)[1]
|
||||
|
||||
#
|
||||
# For vLLM we only care about `flash_attn_varlen_func` and
|
||||
# `flash_attn_with_kvcache` so we only maintain wrappers for these two.
|
||||
#
|
||||
|
||||
|
||||
def maybe_contiguous(x):
|
||||
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
||||
|
||||
# NOTE only used in FA3
|
||||
def get_scheduler_metadata(
|
||||
batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
|
||||
cache_seqlens: torch.Tensor,
|
||||
qkv_dtype=torch.bfloat16,
|
||||
headdim_v=None,
|
||||
cu_seqlens_q: Optional[torch.Tensor] = None,
|
||||
cu_seqlens_k_new: Optional[torch.Tensor] = None,
|
||||
cache_leftpad: Optional[torch.Tensor] = None,
|
||||
page_size: Optional[int] = None,
|
||||
max_seqlen_k_new=0,
|
||||
causal=False,
|
||||
window_size=(-1, -1), # -1 means infinite context window
|
||||
has_softcap=False,
|
||||
num_splits=0, # Can be tuned for speed
|
||||
pack_gqa=None, # Can be tuned for speed
|
||||
sm_margin=0, # Can be tuned if some SMs are used for communication
|
||||
):
|
||||
cache_seqlens = maybe_contiguous(cache_seqlens)
|
||||
if headdim_v is None:
|
||||
headdim_v = headdim
|
||||
scheduler_metadata = torch.ops._vllm_fa3_C.get_scheduler_metadata(
|
||||
batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
|
||||
qkv_dtype,
|
||||
cache_seqlens,
|
||||
cu_seqlens_q,
|
||||
None, # cu_seqlens_k
|
||||
cu_seqlens_k_new,
|
||||
None, # seqused_q
|
||||
cache_leftpad,
|
||||
page_size,
|
||||
max_seqlen_k_new,
|
||||
causal,
|
||||
window_size[0], window_size[1],
|
||||
has_softcap,
|
||||
num_splits,
|
||||
pack_gqa,
|
||||
sm_margin,
|
||||
)
|
||||
|
||||
return scheduler_metadata
|
||||
|
||||
|
||||
def flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
max_seqlen_q,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_k,
|
||||
cu_seqlens_k=None, # only used for non-paged prefill
|
||||
seqused_k=None,
|
||||
q_v=None,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
window_size: Optional[List[int]] = None,
|
||||
softcap=0.0, # 0.0 means deactivated
|
||||
alibi_slopes=None,
|
||||
deterministic=False,
|
||||
return_attn_probs=False,
|
||||
block_table=None,
|
||||
return_softmax_lse=False,
|
||||
out=None,
|
||||
# FA3 Only
|
||||
scheduler_metadata=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
num_splits: int = 0,
|
||||
# Version selector
|
||||
fa_version: int = DEFAULT_FA_VERSION,
|
||||
s_aux=None,
|
||||
):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
|
||||
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
||||
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||
|
||||
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
||||
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
||||
1 1 1 1 0
|
||||
1 1 1 1 1
|
||||
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
||||
0 0
|
||||
0 0
|
||||
0 0
|
||||
1 0
|
||||
1 1
|
||||
If the row of the mask is all zero, the output will be zero.
|
||||
|
||||
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
||||
will only attend to keys between
|
||||
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
||||
|
||||
Arguments:
|
||||
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||||
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into q.
|
||||
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into kv.
|
||||
max_seqlen_q: int. Maximum query sequence length in the batch.
|
||||
max_seqlen_k: int. Maximum key sequence length in the batch.
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
||||
softcap: float. Anything > 0 activates softcapping attention.
|
||||
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
||||
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
||||
is added to the attention score of query i and key j.
|
||||
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
||||
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
"""
|
||||
assert cu_seqlens_k is not None or seqused_k is not None, \
|
||||
"cu_seqlens_k or seqused_k must be provided"
|
||||
assert cu_seqlens_k is None or seqused_k is None, \
|
||||
"cu_seqlens_k and seqused_k cannot be provided at the same time"
|
||||
assert block_table is None or seqused_k is not None, \
|
||||
"seqused_k must be provided if block_table is provided"
|
||||
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
# custom op does not support non-tuple input
|
||||
real_window_size: Tuple[int, int]
|
||||
if window_size is None:
|
||||
real_window_size = (-1, -1)
|
||||
else:
|
||||
assert len(window_size) == 2
|
||||
real_window_size = (window_size[0], window_size[1])
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
|
||||
dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q)
|
||||
|
||||
if fa_version == 2:
|
||||
if scheduler_metadata is not None and q_descale is not None \
|
||||
and k_descale is not None and v_descale is not None:
|
||||
raise NotImplementedError(
|
||||
"FA2 does not support scheduler_metadata, q_descale, "
|
||||
"k_descale, v_descale"
|
||||
)
|
||||
if s_aux is not None:
|
||||
raise NotImplementedError("FA2 does not support s_aux")
|
||||
if num_splits > 1:
|
||||
raise NotImplementedError("FA2 does not support num_splits > 1")
|
||||
out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd(
|
||||
q, k, v,
|
||||
out,
|
||||
cu_seqlens_q,
|
||||
# cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
|
||||
# still wants it so we pass all zeros
|
||||
dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k,
|
||||
seqused_k,
|
||||
None,
|
||||
block_table,
|
||||
alibi_slopes,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
real_window_size[0],
|
||||
real_window_size[1],
|
||||
softcap,
|
||||
return_softmax_lse and dropout_p > 0,
|
||||
None,
|
||||
)
|
||||
elif fa_version == 3:
|
||||
assert alibi_slopes is None, "Alibi is not supported in FA3"
|
||||
out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd(
|
||||
q, k, v,
|
||||
None, None, # k_new, v_new
|
||||
q_v,
|
||||
out,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k, # cu_seqlens_k
|
||||
None, # cu_seqlens_k_new
|
||||
None, seqused_k, # seqused_q, seqused_k
|
||||
max_seqlen_q, max_seqlen_k,
|
||||
block_table,
|
||||
None, # kv_batch_idx
|
||||
None, # leftpad_k
|
||||
None, None, None, # rotary_cos, rotary_sin, seqlens_rotary
|
||||
q_descale, k_descale, v_descale,
|
||||
softmax_scale,
|
||||
causal,
|
||||
real_window_size[0], real_window_size[1],
|
||||
softcap,
|
||||
True, # rotary_interleaved
|
||||
scheduler_metadata,
|
||||
num_splits,
|
||||
None, # pack_gqa
|
||||
0, # sm_margin
|
||||
s_aux # s_aux
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported FA version: {fa_version}")
|
||||
return (out, softmax_lse) if return_softmax_lse else out
|
||||
|
||||
|
||||
def flash_attn_with_kvcache(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
k=None,
|
||||
v=None,
|
||||
rotary_cos=None,
|
||||
rotary_sin=None,
|
||||
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
|
||||
cache_batch_idx: Optional[torch.Tensor] = None,
|
||||
cache_leftpad: Optional[torch.Tensor] = None,
|
||||
block_table: Optional[torch.Tensor] = None,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
window_size=(-1, -1), # -1 means infinite context window
|
||||
softcap=0.0, # 0.0 means deactivated
|
||||
rotary_interleaved=True,
|
||||
alibi_slopes=None,
|
||||
num_splits=0,
|
||||
return_softmax_lse=False,
|
||||
*,
|
||||
out=None,
|
||||
# FA3 Only
|
||||
scheduler_metadata=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
# Version selector
|
||||
fa_version: int = DEFAULT_FA_VERSION,
|
||||
s_aux=None,
|
||||
):
|
||||
"""
|
||||
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
|
||||
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
|
||||
the previous step, and update them with the new keys/values from the current step, and do
|
||||
attention with the updated cache, all in 1 kernel.
|
||||
|
||||
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
|
||||
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
|
||||
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
|
||||
|
||||
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
|
||||
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
||||
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
|
||||
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
||||
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
|
||||
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
|
||||
|
||||
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
|
||||
|
||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
||||
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
||||
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||
|
||||
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
||||
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
||||
1 1 1 1 0
|
||||
1 1 1 1 1
|
||||
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
||||
0 0
|
||||
0 0
|
||||
0 0
|
||||
1 0
|
||||
1 1
|
||||
If the row of the mask is all zero, the output will be zero.
|
||||
|
||||
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
||||
will only attend to keys between
|
||||
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
||||
|
||||
Note: Does not support backward pass.
|
||||
|
||||
Arguments:
|
||||
q: (batch_size, seqlen, nheads, headdim)
|
||||
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
|
||||
or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
|
||||
page_block_size must be a multiple of 256.
|
||||
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
|
||||
or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
|
||||
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
|
||||
k with k_cache, starting at the indices specified by cache_seqlens.
|
||||
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
|
||||
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
|
||||
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
|
||||
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
|
||||
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
|
||||
KV cache.
|
||||
block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
|
||||
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
|
||||
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
|
||||
If the indices are not distinct, and k and v are provided, the values updated in the cache
|
||||
might come from any of the duplicate indices.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
||||
softcap: float. Anything > 0 activates softcapping attention.
|
||||
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
|
||||
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
|
||||
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
|
||||
(i.e. GPT-NeoX style).
|
||||
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
||||
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
||||
is added to the attention score of query i and key j.
|
||||
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
|
||||
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
|
||||
to automatically determine the number of splits.
|
||||
Don't change this unless you know what you are doing.
|
||||
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
|
||||
|
||||
Return:
|
||||
out: (batch_size, seqlen, nheads, headdim).
|
||||
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
"""
|
||||
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
|
||||
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
if cache_seqlens is not None and isinstance(cache_seqlens, int):
|
||||
cache_seqlens = torch.full(
|
||||
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
|
||||
)
|
||||
cache_seqlens = maybe_contiguous(cache_seqlens)
|
||||
cache_batch_idx = maybe_contiguous(cache_batch_idx)
|
||||
block_table = maybe_contiguous(block_table)
|
||||
|
||||
if s_aux is not None:
|
||||
raise NotImplementedError("FA2 does not support s_aux")
|
||||
if scheduler_metadata is not None and q_descale is not None \
|
||||
and k_descale is not None and v_descale is not None:
|
||||
raise NotImplementedError(
|
||||
"FA2 does not support scheduler_metadata, q_descale, "
|
||||
"k_descale, v_descale"
|
||||
)
|
||||
|
||||
out, softmax_lse = torch.ops._vllm_fa2_C.fwd_kvcache(
|
||||
q, k_cache, v_cache,
|
||||
k, v, # k_new, v_new
|
||||
cache_seqlens,
|
||||
rotary_cos,
|
||||
rotary_sin,
|
||||
cache_batch_idx,
|
||||
cache_leftpad,
|
||||
block_table,
|
||||
alibi_slopes,
|
||||
out,
|
||||
softmax_scale,
|
||||
causal,
|
||||
window_size[0],
|
||||
window_size[1],
|
||||
softcap,
|
||||
rotary_interleaved,
|
||||
num_splits,
|
||||
)
|
||||
return (out, softmax_lse) if return_softmax_lse else out
|
||||
|
||||
|
||||
def sparse_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
softcap=0.0, # 0.0 means deactivated
|
||||
alibi_slopes=None,
|
||||
deterministic=False,
|
||||
return_attn_probs=False,
|
||||
*,
|
||||
return_softmax_lse=False,
|
||||
out=None,
|
||||
):
|
||||
"""Compute attention with vertical and slash sparsity patterns.
|
||||
Most Arguments are the same with the flash_attn_func interface, except for 4 extra args:
|
||||
block_count and block_offset for slash sparsity patterns, and
|
||||
column_count and column_index for vertical sparsity patterns.
|
||||
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
|
||||
|
||||
Arguments:
|
||||
q: (batch_size, seqlen, nheads, headdim)
|
||||
k: (batch_size, seqlen, nheads_k, headdim)
|
||||
v: (batch_size, seqlen, nheads_k, headdim)
|
||||
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
|
||||
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
||||
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
||||
is added to the attention score of query i and key j.
|
||||
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
||||
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
Return:
|
||||
out: (batch_size, seqlen, nheads, headdim).
|
||||
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
out, softmax_lse = torch.ops._vllm_fa2_C.fwd_sparse(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
out,
|
||||
alibi_slopes,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
causal,
|
||||
softcap,
|
||||
return_attn_probs and dropout_p > 0,
|
||||
None,
|
||||
)
|
||||
return (out, softmax_lse) if return_softmax_lse else out
|
||||
|
||||
|
||||
def sparse_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
softcap=0.0, # 0.0 means deactivated
|
||||
alibi_slopes=None,
|
||||
deterministic=False,
|
||||
return_attn_probs=False,
|
||||
*,
|
||||
return_softmax_lse=False,
|
||||
out=None,
|
||||
):
|
||||
"""Compute attention with vertical and slash sparsity patterns.
|
||||
Most Arguments are the same with the flash_attn_varlen_func interface, except for 4 extra args:
|
||||
block_count and block_offset for slash sparsity patterns, and
|
||||
column_count and column_index for vertical sparsity patterns.
|
||||
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
|
||||
|
||||
Arguments:
|
||||
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||||
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
|
||||
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
|
||||
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into q.
|
||||
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into kv.
|
||||
max_seqlen_q: int. Maximum query sequence length in the batch.
|
||||
max_seqlen_k: int. Maximum key sequence length in the batch.
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
softcap: float. Anything > 0 activates softcapping attention.
|
||||
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
||||
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
||||
is added to the attention score of query i and key j.
|
||||
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
||||
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd_sparse(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
out,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
None,
|
||||
alibi_slopes,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
softcap,
|
||||
return_attn_probs and dropout_p > 0,
|
||||
None,
|
||||
)
|
||||
return (out, softmax_lse) if return_softmax_lse else out
|
||||
0
vllm/vllm_flash_attn/layers/__init__.py
Normal file
0
vllm/vllm_flash_attn/layers/__init__.py
Normal file
BIN
vllm/vllm_flash_attn/layers/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/vllm_flash_attn/layers/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/vllm_flash_attn/layers/__pycache__/rotary.cpython-310.pyc
Normal file
BIN
vllm/vllm_flash_attn/layers/__pycache__/rotary.cpython-310.pyc
Normal file
Binary file not shown.
530
vllm/vllm_flash_attn/layers/rotary.py
Normal file
530
vllm/vllm_flash_attn/layers/rotary.py
Normal file
@@ -0,0 +1,530 @@
|
||||
# Adapted from https://github.com/vllm-project/flash-attention/blob/main/flash_attn/layers/rotary.py
|
||||
# Modified lines are marked with `# modified from original` comment
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from ..ops.triton.rotary import apply_rotary # modified from original
|
||||
|
||||
|
||||
def rotate_half(x, interleaved=False):
|
||||
if not interleaved:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
else:
|
||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
||||
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
|
||||
|
||||
|
||||
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
||||
"""
|
||||
x: (batch_size, seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
||||
"""
|
||||
ro_dim = cos.shape[-1] * 2
|
||||
assert ro_dim <= x.shape[-1]
|
||||
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
||||
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
||||
return torch.cat(
|
||||
[x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
class ApplyRotaryEmb(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x,
|
||||
cos,
|
||||
sin,
|
||||
interleaved=False,
|
||||
inplace=False,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
):
|
||||
out = apply_rotary(
|
||||
x,
|
||||
cos,
|
||||
sin,
|
||||
seqlen_offsets=seqlen_offsets,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
interleaved=interleaved,
|
||||
inplace=inplace,
|
||||
)
|
||||
if isinstance(seqlen_offsets, int):
|
||||
ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
|
||||
ctx.seqlen_offsets = seqlen_offsets
|
||||
else:
|
||||
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
|
||||
ctx.seqlen_offsets = None
|
||||
ctx.interleaved = interleaved
|
||||
ctx.inplace = inplace
|
||||
ctx.max_seqlen = max_seqlen
|
||||
return out if not inplace else x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
seqlen_offsets = ctx.seqlen_offsets
|
||||
if seqlen_offsets is None:
|
||||
cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
|
||||
else:
|
||||
cos, sin, cu_seqlens = ctx.saved_tensors
|
||||
# TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
|
||||
# "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
|
||||
if not ctx.interleaved and not ctx.inplace:
|
||||
do = do.clone()
|
||||
dx = apply_rotary(
|
||||
do,
|
||||
cos,
|
||||
sin,
|
||||
seqlen_offsets=seqlen_offsets,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=ctx.max_seqlen,
|
||||
interleaved=ctx.interleaved,
|
||||
inplace=ctx.inplace,
|
||||
conjugate=True,
|
||||
)
|
||||
return dx, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
x,
|
||||
cos,
|
||||
sin,
|
||||
interleaved=False,
|
||||
inplace=False,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
|
||||
else (total_seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen_rotary, rotary_dim / 2)
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
||||
of 1st half and 2nd half (GPT-NeoX style).
|
||||
inplace: if True, apply rotary embedding in-place.
|
||||
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
||||
Most commonly used in inference when we have KV cache.
|
||||
cu_seqlens: (batch + 1,) or None
|
||||
max_seqlen: int
|
||||
Return:
|
||||
out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
|
||||
else (total_seqlen, nheads, headdim)
|
||||
rotary_dim must be <= headdim
|
||||
Apply rotary embedding to the first rotary_dim of x.
|
||||
"""
|
||||
return ApplyRotaryEmb.apply(
|
||||
x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
|
||||
)
|
||||
|
||||
|
||||
# For backward compatibility
|
||||
apply_rotary_emb_func = apply_rotary_emb
|
||||
|
||||
|
||||
class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
qkv,
|
||||
cos,
|
||||
sin,
|
||||
cos_k=None,
|
||||
sin_k=None,
|
||||
interleaved=False,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
num_heads_q: Union[int] = None,
|
||||
):
|
||||
if cos_k is None and sin_k is None and qkv.is_contiguous():
|
||||
# Call 1 kernel instead of 2 kernels
|
||||
# We need qkv to be contiguous so that when we reshape to combine (3, nheads)
|
||||
# dimensions, we get the same tensor
|
||||
if qkv.dim() == 5:
|
||||
batch, seqlen, three, nheads, headdim = qkv.shape
|
||||
assert three == 3
|
||||
# qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
|
||||
qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
|
||||
else:
|
||||
assert qkv.dim() == 4
|
||||
assert num_heads_q is not None
|
||||
num_heads_k = (qkv.shape[2] - num_heads_q) // 2
|
||||
assert qkv.shape[2] == num_heads_q + 2 * num_heads_k
|
||||
qk = qkv[:, :, :num_heads_q + num_heads_k]
|
||||
apply_rotary(
|
||||
qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
|
||||
)
|
||||
else:
|
||||
cos_k = cos if cos_k is None else cos_k
|
||||
sin_k = sin if sin_k is None else sin_k
|
||||
if qkv.dim() == 5:
|
||||
q, k = qkv[:, :, 0], qkv[:, :, 1]
|
||||
else:
|
||||
assert qkv.dim() == 4
|
||||
assert num_heads_q is not None
|
||||
num_heads_k = (qkv.shape[2] - num_heads_q) // 2
|
||||
assert qkv.shape[2] == num_heads_q + 2 * num_heads_k
|
||||
q, k = qkv[:, :, :num_heads_q], qkv[:, :, num_heads_q : num_heads_q + num_heads_k]
|
||||
apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True)
|
||||
apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True)
|
||||
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
||||
if isinstance(seqlen_offsets, int):
|
||||
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
||||
ctx.seqlen_offsets = seqlen_offsets
|
||||
else:
|
||||
ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets)
|
||||
ctx.seqlen_offsets = None
|
||||
ctx.interleaved = interleaved
|
||||
ctx.num_heads_q = num_heads_q
|
||||
return qkv
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dqkv):
|
||||
seqlen_offsets = ctx.seqlen_offsets
|
||||
if seqlen_offsets is None:
|
||||
cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors
|
||||
else:
|
||||
cos, sin, cos_k, sin_k = ctx.saved_tensors
|
||||
if cos_k is None and sin_k is None and dqkv.is_contiguous():
|
||||
# Call 1 kernel instead of 2 kernels
|
||||
# We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
|
||||
# dimensions, we get the same tensor
|
||||
if dqkv.dim() == 5:
|
||||
dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d")
|
||||
else:
|
||||
assert dqkv.dim() == 4
|
||||
assert ctx.num_heads_q is not None
|
||||
num_heads_k = (dqkv.shape[2] - ctx.num_heads_q) // 2
|
||||
assert dqkv.shape[2] == ctx.num_heads_q + 2 * num_heads_k
|
||||
dqk = dqkv[:, :, : ctx.num_heads_q + num_heads_k]
|
||||
apply_rotary(
|
||||
dqk,
|
||||
cos,
|
||||
sin,
|
||||
seqlen_offsets=seqlen_offsets,
|
||||
interleaved=ctx.interleaved,
|
||||
inplace=True,
|
||||
conjugate=True,
|
||||
)
|
||||
else:
|
||||
cos_k = cos if cos_k is None else cos_k
|
||||
sin_k = sin if sin_k is None else sin_k
|
||||
if dqkv.dim() == 5:
|
||||
dq, dk = dqkv[:, :, 0], dqkv[:, :, 1]
|
||||
else:
|
||||
assert dqkv.dim() == 4
|
||||
assert ctx.num_heads_q is not None
|
||||
num_heads_k = (dqkv.shape[2] - ctx.num_heads_q) // 2
|
||||
assert dqkv.shape[2] == ctx.num_heads_q + 2 * num_heads_k
|
||||
dq = dqkv[:, :, : ctx.num_heads_q]
|
||||
dk = dqkv[:, :, ctx.num_heads_q : ctx.num_heads_q + num_heads_k]
|
||||
apply_rotary(
|
||||
dq,
|
||||
cos,
|
||||
sin,
|
||||
seqlen_offsets,
|
||||
interleaved=ctx.interleaved,
|
||||
inplace=True,
|
||||
conjugate=True,
|
||||
)
|
||||
apply_rotary(
|
||||
dk,
|
||||
cos_k,
|
||||
sin_k,
|
||||
seqlen_offsets,
|
||||
interleaved=ctx.interleaved,
|
||||
inplace=True,
|
||||
conjugate=True,
|
||||
)
|
||||
return dqkv, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def apply_rotary_emb_qkv_(
|
||||
qkv,
|
||||
cos,
|
||||
sin,
|
||||
cos_k=None,
|
||||
sin_k=None,
|
||||
interleaved=False,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
num_heads_q: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
qkv: (batch_size, seqlen, 3, nheads, headdim) or (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim).
|
||||
If qkv has shape (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim) (e.g. MQA / GQA),
|
||||
then num_heads_q must be provided.
|
||||
cos, sin: (seqlen, rotary_dim / 2)
|
||||
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
|
||||
1st half and 2nd half (GPT-NeoX style).
|
||||
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
|
||||
Most commonly used in inference when we have KV cache.
|
||||
Return:
|
||||
qkv: (batch_size, seqlen, 3, nheads, headdim) or (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim)
|
||||
rotary_dim must be <= headdim
|
||||
Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
|
||||
"""
|
||||
return ApplyRotaryEmbQKV_.apply(
|
||||
qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, num_heads_q
|
||||
)
|
||||
|
||||
|
||||
class ApplyRotaryEmbKV_(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0):
|
||||
batch, seqlen, two, nheads, headdim = kv.shape
|
||||
assert two == 2
|
||||
k = kv[:, :, 0]
|
||||
apply_rotary(
|
||||
k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
|
||||
)
|
||||
if isinstance(seqlen_offsets, int):
|
||||
ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward
|
||||
ctx.seqlen_offsets = seqlen_offsets
|
||||
else:
|
||||
ctx.save_for_backward(cos, sin, seqlen_offsets)
|
||||
ctx.seqlen_offsets = None
|
||||
ctx.interleaved = interleaved
|
||||
return kv
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dkv):
|
||||
seqlen_offsets = ctx.seqlen_offsets
|
||||
if seqlen_offsets is None:
|
||||
cos, sin, seqlen_offsets = ctx.saved_tensors
|
||||
else:
|
||||
cos, sin = ctx.saved_tensors
|
||||
apply_rotary(
|
||||
dkv[:, :, 0],
|
||||
cos,
|
||||
sin,
|
||||
seqlen_offsets=seqlen_offsets,
|
||||
interleaved=ctx.interleaved,
|
||||
inplace=True,
|
||||
conjugate=True,
|
||||
)
|
||||
return dkv, None, None, None, None
|
||||
|
||||
|
||||
apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply
|
||||
|
||||
|
||||
def apply_rotary_emb_kv_(
|
||||
kv,
|
||||
cos,
|
||||
sin,
|
||||
interleaved=False,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
kv: (batch_size, seqlen, 2, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2)
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
|
||||
1st half and 2nd half (GPT-NeoX style).
|
||||
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
|
||||
Most commonly used in inference when we have KV cache.
|
||||
Return:
|
||||
kv: (batch_size, seqlen, 2, nheads, headdim)
|
||||
rotary_dim must be <= headdim
|
||||
Apply rotary embedding *inplace* to the first rotary_dim of K.
|
||||
"""
|
||||
return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets)
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
"""
|
||||
The rotary position embeddings from RoFormer_ (Su et. al).
|
||||
A crucial insight from the method is that the query and keys are
|
||||
transformed by rotation matrices which depend on the relative positions.
|
||||
|
||||
Other implementations are available in the Rotary Transformer repo_ and in
|
||||
GPT-NeoX_, GPT-NeoX was an inspiration
|
||||
|
||||
.. _RoFormer: https://arxiv.org/abs/2104.09864
|
||||
.. _repo: https://github.com/ZhuiyiTechnology/roformer
|
||||
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
|
||||
|
||||
If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
|
||||
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
|
||||
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
base=10000.0,
|
||||
interleaved=False,
|
||||
scale_base=None,
|
||||
pos_idx_in_fp32=True,
|
||||
device=None,
|
||||
):
|
||||
"""
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
||||
of 1st half and 2nd half (GPT-NeoX style).
|
||||
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
|
||||
otherwise they might be in lower precision.
|
||||
This option was added because previously (before 2023-07-02), when we construct
|
||||
the position indices, we use the dtype of self.inv_freq. In most cases this would
|
||||
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
|
||||
self.inv_freq would be bf16, and the position indices are also in bf16.
|
||||
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
|
||||
embeddings for some positions will coincide.
|
||||
To maintain compatibility with models previously trained in pure bf16,
|
||||
we add this option.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.base = float(base)
|
||||
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
||||
# Generate and save the inverse frequency buffer (non trainable)
|
||||
inv_freq = self._compute_inv_freq(device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.interleaved = interleaved
|
||||
self.scale_base = scale_base
|
||||
scale = (
|
||||
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
||||
if scale_base is not None
|
||||
else None
|
||||
)
|
||||
self.register_buffer("scale", scale, persistent=False)
|
||||
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
|
||||
def _compute_inv_freq(self, device=None):
|
||||
return 1.0 / (
|
||||
self.base
|
||||
** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
|
||||
)
|
||||
|
||||
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# if we're on a new device (possibly due to tracing for instance),
|
||||
# or if we're switching from inference mode to training
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached is None
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
or (self.training and self._cos_cached.is_inference())
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
|
||||
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
||||
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
|
||||
if self.pos_idx_in_fp32:
|
||||
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
||||
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
|
||||
# will be large. Having it in bf16 will lose a lot of precision and cause the
|
||||
# cos & sin output to change significantly.
|
||||
# We want to recompute self.inv_freq if it was not loaded in fp32
|
||||
if self.inv_freq.dtype != torch.float32:
|
||||
inv_freq = self._compute_inv_freq(device=device)
|
||||
else:
|
||||
inv_freq = self.inv_freq
|
||||
else:
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
inv_freq = self.inv_freq
|
||||
# Don't do einsum, it converts fp32 to fp16 under AMP
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
if self.scale is None:
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
else:
|
||||
power = (
|
||||
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
|
||||
- seqlen // 2
|
||||
) / self.scale_base
|
||||
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
||||
# We want the multiplication by scale to happen in fp32
|
||||
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
||||
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
||||
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
||||
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
qkv: torch.Tensor,
|
||||
kv: Optional[torch.Tensor] = None,
|
||||
seqlen_offset: Union[int, torch.Tensor] = 0,
|
||||
max_seqlen: Optional[int] = None,
|
||||
num_heads_q: Optional[int] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
qkv: (batch, seqlen, 3, nheads, headdim) or (batch, seqlen, num_heads_q + 2 * num_heads_k, headdim)
|
||||
if kv is none, else it's just q of shape (batch, seqlen, nheads, headdim).
|
||||
If qkv has shape (batch, seqlen, num_heads_q + 2 * num_heads_k, headdim) (e.g. MQA / GQA),
|
||||
then num_heads_q must be provided.
|
||||
kv: (batch, seqlen, 2, nheads, headdim)
|
||||
seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
||||
Most commonly used in inference when we have KV cache.
|
||||
If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
|
||||
should pass in max_seqlen, which will update the cos / sin cache up to that length.
|
||||
Apply rotary embedding *inplace* to qkv and / or kv.
|
||||
"""
|
||||
seqlen = qkv.shape[1]
|
||||
if max_seqlen is not None:
|
||||
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
||||
elif isinstance(seqlen_offset, int):
|
||||
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
||||
if kv is None:
|
||||
if self.scale is None:
|
||||
return apply_rotary_emb_qkv_(
|
||||
qkv,
|
||||
self._cos_cached,
|
||||
self._sin_cached,
|
||||
interleaved=self.interleaved,
|
||||
seqlen_offsets=seqlen_offset,
|
||||
num_heads_q=num_heads_q,
|
||||
)
|
||||
else:
|
||||
return apply_rotary_emb_qkv_(
|
||||
qkv,
|
||||
self._cos_cached,
|
||||
self._sin_cached,
|
||||
self._cos_k_cached,
|
||||
self._sin_k_cached,
|
||||
interleaved=self.interleaved,
|
||||
seqlen_offsets=seqlen_offset,
|
||||
num_heads_q=num_heads_q,
|
||||
)
|
||||
else:
|
||||
q = qkv
|
||||
q = apply_rotary_emb_func(
|
||||
q,
|
||||
self._cos_cached,
|
||||
self._sin_cached,
|
||||
interleaved=self.interleaved,
|
||||
inplace=True,
|
||||
seqlen_offsets=seqlen_offset,
|
||||
)
|
||||
if self.scale is None:
|
||||
kv = apply_rotary_emb_kv_(
|
||||
kv,
|
||||
self._cos_cached,
|
||||
self._sin_cached,
|
||||
interleaved=self.interleaved,
|
||||
seqlen_offsets=seqlen_offset,
|
||||
)
|
||||
else:
|
||||
kv = apply_rotary_emb_kv_(
|
||||
kv,
|
||||
self._cos_k_cached,
|
||||
self._sin_k_cached,
|
||||
interleaved=self.interleaved,
|
||||
seqlen_offsets=seqlen_offset,
|
||||
)
|
||||
return q, kv
|
||||
1
vllm/vllm_flash_attn/ops/triton/__init__.py
Normal file
1
vllm/vllm_flash_attn/ops/triton/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
Binary file not shown.
Binary file not shown.
229
vllm/vllm_flash_attn/ops/triton/rotary.py
Normal file
229
vllm/vllm_flash_attn/ops/triton/rotary.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# Copy from https://github.com/vllm-project/flash-attention/blob/main/flash_attn/ops/triton/rotary.py
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def rotary_kernel(
|
||||
OUT, # Pointers to matrices
|
||||
X,
|
||||
COS,
|
||||
SIN,
|
||||
CU_SEQLENS,
|
||||
SEQLEN_OFFSETS, # this could be int or a pointer
|
||||
# Matrix dimensions
|
||||
seqlen,
|
||||
rotary_dim,
|
||||
seqlen_ro,
|
||||
# strides
|
||||
stride_out_batch,
|
||||
stride_out_seqlen,
|
||||
stride_out_nheads,
|
||||
stride_out_headdim,
|
||||
stride_x_batch,
|
||||
stride_x_seqlen,
|
||||
stride_x_nheads,
|
||||
stride_x_headdim,
|
||||
# Meta-parameters
|
||||
BLOCK_K: tl.constexpr,
|
||||
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
INTERLEAVED: tl.constexpr,
|
||||
CONJUGATE: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_head = tl.program_id(axis=1)
|
||||
pid_batch = tl.program_id(axis=2)
|
||||
rotary_dim_half = rotary_dim // 2
|
||||
|
||||
if not IS_VARLEN:
|
||||
X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
|
||||
OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
|
||||
else:
|
||||
start_idx = tl.load(CU_SEQLENS + pid_batch)
|
||||
seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
|
||||
X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
|
||||
OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
|
||||
|
||||
if pid_m * BLOCK_M >= seqlen:
|
||||
return
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
if not IS_SEQLEN_OFFSETS_TENSOR:
|
||||
rm_cs = rm + SEQLEN_OFFSETS
|
||||
else:
|
||||
rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
|
||||
rk = tl.arange(0, BLOCK_K)
|
||||
rk_half = tl.arange(0, BLOCK_K // 2)
|
||||
|
||||
if not INTERLEAVED:
|
||||
# Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
|
||||
X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)
|
||||
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
|
||||
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
|
||||
cos = tl.load(
|
||||
COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0
|
||||
).to(tl.float32)
|
||||
sin = tl.load(
|
||||
SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0
|
||||
).to(tl.float32)
|
||||
x0 = tl.load(
|
||||
X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0
|
||||
).to(tl.float32)
|
||||
x1 = tl.load(
|
||||
X + rotary_dim_half * stride_x_headdim,
|
||||
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
if CONJUGATE:
|
||||
sin = -sin
|
||||
o0 = x0 * cos - x1 * sin
|
||||
o1 = x0 * sin + x1 * cos
|
||||
# write back result
|
||||
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
|
||||
tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
|
||||
tl.store(
|
||||
OUT + rotary_dim_half * stride_out_headdim,
|
||||
o1,
|
||||
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
|
||||
)
|
||||
else:
|
||||
# We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
|
||||
# Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
|
||||
# Loading x0 will be fast but x1 will be slow.
|
||||
# Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
|
||||
# Then we do the calculation and use tl.where to pick put the right outputs for the even
|
||||
# and for the odd indices.
|
||||
rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
|
||||
rk_repeat = tl.arange(0, BLOCK_K) // 2
|
||||
X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
|
||||
X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
|
||||
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
|
||||
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
|
||||
cos = tl.load(
|
||||
COS,
|
||||
mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
|
||||
other=1.0,
|
||||
).to(tl.float32)
|
||||
sin = tl.load(
|
||||
SIN,
|
||||
mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
x1 = tl.load(
|
||||
X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
|
||||
).to(tl.float32)
|
||||
if CONJUGATE:
|
||||
sin = -sin
|
||||
x0_cos = x0 * cos
|
||||
x1_sin = x1 * sin
|
||||
out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
|
||||
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
|
||||
tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
|
||||
|
||||
|
||||
def apply_rotary(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
interleaved=False,
|
||||
inplace=False,
|
||||
conjugate=False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Arguments:
|
||||
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
|
||||
else (total_seqlen, nheads, headdim).
|
||||
cos: (seqlen_ro, rotary_dim / 2)
|
||||
sin: (seqlen_ro, rotary_dim / 2)
|
||||
seqlen_offsets: integer or integer tensor of size (batch,)
|
||||
cu_seqlens: (batch + 1,) or None
|
||||
max_seqlen: int
|
||||
Returns:
|
||||
y: (batch, seqlen, nheads, headdim)
|
||||
"""
|
||||
is_varlen = cu_seqlens is not None
|
||||
if not is_varlen:
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
else:
|
||||
assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed"
|
||||
total_seqlen, nheads, headdim = x.shape
|
||||
batch_p_1 = cu_seqlens.shape[0]
|
||||
batch = batch_p_1 - 1
|
||||
seqlen = max_seqlen
|
||||
seqlen_ro, rotary_dim = cos.shape
|
||||
assert sin.shape == cos.shape
|
||||
rotary_dim *= 2
|
||||
assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
|
||||
assert headdim <= 256, "Only support headdim <= 256"
|
||||
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
|
||||
|
||||
assert (
|
||||
cos.dtype == sin.dtype
|
||||
), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
|
||||
assert (
|
||||
x.dtype == cos.dtype
|
||||
), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
|
||||
|
||||
cos, sin = cos.contiguous(), sin.contiguous()
|
||||
if isinstance(seqlen_offsets, torch.Tensor):
|
||||
assert seqlen_offsets.shape == (batch,)
|
||||
assert seqlen_offsets.dtype in [torch.int32, torch.int64]
|
||||
seqlen_offsets = seqlen_offsets.contiguous()
|
||||
else:
|
||||
assert seqlen_offsets + seqlen <= seqlen_ro
|
||||
|
||||
output = torch.empty_like(x) if not inplace else x
|
||||
if rotary_dim < headdim and not inplace:
|
||||
output[..., rotary_dim:].copy_(x[..., rotary_dim:])
|
||||
|
||||
BLOCK_K = (
|
||||
32
|
||||
if rotary_dim <= 32
|
||||
else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
|
||||
)
|
||||
grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), nheads, batch) # noqa
|
||||
BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 128 else 4)
|
||||
|
||||
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
||||
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
||||
with torch.cuda.device(x.device.index):
|
||||
rotary_kernel[grid](
|
||||
output, # data ptrs
|
||||
x,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
seqlen_offsets,
|
||||
seqlen, # shapes
|
||||
rotary_dim,
|
||||
seqlen_ro,
|
||||
output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
|
||||
output.stride(-3), # seqlen_stride or total_seqlen_stride
|
||||
output.stride(-2), # nheads_stride
|
||||
output.stride(-1), # headdim_stride
|
||||
x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
|
||||
x.stride(-3), # seqlen stride or total_seqlen_stride
|
||||
x.stride(-2), # nheads stride
|
||||
x.stride(-1), # headdim stride
|
||||
BLOCK_K,
|
||||
isinstance(seqlen_offsets, torch.Tensor),
|
||||
is_varlen,
|
||||
interleaved,
|
||||
conjugate,
|
||||
BLOCK_M,
|
||||
num_warps=2 if rotary_dim <= 64 else 4,
|
||||
)
|
||||
return output
|
||||
Reference in New Issue
Block a user