Files
enginex-bi_series-vllm/pkgs/xformers/ops/rope_padded.py
2025-08-05 19:02:46 +08:00

189 lines
6.6 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Tuple
import torch
from xformers.ops.fmha.attn_bias import ( # type: ignore
BlockDiagonalCausalWithOffsetPaddedKeysMask,
)
from .. import _is_triton_available
def rope_padded(
xq: torch.Tensor,
xk: torch.Tensor,
xv: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
attn_bias: BlockDiagonalCausalWithOffsetPaddedKeysMask,
*,
theta: float = 10000.0,
out_q: Optional[torch.Tensor] = None,
adjacents: bool = True,
internal_dtype: str = "",
):
"""
Performs RoPE (rotary embeddings) and kv-cache emplacement for a heterogeneous
batch for inference in the style given by
BlockDiagonalCausalWithOffsetPaddedKeysMask.
The batch is concatted along the sequence dimension, so the
actual dim-0 length of all tensors is 1.
xq, xk and xv should be (1, slen, n_heads, dim), where
xq's n_heads can differ from xk and xv.
This function places the roped xk in the right place in cache_k, and
xv (unmodified) in the right place in cache_v, and returns out_q
(the roped xq) such that things are ready to call
xformers.ops.memory_efficient_attention(
out_q, cache_k, cache_v, attn_bias=attn_bias
)
This functionality is experimental. Its API might be changed without warnings.
Use it at your own risk.
Arguments:
xq: tensor of queries to apply rope to
xk: tensor of keys to apply rope to
xv: tensor of values to copy into cache_v
cache_k: cache of keys, MODIFIED IN PLACE
cache_v: cache of values, MODIFIED IN PLACE
attn_bias: details the layout of caches.
Used to determine frequencies for the
RoPE calculation as well as the locations in cache_k and cache_v
to write to. Must be on the device.
adjacents: If True, the inputs are in adjacent pairs along the final dim axis.
This is like the released LLaMA model.
If False, the dim axis is split in two equal pieces.
I.e. the features are ordered with all the real parts before all
the imaginary parts. This matches HuggingFace, e.g.
https://github.com/huggingface/transformers/blob/
f143037789288ba532dada934a118e648e715738/
src/transformers/models/llama/modeling_llama.py#L126-L130
internal_dtype: set to "f32" or "f64" to enforce dtype in the calculation
"""
if torch.is_grad_enabled() and (
xq.requires_grad
or xk.requires_grad
or xv.requires_grad
or cache_k.requires_grad
or cache_v.requires_grad
or out_q is not None
):
raise ValueError("Gradients not supported.")
assert _is_triton_available()
import triton
from .triton.rope_padded_kernels import _rope_padded_kernel
n_total_queries = attn_bias.q_seqinfo.seqstart_py[-1]
cache_length = attn_bias.k_seqinfo.seqstart_py[-1]
bsz, q_len, n_q_heads, dim = xq.shape
assert q_len == n_total_queries
if bsz != 1:
raise ValueError(
"Expected batch size dimension to be 1" "as batches should be concatenated."
)
xk_shape = xk.shape
n_kv_heads = xk_shape[2]
if xk_shape != (1, n_total_queries, n_kv_heads, dim):
raise ValueError("unexpected k shape")
if xv.shape != (1, n_total_queries, n_kv_heads, dim):
raise ValueError("unexpected v shape")
if cache_k.shape != (1, cache_length, n_kv_heads, dim):
raise ValueError("unexpected cache_k length")
if cache_v.shape != (1, cache_length, n_kv_heads, dim):
raise ValueError("unexpected cache_v length")
xq_stride = xq.stride()
xk_stride = xk.stride()
xv_stride = xv.stride()
cache_k_stride = cache_k.stride()
cache_v_stride = cache_v.stride()
if xq_stride[3] != 1:
raise ValueError("Each q head must be contiguous")
if xk_stride[3] != 1:
raise ValueError("Each k head must be contiguous")
if xv_stride[3] != 1:
raise ValueError("Each v head must be contiguous")
if cache_k_stride[3] != 1:
raise ValueError("Each cache_k head must be contiguous")
if cache_v_stride[3] != 1:
raise ValueError("Each cache_v head must be contiguous")
n_total_heads = n_q_heads + 2 * n_kv_heads
v_start = n_total_heads - n_kv_heads
k_start = n_q_heads
if out_q is None:
out_q = xq.new_empty(1, n_total_queries, n_q_heads, dim)
out_q_stride: Tuple[int, ...] = (0, n_q_heads * dim, dim, 1)
else:
if out_q.shape != xq.shape:
raise ValueError("Unexpected shape of out_q")
out_q_stride = out_q.stride()
if out_q_stride[3] != 1:
raise ValueError("Each out_q head must be contiguous")
assert out_q is not None
logical_bsz = len(attn_bias.q_seqinfo.seqstart_py) - 1
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // xq.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(dim))
BLOCK_SIZE = max(BLOCK_SIZE, 128)
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
device = xq.device
# Move these to the right device, like fmha does.
attn_bias.k_seqinfo.to(device)
attn_bias.q_seqinfo.to(device)
seqstartq = attn_bias.q_seqinfo.seqstart
seqstartk = attn_bias.k_seqinfo.seqstart
seqlenk = attn_bias.k_seqinfo.seqlen
assert internal_dtype in ["", "f32", "f64"]
# experiment with the order of dims here.
_rope_padded_kernel[(logical_bsz, attn_bias.q_seqinfo.max_seqlen, n_total_heads)](
xq,
xk,
xv,
out_q,
cache_k,
cache_v,
seqstartq,
seqstartk,
seqlenk,
theta,
k_start,
v_start,
dim,
xq_stride[1],
xq_stride[2],
xk_stride[1],
xk_stride[2],
xv_stride[1],
xv_stride[2],
cache_k_stride[1],
cache_k_stride[2],
cache_v_stride[1],
cache_v_stride[2],
seqstartq.stride(0),
seqstartk.stride(0),
seqlenk.stride(0),
out_q_stride[1],
out_q_stride[2],
internal_dtype,
const_batch_strides=False,
cache_padding_length=0,
seqlenk_shift=0,
BLOCK_SIZE=BLOCK_SIZE,
adjacents=adjacents,
num_warps=num_warps,
)
return out_q