First commit
This commit is contained in:
147
pkgs/xformers/ops/fmha/dispatch.py
Normal file
147
pkgs/xformers/ops/fmha/dispatch.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import textwrap
|
||||
from collections import deque
|
||||
from typing import List, Sequence, Type, TypeVar
|
||||
|
||||
from . import attn_bias, cutlass, decoder, flash, small_k, triton, triton_splitk
|
||||
from .common import AttentionBwOpBase, AttentionFwOpBase, Inputs
|
||||
|
||||
|
||||
def _is_cutlass_fwd_faster_than_flash(inp: Inputs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _is_triton_fwd_fastest(inp: Inputs) -> bool:
|
||||
# TODO: fill out
|
||||
return False
|
||||
|
||||
|
||||
T = TypeVar("T", Type[AttentionFwOpBase], Type[AttentionBwOpBase])
|
||||
|
||||
|
||||
def _format_inputs_description(inp: Inputs) -> str:
|
||||
return f"""query : shape={tuple(inp.query.shape)} ({inp.query.dtype})
|
||||
key : shape={tuple(inp.key.shape)} ({inp.key.dtype})
|
||||
value : shape={tuple(inp.value.shape)} ({inp.value.dtype})
|
||||
attn_bias : {type(inp.attn_bias)}
|
||||
p : {inp.p}"""
|
||||
|
||||
|
||||
def _ensure_op_supports_or_raise(exc_type, name: str, op, inp: Inputs) -> None:
|
||||
reasons = op.not_supported_reasons(inp)
|
||||
if not reasons:
|
||||
return
|
||||
raise exc_type(
|
||||
f"""Operator `{name}` does not support inputs:
|
||||
{textwrap.indent(_format_inputs_description(inp), ' ')}
|
||||
{_format_not_supported_reasons(op, reasons)}"""
|
||||
)
|
||||
|
||||
|
||||
def _format_not_supported_reasons(op, reasons: List[str]) -> str:
|
||||
return f"`{op.NAME}` is not supported because:\n " + "\n ".join(reasons)
|
||||
|
||||
|
||||
def _run_priority_list(name: str, priority_list: Sequence[T], inp: Inputs) -> T:
|
||||
not_supported_reasons: List[List[str]] = []
|
||||
for op in priority_list:
|
||||
not_supported = op.not_supported_reasons(inp)
|
||||
if not not_supported:
|
||||
return op
|
||||
not_supported_reasons.append(not_supported)
|
||||
|
||||
# Let's write a nice message explaining what we tried and why it's not supported
|
||||
msg = f"""No operator found for `{name}` with inputs:
|
||||
{textwrap.indent(_format_inputs_description(inp), ' ')}"""
|
||||
for op, not_supported in zip(priority_list, not_supported_reasons):
|
||||
msg += "\n" + _format_not_supported_reasons(op, not_supported)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def _dispatch_fw_priority_list(
|
||||
inp: Inputs, needs_gradient: bool
|
||||
) -> Sequence[Type[AttentionFwOpBase]]:
|
||||
priority_list_ops = deque(
|
||||
[
|
||||
flash.FwOp,
|
||||
triton.FwOp,
|
||||
cutlass.FwOp,
|
||||
small_k.FwOp,
|
||||
]
|
||||
)
|
||||
if _is_cutlass_fwd_faster_than_flash(inp):
|
||||
priority_list_ops.remove(cutlass.FwOp)
|
||||
priority_list_ops.appendleft(cutlass.FwOp)
|
||||
if _is_triton_fwd_fastest(inp):
|
||||
priority_list_ops.remove(triton.FwOp)
|
||||
priority_list_ops.appendleft(triton.FwOp)
|
||||
if not needs_gradient:
|
||||
mqa_or_gqa = (
|
||||
inp.key.ndim > 3 and inp.key.stride(-2) == 0 and inp.key.shape[-2] > 1
|
||||
)
|
||||
if not mqa_or_gqa:
|
||||
# With multiquery, cutlass is sometimes faster than decoder
|
||||
# but it's not currently clear when.
|
||||
priority_list_ops.appendleft(decoder.FwOp)
|
||||
# Split-KV is useful with MQA
|
||||
# for short Q-seqlen / long K-seqlen
|
||||
if mqa_or_gqa and inp.query.shape[1] <= 32 and inp.key.shape[1] >= 256:
|
||||
parallelism_BH = 0 # BMK
|
||||
if inp.query.ndim == 3:
|
||||
parallelism_BH = inp.query.shape[0]
|
||||
elif inp.query.ndim == 4: # BMHK
|
||||
parallelism_BH = inp.query.shape[0] * inp.query.shape[2]
|
||||
elif inp.query.ndim == 5: # BMGHK
|
||||
parallelism_BH = inp.query.shape[0] * inp.query.shape[2]
|
||||
if parallelism_BH > 0 and parallelism_BH < 64:
|
||||
priority_list_ops.appendleft(triton_splitk.FwOp)
|
||||
# Without variable seqlen flash is fastest
|
||||
if not isinstance(inp.attn_bias, attn_bias.BlockDiagonalMask):
|
||||
priority_list_ops.remove(flash.FwOp)
|
||||
priority_list_ops.appendleft(flash.FwOp)
|
||||
|
||||
return priority_list_ops
|
||||
|
||||
|
||||
def _dispatch_fw(inp: Inputs, needs_gradient: bool) -> Type[AttentionFwOpBase]:
|
||||
"""Computes the best operator for forward
|
||||
|
||||
Raises:
|
||||
NotImplementedError: if not operator was found
|
||||
|
||||
Returns:
|
||||
AttentionOp: The best operator for the configuration
|
||||
"""
|
||||
# return _run_priority_list(
|
||||
# "memory_efficient_attention_forward",
|
||||
# _dispatch_fw_priority_list(inp, needs_gradient),
|
||||
# inp,
|
||||
# )
|
||||
return flash.FwOp
|
||||
|
||||
|
||||
def _is_cutlassB_faster_than_flash(inp: Inputs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _dispatch_bw(inp: Inputs) -> Type[AttentionBwOpBase]:
|
||||
priority_list_ops: List[Type[AttentionBwOpBase]] = [
|
||||
flash.BwOp,
|
||||
cutlass.BwOp,
|
||||
# CUDA illegal memory issues, race conditions etc..
|
||||
# triton.BwOp,
|
||||
# Deprecated
|
||||
small_k.BwOp,
|
||||
]
|
||||
# if _is_cutlassB_faster_than_flash(inp):
|
||||
# priority_list_ops.remove(cutlass.BwOp)
|
||||
# priority_list_ops.insert(0, cutlass.BwOp)
|
||||
# return _run_priority_list(
|
||||
# "memory_efficient_attention_backward", priority_list_ops, inp
|
||||
# )
|
||||
return flash.BwOp
|
||||
Reference in New Issue
Block a user