Files
enginex-bi_series-vllm/pkgs/xformers/ops/fmha/decoder.py

109 lines
3.7 KiB
Python
Raw Normal View History

2025-08-05 19:02:46 +08:00
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, List, Optional, Set, Tuple
import numpy as np
import torch
from ..common import get_xformers_operator, register_operator
from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask
from .common import AttentionFwOpBase, Context, Inputs
@register_operator
class FwOp(AttentionFwOpBase):
"""An operator optimized for very small values of K (``K <= 32``) \
and f32 pre-Ampere as it does not use TensorCores.
Only supports contiguous inputs in BMK format, so an extra reshape \
or contiguous call might be done.
:Deprecated:
This operator is deprecated and should not be used in new code
"""
OPERATOR = get_xformers_operator("efficient_attention_forward_decoder")
SUPPORTED_DEVICES = {"cuda"}
SUPPORTED_DTYPES = {torch.bfloat16, torch.half, torch.float32}
CUDA_MINIMUM_COMPUTE_CAPABILITY = (7, 0)
SUPPORTED_MAX_K: float = 128
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {BlockDiagonalCausalWithOffsetPaddedKeysMask}
SUPPORTS_DROPOUT = False
SUPPORTS_CUSTOM_SCALE = True
NAME = "decoderF"
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(FwOp, cls).not_supported_reasons(d)
attn_bias = d.attn_bias
if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask):
# If we don't get here, we've an error elsewhere
if d.query.ndim != 4 or d.key.ndim != 4:
reasons.append("Inputs must be BMHK. BMK not supported")
if d.query.shape[0] != 1:
reasons.append("One formal batch element expected")
if d.query.shape[-1] != 128:
reasons.append("Only head_dim==128 for now.")
if d.key.stride(-1) != 1:
reasons.append("expect keys to have last dim contiguous")
if d.value.stride(-1) != 1:
reasons.append("expect values to have last dim contiguous")
q_starts = attn_bias.q_seqinfo.seqstart_py
if attn_bias.q_seqinfo.max_seqlen != 1:
reasons.append("decoding expects one query")
elif d.query.shape[1] != len(q_starts) - 1:
reasons.append("empty lanes not supported yet")
if attn_bias.k_seqinfo.padding > 8192:
reasons.append("key padding exceeds 8192")
return reasons
@classmethod
def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
if needs_gradient:
raise NotImplementedError("gradient")
attn_bias = inp.attn_bias
assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
attn_bias.k_seqinfo.to(inp.query.device)
attn_bias.q_seqinfo.to(inp.query.device)
padding = attn_bias.k_seqinfo.padding
multiquery = inp.key.stride(2) == 0
if multiquery:
key = inp.key[0, :, :1].unflatten(0, (-1, padding))
value = inp.value[0, :, :1].unflatten(0, (-1, padding))
else:
key = inp.key[0].unflatten(0, (-1, padding))
value = inp.value[0].unflatten(0, (-1, padding))
seq_positions = attn_bias.k_seqinfo.seqlen
query = inp.query[0, :, None]
if inp.scale is not None:
qk_scale = inp.scale
else:
qk_scale = 1.0 / np.sqrt(key.shape[-1])
out = cls.OPERATOR(
query=query,
key=key,
value=value,
seq_positions=seq_positions,
scale=qk_scale,
)
return out, None