148 lines
5.0 KiB
Python
148 lines
5.0 KiB
Python
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||
|
|
#
|
||
|
|
# This source code is licensed under the BSD license found in the
|
||
|
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
|
||
|
|
|
||
|
|
import textwrap
|
||
|
|
from collections import deque
|
||
|
|
from typing import List, Sequence, Type, TypeVar
|
||
|
|
|
||
|
|
from . import attn_bias, cutlass, decoder, flash, small_k, triton, triton_splitk
|
||
|
|
from .common import AttentionBwOpBase, AttentionFwOpBase, Inputs
|
||
|
|
|
||
|
|
|
||
|
|
def _is_cutlass_fwd_faster_than_flash(inp: Inputs) -> bool:
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
def _is_triton_fwd_fastest(inp: Inputs) -> bool:
|
||
|
|
# TODO: fill out
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
T = TypeVar("T", Type[AttentionFwOpBase], Type[AttentionBwOpBase])
|
||
|
|
|
||
|
|
|
||
|
|
def _format_inputs_description(inp: Inputs) -> str:
|
||
|
|
return f"""query : shape={tuple(inp.query.shape)} ({inp.query.dtype})
|
||
|
|
key : shape={tuple(inp.key.shape)} ({inp.key.dtype})
|
||
|
|
value : shape={tuple(inp.value.shape)} ({inp.value.dtype})
|
||
|
|
attn_bias : {type(inp.attn_bias)}
|
||
|
|
p : {inp.p}"""
|
||
|
|
|
||
|
|
|
||
|
|
def _ensure_op_supports_or_raise(exc_type, name: str, op, inp: Inputs) -> None:
|
||
|
|
reasons = op.not_supported_reasons(inp)
|
||
|
|
if not reasons:
|
||
|
|
return
|
||
|
|
raise exc_type(
|
||
|
|
f"""Operator `{name}` does not support inputs:
|
||
|
|
{textwrap.indent(_format_inputs_description(inp), ' ')}
|
||
|
|
{_format_not_supported_reasons(op, reasons)}"""
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _format_not_supported_reasons(op, reasons: List[str]) -> str:
|
||
|
|
return f"`{op.NAME}` is not supported because:\n " + "\n ".join(reasons)
|
||
|
|
|
||
|
|
|
||
|
|
def _run_priority_list(name: str, priority_list: Sequence[T], inp: Inputs) -> T:
|
||
|
|
not_supported_reasons: List[List[str]] = []
|
||
|
|
for op in priority_list:
|
||
|
|
not_supported = op.not_supported_reasons(inp)
|
||
|
|
if not not_supported:
|
||
|
|
return op
|
||
|
|
not_supported_reasons.append(not_supported)
|
||
|
|
|
||
|
|
# Let's write a nice message explaining what we tried and why it's not supported
|
||
|
|
msg = f"""No operator found for `{name}` with inputs:
|
||
|
|
{textwrap.indent(_format_inputs_description(inp), ' ')}"""
|
||
|
|
for op, not_supported in zip(priority_list, not_supported_reasons):
|
||
|
|
msg += "\n" + _format_not_supported_reasons(op, not_supported)
|
||
|
|
raise NotImplementedError(msg)
|
||
|
|
|
||
|
|
|
||
|
|
def _dispatch_fw_priority_list(
|
||
|
|
inp: Inputs, needs_gradient: bool
|
||
|
|
) -> Sequence[Type[AttentionFwOpBase]]:
|
||
|
|
priority_list_ops = deque(
|
||
|
|
[
|
||
|
|
flash.FwOp,
|
||
|
|
triton.FwOp,
|
||
|
|
cutlass.FwOp,
|
||
|
|
small_k.FwOp,
|
||
|
|
]
|
||
|
|
)
|
||
|
|
if _is_cutlass_fwd_faster_than_flash(inp):
|
||
|
|
priority_list_ops.remove(cutlass.FwOp)
|
||
|
|
priority_list_ops.appendleft(cutlass.FwOp)
|
||
|
|
if _is_triton_fwd_fastest(inp):
|
||
|
|
priority_list_ops.remove(triton.FwOp)
|
||
|
|
priority_list_ops.appendleft(triton.FwOp)
|
||
|
|
if not needs_gradient:
|
||
|
|
mqa_or_gqa = (
|
||
|
|
inp.key.ndim > 3 and inp.key.stride(-2) == 0 and inp.key.shape[-2] > 1
|
||
|
|
)
|
||
|
|
if not mqa_or_gqa:
|
||
|
|
# With multiquery, cutlass is sometimes faster than decoder
|
||
|
|
# but it's not currently clear when.
|
||
|
|
priority_list_ops.appendleft(decoder.FwOp)
|
||
|
|
# Split-KV is useful with MQA
|
||
|
|
# for short Q-seqlen / long K-seqlen
|
||
|
|
if mqa_or_gqa and inp.query.shape[1] <= 32 and inp.key.shape[1] >= 256:
|
||
|
|
parallelism_BH = 0 # BMK
|
||
|
|
if inp.query.ndim == 3:
|
||
|
|
parallelism_BH = inp.query.shape[0]
|
||
|
|
elif inp.query.ndim == 4: # BMHK
|
||
|
|
parallelism_BH = inp.query.shape[0] * inp.query.shape[2]
|
||
|
|
elif inp.query.ndim == 5: # BMGHK
|
||
|
|
parallelism_BH = inp.query.shape[0] * inp.query.shape[2]
|
||
|
|
if parallelism_BH > 0 and parallelism_BH < 64:
|
||
|
|
priority_list_ops.appendleft(triton_splitk.FwOp)
|
||
|
|
# Without variable seqlen flash is fastest
|
||
|
|
if not isinstance(inp.attn_bias, attn_bias.BlockDiagonalMask):
|
||
|
|
priority_list_ops.remove(flash.FwOp)
|
||
|
|
priority_list_ops.appendleft(flash.FwOp)
|
||
|
|
|
||
|
|
return priority_list_ops
|
||
|
|
|
||
|
|
|
||
|
|
def _dispatch_fw(inp: Inputs, needs_gradient: bool) -> Type[AttentionFwOpBase]:
|
||
|
|
"""Computes the best operator for forward
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
NotImplementedError: if not operator was found
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
AttentionOp: The best operator for the configuration
|
||
|
|
"""
|
||
|
|
# return _run_priority_list(
|
||
|
|
# "memory_efficient_attention_forward",
|
||
|
|
# _dispatch_fw_priority_list(inp, needs_gradient),
|
||
|
|
# inp,
|
||
|
|
# )
|
||
|
|
return flash.FwOp
|
||
|
|
|
||
|
|
|
||
|
|
def _is_cutlassB_faster_than_flash(inp: Inputs) -> bool:
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
def _dispatch_bw(inp: Inputs) -> Type[AttentionBwOpBase]:
|
||
|
|
priority_list_ops: List[Type[AttentionBwOpBase]] = [
|
||
|
|
flash.BwOp,
|
||
|
|
cutlass.BwOp,
|
||
|
|
# CUDA illegal memory issues, race conditions etc..
|
||
|
|
# triton.BwOp,
|
||
|
|
# Deprecated
|
||
|
|
small_k.BwOp,
|
||
|
|
]
|
||
|
|
# if _is_cutlassB_faster_than_flash(inp):
|
||
|
|
# priority_list_ops.remove(cutlass.BwOp)
|
||
|
|
# priority_list_ops.insert(0, cutlass.BwOp)
|
||
|
|
# return _run_priority_list(
|
||
|
|
# "memory_efficient_attention_backward", priority_list_ops, inp
|
||
|
|
# )
|
||
|
|
return flash.BwOp
|