init
This commit is contained in:
15
vllm/attention/__init__.py
Normal file
15
vllm/attention/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
|
||||
__all__ = [
|
||||
"Attention",
|
||||
"AttentionBackend",
|
||||
"AttentionMetadata",
|
||||
"AttentionType",
|
||||
"get_attn_backend",
|
||||
]
|
||||
0
vllm/attention/backends/__init__.py
Normal file
0
vllm/attention/backends/__init__.py
Normal file
204
vllm/attention/backends/abstract.py
Normal file
204
vllm/attention/backends/abstract.py
Normal file
@@ -0,0 +1,204 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, List, Optional, Protocol, Tuple, Type, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
|
||||
|
||||
class AttentionType:
|
||||
"""
|
||||
Attention type.
|
||||
Use string to be compatible with `torch.compile`.
|
||||
"""
|
||||
DECODER = "decoder"
|
||||
"""Decoder attention between previous layer Q/K/V."""
|
||||
ENCODER = "encoder"
|
||||
"""Encoder attention between previous layer Q/K/V for encoder-decoder."""
|
||||
ENCODER_ONLY = "encoder_only"
|
||||
"""Encoder attention between previous layer Q/K/V."""
|
||||
ENCODER_DECODER = "encoder_decoder"
|
||||
"""Attention between dec. Q and enc. K/V for encoder-decoder."""
|
||||
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
"""Abstract class for attention backends."""
|
||||
# For some attention backends, we allocate an output tensor before
|
||||
# calling the custom op. When piecewise cudagraph is enabled, this
|
||||
# makes sure the output tensor is allocated inside the cudagraph.
|
||||
accept_output_buffer: bool = False
|
||||
|
||||
# Whether this backend supports receiving pre-quantized query input.
|
||||
# If True, the attention layer will handle query quantization instead
|
||||
# of the backend, allowing torch.compile to fuse quantization with
|
||||
# previous operations.
|
||||
# Needs to be worked through for all backends
|
||||
# https://github.com/vllm-project/vllm/issues/25584
|
||||
supports_quant_query_input: bool = False
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_name() -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_impl_cls() -> Type["AttentionImpl"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
|
||||
return cls.get_metadata_cls()(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> Tuple[int, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> Tuple[int, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def full_cls_name(cls) -> tuple[str, str]:
|
||||
return (cls.__module__, cls.__qualname__)
|
||||
|
||||
|
||||
class AttentionMetadata:
|
||||
pass
|
||||
|
||||
|
||||
T = TypeVar("T", bound=AttentionMetadata)
|
||||
|
||||
|
||||
class AttentionLayer(Protocol):
|
||||
|
||||
_q_scale: torch.Tensor
|
||||
_k_scale: torch.Tensor
|
||||
_v_scale: torch.Tensor
|
||||
_q_scale_float: float
|
||||
_k_scale_float: float
|
||||
_v_scale_float: float
|
||||
_prob_scale: torch.Tensor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
|
||||
|
||||
class AttentionImpl(ABC, Generic[T]):
|
||||
|
||||
# Whether the attention impl can return the softmax lse for decode.
|
||||
# Some features like decode context parallelism require the softmax lse.
|
||||
can_return_lse_for_decode: bool = False
|
||||
|
||||
# some attention backends might not always want to return lse
|
||||
# even if they can return lse (for efficiency reasons)
|
||||
need_to_return_lse_for_decode: bool = False
|
||||
|
||||
dcp_world_size: int
|
||||
dcp_rank: int
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# use __new__ so that all subclasses will call this
|
||||
self = super().__new__(cls)
|
||||
try:
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.need_to_return_lse_for_decode = self.dcp_world_size > 1 \
|
||||
and self.can_return_lse_for_decode
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
kv_cache_dtype: str = "auto",
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
"""
|
||||
Does this attention implementation support fused output quantization.
|
||||
This is used by the AttnFusionPass to only fuse output quantization
|
||||
onto implementations that support it.
|
||||
|
||||
:param quant_key: QuantKey object that describes the quantization op
|
||||
:return: is fusion supported for this type of quantization
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states_or_cq: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
||||
return kv_cache_dtype != "auto"
|
||||
33
vllm/attention/backends/utils.py
Normal file
33
vllm/attention/backends/utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention backend utils"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLADims:
|
||||
q_lora_rank: Optional[int]
|
||||
kv_lora_rank: int
|
||||
qk_nope_head_dim: int
|
||||
qk_rope_head_dim: int
|
||||
v_head_dim: int
|
||||
|
||||
|
||||
def get_mla_dims(model_config: ModelConfig) -> MLADims:
|
||||
hf_text_config = model_config.hf_text_config
|
||||
|
||||
return MLADims(
|
||||
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
|
||||
kv_lora_rank=hf_text_config.kv_lora_rank,
|
||||
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
|
||||
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
|
||||
v_head_dim=hf_text_config.v_head_dim,
|
||||
)
|
||||
645
vllm/attention/layer.py
Normal file
645
vllm/attention/layer.py
Normal file
@@ -0,0 +1,645 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer."""
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import GiB_bytes, direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
USE_XFORMERS_OPS = None
|
||||
try:
|
||||
tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, )
|
||||
except AttributeError:
|
||||
tag_cudagraph_unsafe = () # type: ignore[assignment]
|
||||
|
||||
|
||||
def check_xformers_availability():
|
||||
global USE_XFORMERS_OPS
|
||||
if USE_XFORMERS_OPS is not None:
|
||||
return USE_XFORMERS_OPS
|
||||
|
||||
if current_platform.is_cuda() and current_platform.has_device_capability(
|
||||
100):
|
||||
# Xformers FA is not compatible with B200
|
||||
USE_XFORMERS_OPS = False
|
||||
else:
|
||||
try:
|
||||
from importlib.util import find_spec
|
||||
|
||||
find_spec("xformers.ops")
|
||||
USE_XFORMERS_OPS = True
|
||||
except ImportError:
|
||||
USE_XFORMERS_OPS = False
|
||||
|
||||
# the warning only needs to be shown once
|
||||
if not USE_XFORMERS_OPS:
|
||||
logger.warning("Xformers is not available, falling back.")
|
||||
|
||||
return USE_XFORMERS_OPS
|
||||
|
||||
|
||||
def check_upstream_fa_availability(dtype: torch.dtype):
|
||||
if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda(
|
||||
) and current_platform.has_device_capability(80):
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
return is_flash_attn_2_available()
|
||||
return False
|
||||
|
||||
|
||||
class Attention(nn.Module, AttentionLayerBase):
|
||||
"""Attention layer.
|
||||
|
||||
This class takes query, key, and value tensors as input. The input tensors
|
||||
can either contain prompt tokens or generation tokens.
|
||||
The class does the following:
|
||||
|
||||
1. Store the input key and value tensors in the KV cache.
|
||||
2. Perform (multi-head/multi-query/grouped-query) attention.
|
||||
3. Return the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
per_layer_sliding_window: Optional[int] = None,
|
||||
use_mla: bool = False,
|
||||
use_sparse: bool = False,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
attn_backend: Optional[type[AttentionBackend]] = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
"""
|
||||
The KV cache is stored inside this class and is accessed via
|
||||
`self.kv_cache`.
|
||||
"""
|
||||
super().__init__()
|
||||
if per_layer_sliding_window is not None:
|
||||
# per-layer sliding window
|
||||
sliding_window = per_layer_sliding_window
|
||||
elif cache_config is not None:
|
||||
# model-level sliding window
|
||||
sliding_window = cache_config.sliding_window
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
calculate_kv_scales = cache_config.calculate_kv_scales
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
calculate_kv_scales = False
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = num_heads
|
||||
assert num_heads % num_kv_heads == 0, \
|
||||
f"num_heads ({num_heads}) is not " \
|
||||
f"divisible by num_kv_heads ({num_kv_heads})"
|
||||
|
||||
# The default k/v_scale is set to 1.0. This is ignored
|
||||
# when kv-cache is not fp8, and should be used with
|
||||
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
|
||||
# expect the pre-quantized k/v_scale to be loaded along
|
||||
# with the model weights.
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
# FlashAttn doesn't support quantizing the kv-cache only
|
||||
# but requires q to be quantized as well.
|
||||
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
|
||||
# We also keep q/k/v_scale on host (cpu) memory for attention
|
||||
# backends that require the scales to be on host instead of on device.
|
||||
# e.g. Flashinfer
|
||||
self._q_scale_float = 1.0
|
||||
self._k_scale_float = 1.0
|
||||
self._v_scale_float = 1.0
|
||||
|
||||
# The output scale on host memory. This should be the input scale of
|
||||
# the quant op after this attention layer.
|
||||
self._o_scale_float: Optional[float] = None
|
||||
|
||||
self.use_mla = use_mla
|
||||
self.use_sparse = use_sparse
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.has_sink = extra_impl_args.get("sinks") is not None
|
||||
|
||||
quant_method = quant_config.get_quant_method(
|
||||
self, prefix=prefix) if quant_config else None
|
||||
if quant_method is not None and not isinstance(
|
||||
quant_method, UnquantizedLinearMethod):
|
||||
assert isinstance(quant_method, BaseKVCacheMethod)
|
||||
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
||||
# checkpoint config and become the "auto" behavior
|
||||
if self.kv_cache_dtype == "fp8_e5m2":
|
||||
raise ValueError("fp8_e5m2 kv-cache is not supported with "
|
||||
"fp8 checkpoints.")
|
||||
# If quantization is enabled, we make "k_scale" and "v_scale"
|
||||
# parameters so that it can be loaded from the model checkpoint.
|
||||
# The k/v_scale will then be converted back to native float32
|
||||
# values after weight loading.
|
||||
self.quant_method = quant_method
|
||||
self.quant_method.create_weights(self)
|
||||
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
if attn_backend is None:
|
||||
self.attn_backend = get_attn_backend(head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla=use_mla,
|
||||
has_sink=self.has_sink,
|
||||
use_sparse=use_sparse)
|
||||
else:
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
impl_cls = self.attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **extra_impl_args)
|
||||
self.backend = backend_name_to_enum(self.attn_backend.get_name())
|
||||
self.dtype = dtype
|
||||
|
||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||
# torch.compile works by registering the attention as one giant
|
||||
# opaque custom op. For other platforms, we directly call them
|
||||
# and let torch.compile handle them.
|
||||
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||
|
||||
self.use_output = self.attn_backend.accept_output_buffer
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.layer_name = prefix
|
||||
self.attn_type = attn_type
|
||||
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
validate_kv_sharing_target(
|
||||
prefix,
|
||||
kv_sharing_target_layer_name,
|
||||
compilation_config.static_forward_context,
|
||||
)
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
# use a placeholder kv cache tensor during init, which will be replaced
|
||||
# by bind_kv_cache
|
||||
# this variable will not be accessed if use_direct_call is True
|
||||
self.kv_cache = [
|
||||
torch.tensor([]) for _ in range(get_current_vllm_config(
|
||||
).parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
try:
|
||||
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT,
|
||||
dtype=torch.float32)
|
||||
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT,
|
||||
dtype=torch.float32)
|
||||
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT,
|
||||
dtype=torch.float32)
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
logger.error(
|
||||
"Failed to initialize attention q/k/v range constants: %s", e)
|
||||
if torch.cuda.is_available():
|
||||
logger.debug("CUDA device: %s", torch.cuda.current_device())
|
||||
logger.debug("Allocated: %.2f GiB",
|
||||
torch.cuda.memory_allocated() / GiB_bytes)
|
||||
logger.debug("Reserved: %.2f GiB",
|
||||
torch.cuda.memory_reserved() / GiB_bytes)
|
||||
raise RuntimeError(
|
||||
"Failed to initialize q/k/v range constants. "
|
||||
"This may be caused by insufficient memory to allocate "
|
||||
"kv cache.") from e
|
||||
|
||||
# for attn backends supporting query quantization
|
||||
self.query_quant = None
|
||||
if self.kv_cache_dtype.startswith(
|
||||
"fp8") and self.attn_backend.supports_quant_query_input:
|
||||
self.query_quant = QuantFP8(static=True,
|
||||
group_shape=GroupShape.PER_TENSOR)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
# For some alternate attention backends like MLA the attention output
|
||||
# shape does not match the query shape, so we optionally let the model
|
||||
# definition specify the output tensor shape.
|
||||
output_shape: Optional[torch.Size] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
The KV cache is stored inside this class and is accessed via
|
||||
`self.kv_cache`.
|
||||
|
||||
Attention metadata (`attn_metadata`) is set using a context manager in
|
||||
the model runner's `execute_model` method. It is accessed via forward
|
||||
context using
|
||||
`vllm.forward_context.get_forward_context().attn_metadata`.
|
||||
"""
|
||||
if self.calculate_kv_scales:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata.enable_kv_scales_calculation:
|
||||
self.calc_kv_scales(query, key, value)
|
||||
|
||||
output_dtype = query.dtype
|
||||
if self.query_quant is not None:
|
||||
# quantizing with a simple torch operation enables
|
||||
# torch.compile to fuse this into previous ops
|
||||
# which reduces overheads during decoding.
|
||||
# Otherwise queries are quantized using custom ops
|
||||
# which causes decoding overheads
|
||||
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
|
||||
query, _ = self.query_quant(query, self._q_scale)
|
||||
|
||||
if self.use_output:
|
||||
output_shape = (output_shape
|
||||
if output_shape is not None else query.shape)
|
||||
output = torch.zeros(output_shape,
|
||||
dtype=output_dtype,
|
||||
device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
# We skip reshaping query, key and value tensors for the MLA
|
||||
# backend since these tensors have different semantics and are
|
||||
# processed differently.
|
||||
if not self.use_mla:
|
||||
# Reshape the query, key, and value tensors.
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||
# CPU overheads from the non-CUDA-graph regions.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
output = output.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
if self.use_direct_call:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self_kv_cache,
|
||||
attn_metadata,
|
||||
output=output)
|
||||
else:
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
query, key, value, output, self.layer_name)
|
||||
return output.view(-1, hidden_size)
|
||||
else:
|
||||
if self.use_direct_call:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
return self.impl.forward(self, query, key, value,
|
||||
self_kv_cache, attn_metadata)
|
||||
else:
|
||||
return torch.ops.vllm.unified_attention(
|
||||
query, key, value, self.layer_name)
|
||||
|
||||
def calc_kv_scales(self, query, key, value):
|
||||
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
|
||||
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
|
||||
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
|
||||
self._q_scale_float = self._q_scale.item()
|
||||
self._k_scale_float = self._k_scale.item()
|
||||
self._v_scale_float = self._v_scale.item()
|
||||
# We only calculate the scales once
|
||||
self.calculate_kv_scales = False
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"head_size={self.impl.head_size}" # type: ignore
|
||||
s += f", num_heads={self.impl.num_heads}" # type: ignore
|
||||
s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
|
||||
s += f", scale={self.impl.scale}" # type: ignore
|
||||
s += f", backend={self.impl.__class__.__name__}"
|
||||
return s
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
if hasattr(self.impl, "process_weights_after_loading"):
|
||||
self.impl.process_weights_after_loading(act_dtype)
|
||||
|
||||
# FlashInfer requires attention sinks to be float32
|
||||
if (self.backend == _Backend.FLASHINFER
|
||||
and hasattr(self.impl, 'sinks')):
|
||||
from vllm.v1.attention.backends.flashinfer import FlashInferImpl
|
||||
assert isinstance(self.impl, FlashInferImpl)
|
||||
if (self.impl.sinks is not None
|
||||
and self.impl.sinks.dtype != torch.float32):
|
||||
self.impl.sinks = self.impl.sinks.to(torch.float32)
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return self.attn_backend
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""Multi-headed attention without any cache, used for ViT."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0, \
|
||||
f"num_heads ({self.num_heads}) is not " \
|
||||
f"divisible by num_kv_heads ({self.num_kv_heads})"
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
# Determine the attention backend
|
||||
backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)
|
||||
|
||||
# Some auto-selected backends can be upgraded
|
||||
# to upstream flash attention if available.
|
||||
# If vllm native fa is selected, we use it directly.
|
||||
use_upstream_fa = False
|
||||
if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||
dtype):
|
||||
backend = _Backend.FLASH_ATTN
|
||||
use_upstream_fa = True
|
||||
|
||||
if current_platform.is_rocm() or current_platform.is_xpu():
|
||||
# currently, only torch_sdpa is supported on rocm/xpu
|
||||
self.attn_backend = _Backend.TORCH_SDPA
|
||||
else:
|
||||
|
||||
self.attn_backend = backend if backend in {
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.PALLAS,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
_Backend.FLASH_ATTN,
|
||||
} else _Backend.TORCH_SDPA
|
||||
|
||||
if (self.attn_backend == _Backend.XFORMERS
|
||||
and not check_xformers_availability()):
|
||||
self.attn_backend = _Backend.TORCH_SDPA
|
||||
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
if use_upstream_fa:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
self._flash_attn_varlen_func = flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
self._flash_attn_varlen_func = flash_attn_varlen_func
|
||||
|
||||
logger.info_once(
|
||||
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
|
||||
f"use_upstream_fa: {use_upstream_fa}")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Input shape:
|
||||
(batch_size x seq_len x hidden_size) or
|
||||
(batch_size x seq_len x num_heads x head_size)
|
||||
"""
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
|
||||
query = query.view(bsz, q_len, self.num_heads, self.head_size)
|
||||
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (num_repeat := self.num_queries_per_kv) > 1:
|
||||
# Handle MQA and GQA
|
||||
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
|
||||
step=q_len,
|
||||
dtype=torch.int32,
|
||||
device=query.device)
|
||||
cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
|
||||
step=kv_len,
|
||||
dtype=torch.int32,
|
||||
device=key.device)
|
||||
|
||||
out = self._flash_attn_varlen_func(
|
||||
query.flatten(0, 1),
|
||||
key.flatten(0, 1),
|
||||
value.flatten(0, 1),
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=q_len,
|
||||
max_seqlen_k=kv_len,
|
||||
softmax_scale=self.scale,
|
||||
)
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
|
||||
out = xops.memory_efficient_attention_forward(query,
|
||||
key,
|
||||
value,
|
||||
scale=self.scale)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
query, key, value = (x.transpose(1, 2)
|
||||
for x in (query, key, value))
|
||||
out = F.scaled_dot_product_attention(query,
|
||||
key,
|
||||
value,
|
||||
scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
elif self.attn_backend == _Backend.PALLAS:
|
||||
query, key, value = (x.transpose(1, 2)
|
||||
for x in (query, key, value))
|
||||
from torch_xla.experimental.custom_kernel import flash_attention
|
||||
out = flash_attention(query, key, value, sm_scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
elif self.attn_backend == _Backend.ROCM_AITER_FA:
|
||||
from aiter import flash_attn_varlen_func
|
||||
|
||||
# ROCm Flash Attention expects (batch, seq, heads, head_dim)
|
||||
out = flash_attn_varlen_func(query,
|
||||
key,
|
||||
value,
|
||||
softmax_scale=self.scale)
|
||||
else:
|
||||
# ViT attention hasn't supported this backend yet
|
||||
raise NotImplementedError(
|
||||
f"ViT attention hasn't supported {self.attn_backend} "
|
||||
f"backend yet.")
|
||||
|
||||
return out.reshape(bsz, q_len, -1)
|
||||
|
||||
|
||||
def wait_for_kv_layer_from_connector(layer_name: str):
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return
|
||||
|
||||
connector = get_kv_transfer_group()
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
assert isinstance(attn_metadata, dict)
|
||||
connector.wait_for_layer_load(layer_name)
|
||||
|
||||
|
||||
def maybe_save_kv_layer_to_connector(
|
||||
layer_name: str,
|
||||
kv_cache_layer: List[torch.Tensor],
|
||||
):
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return
|
||||
|
||||
connector = get_kv_transfer_group()
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
assert isinstance(attn_metadata, dict)
|
||||
connector.save_kv_layer(layer_name, kv_cache_layer,
|
||||
attn_metadata[layer_name])
|
||||
|
||||
|
||||
def unified_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
output = self.impl.forward(self, query, key, value, kv_cache,
|
||||
attn_metadata)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
return output
|
||||
|
||||
|
||||
def unified_attention_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query).contiguous()
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_attention",
|
||||
op_func=unified_attention,
|
||||
fake_impl=unified_attention_fake,
|
||||
tags=tag_cudagraph_unsafe,
|
||||
)
|
||||
|
||||
|
||||
def unified_attention_with_output(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
output_scale=output_scale,
|
||||
output_block_scale=output_block_scale)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
|
||||
|
||||
def unified_attention_with_output_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_attention_with_output",
|
||||
op_func=unified_attention_with_output,
|
||||
mutates_args=["output", "output_block_scale"],
|
||||
fake_impl=unified_attention_with_output_fake,
|
||||
tags=tag_cudagraph_unsafe,
|
||||
)
|
||||
0
vllm/attention/layers/__init__.py
Normal file
0
vllm/attention/layers/__init__.py
Normal file
93
vllm/attention/layers/chunked_local_attention.py
Normal file
93
vllm/attention/layers/chunked_local_attention.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from typing import ClassVar, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata)
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig, QuantizationConfig
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, CommonAttentionMetadata,
|
||||
make_local_attention_virtual_batches, subclass_attention_backend)
|
||||
|
||||
from ..layer import Attention
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_chunked_local_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend,
|
||||
attention_chunk_size: int,
|
||||
block_size: int,
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
|
||||
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.NEVER
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> AttentionMetadata:
|
||||
common_attn_metadata = make_local_attention_virtual_batches(
|
||||
attention_chunk_size, common_attn_metadata, block_size)
|
||||
return super().build(common_prefix_len, common_attn_metadata,
|
||||
fast_build)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=ChunkedLocalAttentionBuilder)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class ChunkedLocalAttention(Attention):
|
||||
|
||||
def __init__(self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
attention_chunk_size: int,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
prefix: str = ""):
|
||||
dtype = torch.get_default_dtype()
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
underlying_attn_backend = get_attn_backend(head_size, dtype,
|
||||
kv_cache_dtype,
|
||||
block_size)
|
||||
|
||||
attn_backend = create_chunked_local_attention_backend(
|
||||
underlying_attn_backend, attention_chunk_size, block_size)
|
||||
else:
|
||||
# in v0 the local attention is handled inside the backends
|
||||
attn_backend = None
|
||||
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=alibi_slopes,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
||||
attn_backend=attn_backend)
|
||||
162
vllm/attention/layers/cross_attention.py
Normal file
162
vllm/attention/layers/cross_attention.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from copy import copy
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||
subclass_attention_backend)
|
||||
from vllm.v1.kv_cache_interface import CrossAttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_max_encoder_len(vllm_config: "VllmConfig") -> int:
|
||||
"""Gets the max number of encoder input tokens from the config.
|
||||
"""
|
||||
sc = vllm_config.scheduler_config
|
||||
assert sc and isinstance(sc.max_num_encoder_input_tokens, int), \
|
||||
"max_num_encoder_input_tokens must be int for enc-dec models"
|
||||
return sc.max_num_encoder_input_tokens
|
||||
|
||||
|
||||
def _get_cross_slot_mapping(encoder_seq_lens: np.ndarray,
|
||||
block_table_tensor: torch.Tensor,
|
||||
kv_cache_spec: CrossAttentionSpec,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
"""Get cross-attention slot mappings."""
|
||||
|
||||
block_size = kv_cache_spec.block_size
|
||||
slot_mappings = []
|
||||
|
||||
# Find indices with non-zero encoder sequence lengths
|
||||
# The majority of parallel requests will be running the
|
||||
# decoder, so this list should be relatively small.
|
||||
active_indices = np.nonzero(encoder_seq_lens)[0]
|
||||
|
||||
for req_index in active_indices:
|
||||
encoder_seq_len = encoder_seq_lens[req_index].item()
|
||||
|
||||
# Calculate the number of blocks needed for this request
|
||||
num_blocks_needed = cdiv(encoder_seq_len, block_size)
|
||||
|
||||
# Get the block IDs for this request from the tensor
|
||||
req_block_ids = block_table_tensor[req_index]
|
||||
|
||||
# Get only the blocks we need (first num_blocks_needed blocks)
|
||||
needed_block_ids = req_block_ids[:num_blocks_needed]
|
||||
|
||||
# All needed blocks are allocated
|
||||
i_values = torch.arange(encoder_seq_len,
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
block_indices = i_values // block_size
|
||||
block_offsets = i_values % block_size
|
||||
block_numbers = needed_block_ids[block_indices]
|
||||
slot_mapping = block_numbers * block_size + block_offsets
|
||||
|
||||
slot_mappings.append(slot_mapping)
|
||||
|
||||
if slot_mappings:
|
||||
return torch.cat(slot_mappings)
|
||||
else:
|
||||
return torch.empty(0, dtype=torch.int64, device=device)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_cross_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]:
|
||||
prefix = "CrossAttention_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class CrossAttentionBuilder(underlying_builder): # type: ignore
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> AttentionMetadata:
|
||||
new_metadata = copy(common_attn_metadata)
|
||||
new_metadata.causal = False
|
||||
max_encoder_len = _get_max_encoder_len(self.vllm_config)
|
||||
new_metadata.max_seq_len = max_encoder_len
|
||||
|
||||
new_metadata.seq_lens = torch.full(
|
||||
(new_metadata.num_reqs, ),
|
||||
max_encoder_len,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
new_metadata.seq_lens_cpu = torch.full(
|
||||
(new_metadata.num_reqs, ),
|
||||
max_encoder_len,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
new_metadata.slot_mapping = _get_cross_slot_mapping(
|
||||
new_metadata.encoder_seq_lens, new_metadata.block_table_tensor,
|
||||
self.kv_cache_spec, self.device)
|
||||
return super().build(common_prefix_len, new_metadata, fast_build)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=CrossAttentionBuilder)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class CrossAttention(Attention):
|
||||
"""
|
||||
Cross-attention for encoder-decoder models.
|
||||
Handles attention between decoder queries and encoder keys/values.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
attn_type: Optional[str] = None,
|
||||
**kwargs):
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
underlying_attn_backend = get_attn_backend(head_size, dtype,
|
||||
kv_cache_dtype,
|
||||
block_size)
|
||||
|
||||
attn_backend = create_cross_attention_backend(
|
||||
underlying_attn_backend)
|
||||
else:
|
||||
# in v0 cross attention is handled inside the backends
|
||||
attn_backend = None
|
||||
|
||||
if attn_type is not None:
|
||||
assert attn_type == AttentionType.ENCODER_DECODER, (
|
||||
"CrossAttention only supports AttentionType.ENCODER_DECODER")
|
||||
|
||||
super().__init__(num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
cache_config=cache_config,
|
||||
attn_backend=attn_backend,
|
||||
attn_type=AttentionType.ENCODER_DECODER,
|
||||
**kwargs)
|
||||
86
vllm/attention/layers/encoder_only_attention.py
Normal file
86
vllm/attention/layers/encoder_only_attention.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from copy import copy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||
subclass_attention_backend)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_encoder_only_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]:
|
||||
prefix = "EncoderOnlyAttention_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> AttentionMetadata:
|
||||
new_common_attn_metadata = copy(common_attn_metadata)
|
||||
new_common_attn_metadata.causal = False
|
||||
return super().build(common_prefix_len, new_common_attn_metadata,
|
||||
fast_build)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=EncoderOnlyAttentionBuilder)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class EncoderOnlyAttention(Attention):
|
||||
"""
|
||||
Encoder attention is a special case that doesn't need a KV Cache.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
attn_type: Optional[str] = None,
|
||||
**kwargs):
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
underlying_attn_backend = get_attn_backend(head_size, dtype,
|
||||
kv_cache_dtype,
|
||||
block_size)
|
||||
|
||||
attn_backend = create_encoder_only_attention_backend(
|
||||
underlying_attn_backend)
|
||||
else:
|
||||
# in v0 encoder only attention is handled inside the backends
|
||||
attn_backend = None
|
||||
|
||||
if attn_type is not None:
|
||||
assert attn_type == AttentionType.ENCODER_ONLY, \
|
||||
"EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
|
||||
|
||||
super().__init__(num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
cache_config=cache_config,
|
||||
attn_backend=attn_backend,
|
||||
attn_type=AttentionType.ENCODER_ONLY,
|
||||
**kwargs)
|
||||
0
vllm/attention/ops/__init__.py
Normal file
0
vllm/attention/ops/__init__.py
Normal file
405
vllm/attention/ops/chunked_prefill_paged_decode.py
Normal file
405
vllm/attention/ops/chunked_prefill_paged_decode.py
Normal file
@@ -0,0 +1,405 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Authors:
|
||||
# - Burkhard Ringlein <ngl@zurich.ibm.com>
|
||||
# - Jan van Lunteren <jvl@zurich.ibm.com>
|
||||
# - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
|
||||
# - Thomas Parnell <tpa@zurich.ibm.com>
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .prefix_prefill import context_attention_fwd
|
||||
|
||||
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cdiv_fn(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_paged_attention_2d(
|
||||
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
|
||||
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
|
||||
sink_ptr, # [num_query_heads]
|
||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
alibi_slopes_ptr, # [num_query_heads]
|
||||
scale, # float32
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
out_scale_inv,
|
||||
num_query_heads: tl.constexpr, # int
|
||||
num_queries_per_kv: tl.constexpr, # int
|
||||
num_queries_per_kv_padded: tl.constexpr, # int
|
||||
block_table_stride: tl.int64, # int
|
||||
query_stride_0: tl.int64, # int
|
||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||
output_stride_0: tl.int64, # int
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
x: tl.constexpr, # int
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
stride_k_cache_2: tl.int64, # int
|
||||
stride_k_cache_3: tl.int64, # int
|
||||
stride_k_cache_4: tl.int64, # int
|
||||
stride_v_cache_0: tl.int64, # int
|
||||
stride_v_cache_1: tl.int64, # int
|
||||
stride_v_cache_2: tl.int64, # int
|
||||
stride_v_cache_3: tl.int64, # int
|
||||
filter_by_query_len: tl.constexpr, # bool
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
USE_SINKS: tl.constexpr, # bool
|
||||
USE_FP8: tl.constexpr,
|
||||
FP8_MIN: tl.constexpr = float8_info.min,
|
||||
FP8_MAX: tl.constexpr = float8_info.max):
|
||||
seq_idx = tl.program_id(0)
|
||||
kv_head_idx = tl.program_id(1)
|
||||
|
||||
if filter_by_query_len:
|
||||
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
||||
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx +
|
||||
1)
|
||||
cur_batch_query_len = cur_batch_in_all_stop_index \
|
||||
- cur_batch_in_all_start_index
|
||||
if cur_batch_query_len > 1:
|
||||
return
|
||||
else:
|
||||
cur_batch_in_all_start_index = seq_idx
|
||||
|
||||
query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange(
|
||||
0, num_queries_per_kv_padded)
|
||||
|
||||
query_offset = (cur_batch_in_all_start_index * query_stride_0 +
|
||||
query_head_idx[:, None] * query_stride_1)
|
||||
|
||||
head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv
|
||||
head_mask = head_mask & (query_head_idx < num_query_heads)
|
||||
|
||||
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
|
||||
0).to(tl.int1)
|
||||
|
||||
# Q : (num_queries_per_kv, HEAD_SIZE,)
|
||||
Q = tl.load(
|
||||
query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :],
|
||||
mask=dim_mask[None, :] & head_mask[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
block_table_offset = seq_idx * block_table_stride
|
||||
|
||||
if not USE_SINKS:
|
||||
M = tl.full([num_queries_per_kv_padded],
|
||||
float("-inf"),
|
||||
dtype=tl.float32)
|
||||
else:
|
||||
M = tl.load(
|
||||
sink_ptr + query_head_idx,
|
||||
mask=head_mask,
|
||||
other=float("-inf"),
|
||||
).to(dtype=tl.float32)
|
||||
|
||||
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED],
|
||||
dtype=tl.float32)
|
||||
|
||||
# sequence len for this particular sequence
|
||||
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||
|
||||
# alibi slope for this head
|
||||
if USE_ALIBI_SLOPES:
|
||||
alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx,
|
||||
mask=head_mask,
|
||||
other=0.0)
|
||||
|
||||
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
|
||||
|
||||
# iterate through tiles
|
||||
for j in range(0, num_blocks):
|
||||
|
||||
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
|
||||
|
||||
offs_n = tl.arange(0, BLOCK_SIZE)
|
||||
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||
|
||||
v_offset = (physical_block_idx * stride_v_cache_0 +
|
||||
kv_head_idx * stride_v_cache_1 +
|
||||
offs_d[None, :] * stride_v_cache_2 +
|
||||
offs_n[:, None] * stride_v_cache_3)
|
||||
|
||||
k_offset = (physical_block_idx * stride_k_cache_0 +
|
||||
kv_head_idx * stride_k_cache_1 +
|
||||
(offs_d[:, None] // x) * stride_k_cache_2 +
|
||||
offs_n[None, :] * stride_k_cache_3 +
|
||||
(offs_d[:, None] % x) * stride_k_cache_4)
|
||||
|
||||
# K : (HEAD_SIZE, BLOCK_SIZE)
|
||||
K_load = tl.load(key_cache_ptr + k_offset,
|
||||
mask=dim_mask[:, None],
|
||||
other=0.0)
|
||||
|
||||
if K_load.dtype.is_fp8():
|
||||
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
||||
else:
|
||||
K = K_load
|
||||
|
||||
# V : (BLOCK_SIZE, HEAD_SIZE)
|
||||
V_load = tl.load(value_cache_ptr + v_offset,
|
||||
mask=dim_mask[None, :],
|
||||
other=0.0)
|
||||
|
||||
if V_load.dtype.is_fp8():
|
||||
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
||||
else:
|
||||
V = V_load
|
||||
|
||||
seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32)
|
||||
seq_mask = seq_offset[None, :] < boundary
|
||||
|
||||
# S : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
S = tl.where(head_mask[:, None] & seq_mask, 0.0,
|
||||
float("-inf")).to(tl.float32)
|
||||
S += scale * tl.dot(Q, K)
|
||||
|
||||
context_len = seq_len - 1
|
||||
|
||||
if SLIDING_WINDOW > 0:
|
||||
S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S,
|
||||
-10000)
|
||||
|
||||
if USE_ALIBI_SLOPES:
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
|
||||
# compute running maximum
|
||||
# m_j : (num_queries_per_kv,)
|
||||
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||
|
||||
# P : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
P = tl.exp(S - m_j[:, None])
|
||||
|
||||
# l_j : (num_queries_per_kv,)
|
||||
l_j = tl.sum(P, axis=1)
|
||||
|
||||
# alpha : (num_queries_per_kv, )
|
||||
alpha = tl.exp(M - m_j)
|
||||
|
||||
# acc : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update constants
|
||||
L = L * alpha + l_j
|
||||
M = m_j
|
||||
|
||||
# acc : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
acc += tl.dot(P.to(V.dtype), V)
|
||||
|
||||
# epilogue
|
||||
acc = acc / L[:, None]
|
||||
if USE_FP8:
|
||||
acc = acc * tl.load(out_scale_inv)
|
||||
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||
|
||||
output_offset = (cur_batch_in_all_start_index * output_stride_0 +
|
||||
query_head_idx * output_stride_1)
|
||||
|
||||
tl.store(
|
||||
output_ptr + output_offset[:, None] +
|
||||
tl.arange(0, HEAD_SIZE_PADDED)[None, :],
|
||||
acc,
|
||||
mask=dim_mask[None, :] & head_mask[:, None],
|
||||
)
|
||||
|
||||
|
||||
def chunked_prefill_paged_decode(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_table,
|
||||
query_start_loc,
|
||||
seq_lens,
|
||||
max_seq_len,
|
||||
max_query_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
sm_scale=None,
|
||||
output_scale=None,
|
||||
# Optional tensor for sinks
|
||||
sinks=None,
|
||||
):
|
||||
|
||||
if sm_scale is None:
|
||||
sm_scale = 1.0 / (query.shape[1]**0.5)
|
||||
|
||||
use_alibi_slopes = alibi_slopes is not None
|
||||
|
||||
if sliding_window is None or sliding_window <= 0:
|
||||
sliding_window = 0
|
||||
|
||||
if max_query_len > 1:
|
||||
context_attention_fwd(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
o=output,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
b_loc=block_table,
|
||||
b_start_loc=query_start_loc,
|
||||
b_seq_len=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
max_input_len=max_query_len,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
alibi_slopes=alibi_slopes,
|
||||
sliding_window=sliding_window,
|
||||
sm_scale=sm_scale,
|
||||
skip_decode=True,
|
||||
fp8_out_scale=output_scale,
|
||||
sinks=sinks,
|
||||
)
|
||||
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs = len(seq_lens)
|
||||
num_query_heads = query.shape[1]
|
||||
num_kv_heads = key.shape[1]
|
||||
num_queries_per_kv = query.shape[1] // key.shape[1]
|
||||
head_size = query.shape[2]
|
||||
|
||||
# Conversion of FP8 Tensor from uint8 storage to
|
||||
# appropriate torch.dtype for interpretation by Triton
|
||||
if "fp8" in kv_cache_dtype:
|
||||
assert key_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
assert value_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
|
||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||
target_dtype = current_platform.fp8_dtype()
|
||||
elif kv_cache_dtype == "fp8_e5m2":
|
||||
target_dtype = torch.float8_e5m2
|
||||
else:
|
||||
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
|
||||
|
||||
key_cache = key_cache.view(target_dtype)
|
||||
value_cache = value_cache.view(target_dtype)
|
||||
|
||||
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
|
||||
16)
|
||||
|
||||
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||
use_custom = use_rocm_custom_paged_attention(
|
||||
query.dtype,
|
||||
head_size,
|
||||
block_size,
|
||||
num_queries_per_kv,
|
||||
max_seq_len,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
alibi_slopes,
|
||||
sinks,
|
||||
)
|
||||
if use_custom:
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
|
||||
_PARTITION_SIZE_ROCM)
|
||||
assert _PARTITION_SIZE_ROCM % block_size == 0
|
||||
total_num_seq = block_table.shape[0]
|
||||
tmp_output = torch.empty(
|
||||
size=(total_num_seq, num_query_heads, max_num_partitions,
|
||||
head_size),
|
||||
dtype=query.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(total_num_seq, num_query_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
|
||||
ops.paged_attention_rocm(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale=sm_scale,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
query_start_loc=query_start_loc,
|
||||
block_size=block_size,
|
||||
max_seq_len=max_seq_len,
|
||||
alibi_slopes=alibi_slopes,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
fp8_out_scale=output_scale,
|
||||
)
|
||||
else:
|
||||
kernel_paged_attention_2d[(
|
||||
num_seqs,
|
||||
num_kv_heads,
|
||||
)](
|
||||
output_ptr=output,
|
||||
query_ptr=query,
|
||||
key_cache_ptr=key_cache,
|
||||
value_cache_ptr=value_cache,
|
||||
sink_ptr=sinks,
|
||||
block_tables_ptr=block_table,
|
||||
seq_lens_ptr=seq_lens,
|
||||
alibi_slopes_ptr=alibi_slopes,
|
||||
scale=sm_scale,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
out_scale_inv=1.0 /
|
||||
output_scale if output_scale is not None else 1.0,
|
||||
num_query_heads=num_query_heads,
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
num_queries_per_kv_padded=num_queries_per_kv_padded,
|
||||
block_table_stride=block_table.stride(0),
|
||||
query_stride_0=query.stride(0),
|
||||
query_stride_1=query.stride(1),
|
||||
output_stride_0=output.stride(0),
|
||||
output_stride_1=output.stride(1),
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
SLIDING_WINDOW=sliding_window,
|
||||
x=key_cache.shape[4],
|
||||
stride_k_cache_0=key_cache.stride(0),
|
||||
stride_k_cache_1=key_cache.stride(1),
|
||||
stride_k_cache_2=key_cache.stride(2),
|
||||
stride_k_cache_3=key_cache.stride(3),
|
||||
stride_k_cache_4=key_cache.stride(4),
|
||||
stride_v_cache_0=value_cache.stride(0),
|
||||
stride_v_cache_1=value_cache.stride(1),
|
||||
stride_v_cache_2=value_cache.stride(2),
|
||||
stride_v_cache_3=value_cache.stride(3),
|
||||
filter_by_query_len=True,
|
||||
query_start_len_ptr=query_start_loc,
|
||||
USE_SINKS=sinks is not None,
|
||||
USE_FP8=output_scale is not None,
|
||||
)
|
||||
345
vllm/attention/ops/common.py
Normal file
345
vllm/attention/ops/common.py
Normal file
@@ -0,0 +1,345 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _correct_attn_cp_out_kernel(outputs_ptr, new_output_ptr, lses_ptr,
|
||||
vlse_ptr, outputs_stride_B, outputs_stride_H,
|
||||
outputs_stride_D, lses_stride_N, lses_stride_B,
|
||||
lses_stride_H, lse_idx, HEAD_DIM: tl.constexpr,
|
||||
N_ROUNDED: tl.constexpr):
|
||||
"""
|
||||
Apply the all-gathered lses to correct each local rank's attention
|
||||
output. we still need perform a cross-rank reduction to obtain the
|
||||
final attention output.
|
||||
|
||||
Args:
|
||||
outputs_ptr (triton.PointerType):
|
||||
Pointer to input tensor of shape [ B, H, D ]
|
||||
lses_ptr (triton.PointerType):
|
||||
Pointer to input tensor of shape [ N, B, H ]
|
||||
new_output_ptr (triton.PointerType):
|
||||
Pointer to output tensor of shape [ B, H, D ]
|
||||
vlse_ptr (triton.PointerType):
|
||||
Pointer to output tensor of shape [ B, H ]
|
||||
"""
|
||||
batch_idx = tl.program_id(axis=0).to(tl.int64)
|
||||
head_idx = tl.program_id(axis=1).to(tl.int64)
|
||||
d_offsets = tl.arange(0, HEAD_DIM)
|
||||
num_n_offsets = tl.arange(0, N_ROUNDED)
|
||||
|
||||
# shape = [N]
|
||||
lse_offsets = num_n_offsets * lses_stride_N + batch_idx * \
|
||||
lses_stride_B + head_idx * lses_stride_H
|
||||
|
||||
# calc final lse
|
||||
lse = tl.load(lses_ptr + lse_offsets)
|
||||
lse = tl.where((lse != lse) | (lse == float('inf')), -float('inf'), lse)
|
||||
lse_max = tl.max(lse, axis=0)
|
||||
lse -= lse_max
|
||||
lse_exp = tl.exp(lse)
|
||||
lse_acc = tl.sum(lse_exp, axis=0)
|
||||
lse = tl.log(lse_acc)
|
||||
lse += lse_max
|
||||
|
||||
lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
|
||||
tl.store(vlse_ptr + lse_offsets, lse)
|
||||
|
||||
# shape = [D]
|
||||
output_offsets = batch_idx * outputs_stride_B + \
|
||||
head_idx * outputs_stride_H + \
|
||||
d_offsets * outputs_stride_D
|
||||
|
||||
# correct output
|
||||
lse_offset = lse_idx * lses_stride_N + batch_idx * \
|
||||
lses_stride_B + head_idx * lses_stride_H
|
||||
lse_tmp = tl.load(lses_ptr + lse_offset)
|
||||
lse_finally = lse_tmp - lse
|
||||
lse_finally = tl.where(
|
||||
(lse_finally != lse_finally) | (lse_finally == float('inf')),
|
||||
-float('inf'), lse_finally)
|
||||
factor = tl.exp(lse_finally)
|
||||
output = tl.load(outputs_ptr + output_offsets)
|
||||
output = output * factor
|
||||
|
||||
tl.store(new_output_ptr + output_offsets, output)
|
||||
|
||||
|
||||
class CPTritonContext:
|
||||
""" The CPTritonContext is used to avoid recompilation of the Triton JIT.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.inner_kernel = None
|
||||
|
||||
def call_kernel(self, kernel, grid, *regular_args, **const_args):
|
||||
if self.inner_kernel is None:
|
||||
self.inner_kernel = kernel[grid](*regular_args, **const_args)
|
||||
else:
|
||||
self.inner_kernel[grid](*regular_args)
|
||||
|
||||
|
||||
def correct_attn_out(
|
||||
out: torch.Tensor, lses: torch.Tensor, cp_rank: int,
|
||||
ctx: CPTritonContext) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Correct the attention output using the all-gathered lses.
|
||||
|
||||
Args:
|
||||
out: Tensor of shape [ B, H, D ]
|
||||
lses: Tensor of shape [ N, B, H ]
|
||||
cp_rank: Current rank in the context-parallel group
|
||||
ctx: Triton context to avoid recompilation
|
||||
|
||||
Returns:
|
||||
Tuple of (out, lse) with corrected attention and final log-sum-exp.
|
||||
"""
|
||||
if ctx is None:
|
||||
ctx = CPTritonContext()
|
||||
|
||||
lse = torch.empty_like(lses[0])
|
||||
|
||||
grid = (out.shape[0], out.shape[1], 1)
|
||||
regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(),
|
||||
cp_rank)
|
||||
const_args = {
|
||||
"HEAD_DIM": out.shape[-1],
|
||||
"N_ROUNDED": lses.shape[0],
|
||||
}
|
||||
|
||||
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args,
|
||||
**const_args)
|
||||
return out, lse
|
||||
|
||||
|
||||
def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
|
||||
cp_attn_lse: torch.Tensor,
|
||||
cp_group: GroupCoordinator,
|
||||
ctx: CPTritonContext = None):
|
||||
"""
|
||||
cp_attn_out: [ B, H, D ]
|
||||
cp_attn_lse: [ B, H ]
|
||||
"""
|
||||
if cp_group.world_size == 1:
|
||||
return cp_attn_out
|
||||
|
||||
if ctx is None:
|
||||
ctx = CPTritonContext()
|
||||
|
||||
lses = torch.empty((cp_group.world_size, ) + cp_attn_lse.shape,
|
||||
dtype=cp_attn_lse.dtype,
|
||||
device=cp_attn_lse.device)
|
||||
|
||||
cp_attn_lse = cp_attn_lse.contiguous()
|
||||
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
||||
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
||||
out = cp_group.reduce_scatter(out, dim=1)
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _pack_seq_kernel(
|
||||
x_ptr, # [N, D]
|
||||
out_ptr, # [B, Lmax, D]
|
||||
lengths_ptr, # *i32, [B]
|
||||
N: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
Lmax: tl.constexpr,
|
||||
PAD_VALUE: tl.constexpr,
|
||||
BLOCK_T: tl.constexpr, # timesteps per program
|
||||
BLOCK_D: tl.constexpr # features per program
|
||||
):
|
||||
pid_b = tl.program_id(0) # batch id
|
||||
pid_t = tl.program_id(1) # block over time dimension
|
||||
pid_d = tl.program_id(2) # block over feature dimension
|
||||
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
|
||||
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
|
||||
|
||||
# Compute start index and sequence length from cumulative lengths
|
||||
in_start = 0
|
||||
for i in range(pid_b):
|
||||
in_start += tl.load(lengths_ptr + i)
|
||||
seq_len = tl.load(lengths_ptr + pid_b)
|
||||
|
||||
# valid time positions for this block
|
||||
t_mask = off_t < Lmax
|
||||
|
||||
# compute input row indices for valid (b, t)
|
||||
in_row = in_start + off_t
|
||||
valid_row = (off_t < seq_len) & t_mask
|
||||
|
||||
# Pointers
|
||||
# x_ptr: row-major [N, D]
|
||||
x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :]
|
||||
|
||||
# out_ptr: row-major [B, Lmax, D]
|
||||
out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:,
|
||||
None] * D + off_d[None, :]
|
||||
|
||||
# Initialize with PAD (cast will occur as needed based on out_ptr dtype)
|
||||
d_mask = off_d[None, :] < D
|
||||
pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32)
|
||||
tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask)
|
||||
|
||||
# Load & write only where within seq_len
|
||||
x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask)
|
||||
tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask)
|
||||
|
||||
|
||||
def pack_seq_triton(x: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
pad_value: float = -float('inf'),
|
||||
block_t: int = 64,
|
||||
block_d: int = 64) -> torch.Tensor:
|
||||
"""
|
||||
Pack sequences of different lengths into a batched tensor.
|
||||
|
||||
Args:
|
||||
x: [N, ...] - input tensor where N is total number of tokens
|
||||
lengths: [B] - sequence lengths for each batch
|
||||
pad_value: value to use for padding
|
||||
block_t: block size for time dimension
|
||||
block_d: block size for feature dimension
|
||||
|
||||
Returns:
|
||||
packed: [B, Lmax, ...] - packed tensor
|
||||
"""
|
||||
|
||||
# Handle multi-dimensional input by reshaping to (N, -1)
|
||||
original_shape = x.shape
|
||||
if len(original_shape) > 2:
|
||||
N = original_shape[0]
|
||||
x_reshaped = x.reshape(N, -1)
|
||||
D = x_reshaped.shape[1]
|
||||
else:
|
||||
N, D = x.shape
|
||||
x_reshaped = x
|
||||
|
||||
B = lengths.numel()
|
||||
Lmax = int(lengths.max().item())
|
||||
|
||||
# Starts are computed inside the kernel from lengths
|
||||
|
||||
out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype)
|
||||
|
||||
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
|
||||
_pack_seq_kernel[grid](x_reshaped,
|
||||
out,
|
||||
lengths.int(),
|
||||
N,
|
||||
D,
|
||||
Lmax,
|
||||
PAD_VALUE=float(pad_value),
|
||||
BLOCK_T=block_t,
|
||||
BLOCK_D=block_d,
|
||||
num_warps=4,
|
||||
num_stages=2)
|
||||
|
||||
# Reshape output back to original dimensions (except first dimension)
|
||||
if len(original_shape) > 2:
|
||||
output_shape = (B, Lmax) + original_shape[1:]
|
||||
out = out.reshape(output_shape)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _unpack_seq_triton_kernel(
|
||||
packed_ptr, # [B, Lmax, D]
|
||||
out_ptr, # [N, D]
|
||||
lengths_ptr, # *i32, [B]
|
||||
B: tl.constexpr,
|
||||
Lmax: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
BLOCK_T: tl.constexpr, # timesteps per program
|
||||
BLOCK_D: tl.constexpr # features per program
|
||||
):
|
||||
pid_b = tl.program_id(0) # batch id
|
||||
pid_t = tl.program_id(1) # block over time dimension
|
||||
pid_d = tl.program_id(2) # block over feature dimension
|
||||
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
|
||||
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
|
||||
|
||||
# bounds: compute start from cumulative lengths
|
||||
in_start = 0
|
||||
for i in range(pid_b):
|
||||
in_start += tl.load(lengths_ptr + i)
|
||||
seq_len = tl.load(lengths_ptr + pid_b)
|
||||
|
||||
# valid time positions for this block
|
||||
t_mask = off_t < Lmax
|
||||
valid_row = (off_t < seq_len) & t_mask
|
||||
|
||||
# compute output row indices for valid (b, t)
|
||||
out_row = in_start + off_t
|
||||
|
||||
# Pointers
|
||||
# packed_ptr: row-major [B, Lmax, D]
|
||||
packed_row_ptr = packed_ptr + (pid_b * Lmax +
|
||||
off_t)[:, None] * D + off_d[None, :]
|
||||
|
||||
# out_ptr: row-major [N, D]
|
||||
out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :]
|
||||
|
||||
# Load from packed tensor and store to output
|
||||
d_mask = off_d[None, :] < D
|
||||
packed_vals = tl.load(packed_row_ptr, mask=valid_row[:, None] & d_mask)
|
||||
tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask)
|
||||
|
||||
|
||||
def unpack_seq_triton(packed_tensor: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
block_t: int = 64,
|
||||
block_d: int = 64) -> torch.Tensor:
|
||||
"""
|
||||
Unpack a packed decode query tensor back to the original format.
|
||||
Efficient Triton implementation.
|
||||
|
||||
Args:
|
||||
packed_tensor: [B, Lmax, ...] - packed tensor from pack_seq_triton
|
||||
lengths: [B] - sequence lengths for each batch
|
||||
block_t: block size for time dimension
|
||||
block_d: block size for feature dimension
|
||||
|
||||
Returns:
|
||||
unpacked_tensor: [N, ...] where N = sum(lengths)
|
||||
"""
|
||||
|
||||
# Handle multi-dimensional input by reshaping to (B, Lmax, -1)
|
||||
original_shape = packed_tensor.shape
|
||||
if len(original_shape) > 3:
|
||||
B, Lmax = original_shape[:2]
|
||||
packed_reshaped = packed_tensor.reshape(B, Lmax, -1)
|
||||
D = packed_reshaped.shape[2]
|
||||
else:
|
||||
B, Lmax, D = packed_tensor.shape
|
||||
packed_reshaped = packed_tensor
|
||||
|
||||
# Calculate total number of elements
|
||||
N = int(lengths.sum().item())
|
||||
|
||||
out = torch.empty((N, D),
|
||||
device=packed_tensor.device,
|
||||
dtype=packed_tensor.dtype)
|
||||
|
||||
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
|
||||
_unpack_seq_triton_kernel[grid](packed_reshaped,
|
||||
out,
|
||||
lengths.int(),
|
||||
B,
|
||||
Lmax,
|
||||
D,
|
||||
BLOCK_T=block_t,
|
||||
BLOCK_D=block_d,
|
||||
num_warps=4,
|
||||
num_stages=2)
|
||||
|
||||
# Reshape output back to original dimensions (except first dimension)
|
||||
if len(original_shape) > 3:
|
||||
output_shape = (N, ) + original_shape[2:]
|
||||
out = out.reshape(output_shape)
|
||||
|
||||
return out
|
||||
192
vllm/attention/ops/flashmla.py
Normal file
192
vllm/attention/ops/flashmla.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if current_platform.is_cuda():
|
||||
try:
|
||||
import vllm._flashmla_C # noqa: F401
|
||||
_flashmla_C_AVAILABLE = True
|
||||
except ImportError:
|
||||
_flashmla_C_AVAILABLE = False
|
||||
else:
|
||||
_flashmla_C_AVAILABLE = False
|
||||
|
||||
if current_platform.is_cuda():
|
||||
try:
|
||||
import vllm._flashmla_extension_C # noqa: F401
|
||||
_flashmla_extension_C_AVAILABLE = True
|
||||
except ImportError:
|
||||
_flashmla_extension_C_AVAILABLE = False
|
||||
else:
|
||||
_flashmla_extension_C_AVAILABLE = False
|
||||
|
||||
|
||||
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Return: is_supported_flag, unsupported_reason (optional).
|
||||
"""
|
||||
if not current_platform.is_cuda():
|
||||
return False, "FlashMLA is only supported on CUDA devices."
|
||||
if current_platform.get_device_capability()[0] != 9:
|
||||
return False, "FlashMLA is only supported on Hopper devices."
|
||||
if not _flashmla_C_AVAILABLE:
|
||||
return False, "vllm._flashmla_C is not available, likely was not "\
|
||||
"compiled due to insufficient nvcc version or a supported arch "\
|
||||
"(only sm90a currently) was not in the list of target arches to "\
|
||||
"compile for."
|
||||
return True, None
|
||||
|
||||
|
||||
def get_mla_metadata(
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_q_tokens_per_head_k: int,
|
||||
num_heads_k: int,
|
||||
num_heads_q: Optional[int] = None,
|
||||
is_fp8_kvcache: bool = False,
|
||||
topk: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
- cache_seqlens: (batch_size), dtype torch.int32.
|
||||
- num_q_tokens_per_head_k:
|
||||
Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
|
||||
- num_heads_k: The number of k heads.
|
||||
- num_heads_q:
|
||||
The number of q heads.
|
||||
This argument is optional when sparse attention is not enabled
|
||||
- is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
|
||||
- topk: If not None, sparse attention will be enabled,
|
||||
and only tokens in the `indices` array
|
||||
passed to `flash_mla_with_kvcache_sm90` will be attended to.
|
||||
|
||||
Returns:
|
||||
- tile_scheduler_metadata:
|
||||
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
||||
- num_splits: (batch_size + 1), dtype torch.int32.
|
||||
"""
|
||||
return torch.ops._flashmla_C.get_mla_decoding_metadata(
|
||||
cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q,
|
||||
is_fp8_kvcache, topk)
|
||||
|
||||
|
||||
def flash_mla_with_kvcache(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
head_dim_v: int,
|
||||
tile_scheduler_metadata: torch.Tensor,
|
||||
num_splits: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
descale_q: Optional[torch.Tensor] = None,
|
||||
descale_k: Optional[torch.Tensor] = None,
|
||||
is_fp8_kvcache: bool = False,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
- q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
- k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||
- block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||
- cache_seqlens: (batch_size), torch.int32.
|
||||
- head_dim_v: Head dimension of v.
|
||||
- tile_scheduler_metadata:
|
||||
(num_sm_parts, TileSchedulerMetaDataSize), torch.int32,
|
||||
returned by get_mla_metadata.
|
||||
- num_splits:
|
||||
(batch_size + 1), torch.int32, returned by get_mla_metadata.
|
||||
- softmax_scale: float.
|
||||
The scale of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(head_dim).
|
||||
- causal: bool. Whether to apply causal attention mask.
|
||||
- descale_q: (batch_size),
|
||||
torch.float32. Descaling factors for Q, used for fp8 quantization.
|
||||
- descale_k: (batch_size),
|
||||
torch.float32. Descaling factors for K, used for fp8 quantization.
|
||||
- is_fp8_kvcache: bool.
|
||||
Whether the k_cache and v_cache are in fp8 format.
|
||||
For the format of FP8 KV cache, please refer to README.md
|
||||
- indices: (batch_size, seq_len_q, topk), torch.int32.
|
||||
If not None, sparse attention will be enabled,
|
||||
and only tokens in the `indices` array will be attended to.
|
||||
Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
|
||||
For details about how to set up `indices`, please refer to README.md.
|
||||
|
||||
Returns:
|
||||
- out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
- softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1]**(-0.5)
|
||||
if indices is not None:
|
||||
# NOTE (zyongye): sparse attention is also causal
|
||||
# since it only attend to the tokens before
|
||||
# but here `causal` should not be specified
|
||||
assert not causal, \
|
||||
"causal must be `false` if sparse attention is enabled."
|
||||
assert (descale_q is None) == (
|
||||
descale_k is None
|
||||
), "descale_q and descale_k should be both None or both not None"
|
||||
|
||||
if indices is None and q.element_size() == 1:
|
||||
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
|
||||
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
|
||||
causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)
|
||||
else:
|
||||
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
|
||||
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
|
||||
causal, tile_scheduler_metadata, num_splits, is_fp8_kvcache,
|
||||
indices)
|
||||
return out, softmax_lse
|
||||
|
||||
|
||||
def flash_mla_sparse_prefill(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
sm_scale: float,
|
||||
d_v: int = 512,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sparse attention prefill kernel
|
||||
|
||||
Args:
|
||||
- q: [s_q, h_q, d_qk], bfloat16
|
||||
- kv: [s_kv, h_kv, d_qk], bfloat16
|
||||
- indices: [s_q, h_kv, topk], int32.
|
||||
Invalid indices should be set to -1 or numbers >= s_kv
|
||||
- sm_scale: float
|
||||
- d_v: The dimension of value vectors. Can only be 512
|
||||
|
||||
Returns:
|
||||
- (output, max_logits, lse)
|
||||
About the definition of output,
|
||||
max_logits and lse, please refer to README.md
|
||||
- output: [s_q, h_q, d_v], bfloat16
|
||||
- max_logits: [s_q, h_q], float
|
||||
- lse: [s_q, h_q], float, 2-based log-sum-exp
|
||||
"""
|
||||
results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices,
|
||||
sm_scale, d_v)
|
||||
return results
|
||||
|
||||
|
||||
#
|
||||
# TODO: Add fake functions
|
||||
#
|
||||
# @register_fake("_flashmla_C::get_mla_metadata")
|
||||
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# return ....
|
||||
#
|
||||
# @register_fake("_flashmla_C::fwd_kvcache_mla")
|
||||
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# return ....
|
||||
#
|
||||
43
vllm/attention/ops/merge_attn_states.py
Normal file
43
vllm/attention/ops/merge_attn_states.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def merge_attn_states(
|
||||
output: torch.Tensor,
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
suffix_output: torch.Tensor,
|
||||
suffix_lse: torch.Tensor,
|
||||
output_lse: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
|
||||
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel
|
||||
# is not support for FP8 dtype, fallback to use Triton kernel.
|
||||
def supported_dtypes(o: torch.Tensor) -> bool:
|
||||
return o.dtype in [torch.float32, torch.half, torch.bfloat16]
|
||||
|
||||
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA
|
||||
# kernel load/store 128b(16 bytes) per memory issue within
|
||||
# thread. Namely, the headsize(headdim) must be multiple of
|
||||
# pack_size (float32 -> 4, half/bfloat16 -> 8).
|
||||
def supported_headdim(o: torch.Tensor) -> bool:
|
||||
headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
if o.dtype == torch.float32:
|
||||
return headdim % 4 == 0
|
||||
return headdim % 8 == 0
|
||||
|
||||
if (current_platform.is_cuda() and supported_dtypes(output)
|
||||
and supported_headdim(output)):
|
||||
from vllm._custom_ops import merge_attn_states
|
||||
return merge_attn_states(output, prefix_output, prefix_lse,
|
||||
suffix_output, suffix_lse, output_lse)
|
||||
else:
|
||||
from vllm.attention.ops.triton_merge_attn_states import (
|
||||
merge_attn_states)
|
||||
return merge_attn_states(output, prefix_output, prefix_lse,
|
||||
suffix_output, suffix_lse, output_lse)
|
||||
262
vllm/attention/ops/paged_attn.py
Normal file
262
vllm/attention/ops/paged_attn.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm import _custom_ops as ops
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
|
||||
@dataclass
|
||||
class PagedAttentionMetadata:
|
||||
"""Metadata for PagedAttention."""
|
||||
# (batch_size,). The length of sequences (entire tokens seen so far) per
|
||||
# sequence.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
|
||||
class PagedAttention:
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [32, 64, 80, 96, 112, 120, 128, 192, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size * num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def split_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x = 16 // kv_cache.element_size()
|
||||
num_blocks = kv_cache.shape[1]
|
||||
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
|
||||
-1, x)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
||||
return key_cache, value_cache
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
) -> None:
|
||||
ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping.flatten(),
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
tp_rank: int = 0,
|
||||
blocksparse_local_blocks: int = 0,
|
||||
blocksparse_vert_stride: int = 0,
|
||||
blocksparse_block_size: int = 64,
|
||||
blocksparse_head_sliding_step: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||
# use blocksparse paged attention
|
||||
block_size = value_cache.size(-1)
|
||||
assert (blocksparse_block_size > 0 and
|
||||
blocksparse_block_size % block_size == 0), \
|
||||
(f"{blocksparse_block_size=} needs to be a multiple of"
|
||||
f"{block_size=} used in block_tables.")
|
||||
|
||||
output = torch.empty_like(query)
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
# TODO(woosuk): Tune this heuristic.
|
||||
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
|
||||
use_v1 = (max_seq_len <= 8192
|
||||
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
|
||||
|
||||
if use_v1:
|
||||
# Run PagedAttention V1.
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def forward_prefix(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
seq_lens_tensor: torch.Tensor,
|
||||
max_query_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
sliding_window: Optional[int],
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty_like(query)
|
||||
max_seq_len = None
|
||||
context_attention_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
# query_start_loc is (batch_size + 1,)
|
||||
query_start_loc,
|
||||
seq_lens_tensor,
|
||||
max_seq_len,
|
||||
max_query_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache = src_kv_cache[0]
|
||||
dst_key_cache = dst_kv_cache[0]
|
||||
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||
|
||||
src_value_cache = src_kv_cache[1]
|
||||
dst_value_cache = dst_kv_cache[1]
|
||||
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
124
vllm/attention/ops/pallas_kv_cache_update.py
Normal file
124
vllm/attention/ops/pallas_kv_cache_update.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
|
||||
import jax
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
|
||||
from vllm.utils import cdiv
|
||||
|
||||
|
||||
def _kv_cache_update_kernel(
|
||||
# Prefetch
|
||||
slices_ref, # [3, padded_num_slices], list of (kv_cache_start,
|
||||
# new_kv_start, slice_len)
|
||||
num_slices_ref, # [1]
|
||||
# Input
|
||||
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
|
||||
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
|
||||
# head_dim]
|
||||
# Output
|
||||
_, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
|
||||
# Scratch
|
||||
scratch, # [num_slices_per_block, page_size, num_combined_kv_heads,
|
||||
# head_dim]
|
||||
sem,
|
||||
):
|
||||
async_copies = []
|
||||
block_idx = pl.program_id(0)
|
||||
num_slices_per_block = scratch.shape[0]
|
||||
|
||||
# Copy from new_kv_hbm_ref to scratch
|
||||
for i in range(num_slices_per_block):
|
||||
offset_i = i + block_idx * num_slices_per_block
|
||||
new_kv_start = jax.lax.select(offset_i < num_slices_ref[0],
|
||||
slices_ref[1, offset_i], 0)
|
||||
length = jax.lax.select(offset_i < num_slices_ref[0],
|
||||
slices_ref[2, offset_i], 0)
|
||||
async_copy = pltpu.make_async_copy(
|
||||
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
|
||||
scratch.at[i, pl.ds(0, length), ...],
|
||||
sem,
|
||||
)
|
||||
async_copy.start()
|
||||
async_copies.append(async_copy)
|
||||
|
||||
for async_copy in async_copies:
|
||||
async_copy.wait()
|
||||
|
||||
# Copy from scratch to kv_cache_hbm_ref
|
||||
async_copies.clear()
|
||||
for i in range(num_slices_per_block):
|
||||
offset_i = i + block_idx * num_slices_per_block
|
||||
kv_cache_start = jax.lax.select(offset_i < num_slices_ref[0],
|
||||
slices_ref[0, offset_i], 0)
|
||||
length = jax.lax.select(offset_i < num_slices_ref[0],
|
||||
slices_ref[2, offset_i], 0)
|
||||
async_copy = pltpu.make_async_copy(
|
||||
scratch.at[i, pl.ds(0, length), ...],
|
||||
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
|
||||
sem,
|
||||
)
|
||||
async_copy.start()
|
||||
async_copies.append(async_copy)
|
||||
for async_copy in async_copies:
|
||||
async_copy.wait()
|
||||
|
||||
|
||||
@functools.partial(
|
||||
jax.jit,
|
||||
static_argnames=["page_size", "num_slices_per_block"],
|
||||
)
|
||||
def kv_cache_update(
|
||||
new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim]
|
||||
slices: jax.
|
||||
Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
|
||||
kv_cache: jax.
|
||||
Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
|
||||
num_kv_update_slices: jax.Array, # [1]
|
||||
*,
|
||||
page_size: int = 32,
|
||||
num_slices_per_block: int = 8,
|
||||
):
|
||||
_, num_combined_kv_heads, head_dim = new_kv.shape
|
||||
assert kv_cache.shape[1] == num_combined_kv_heads
|
||||
assert kv_cache.shape[2] == head_dim
|
||||
assert head_dim % 128 == 0
|
||||
# TODO: Add dynamic check to make sure that the all the slice lengths are
|
||||
# smaller or equal to page_size
|
||||
|
||||
in_specs = [
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||
]
|
||||
|
||||
out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
|
||||
out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]
|
||||
|
||||
scalar_prefetches = [slices, num_kv_update_slices]
|
||||
scratch = pltpu.VMEM(
|
||||
(num_slices_per_block, page_size, num_combined_kv_heads, head_dim),
|
||||
new_kv.dtype,
|
||||
)
|
||||
|
||||
scratch_shapes = [
|
||||
scratch,
|
||||
pltpu.SemaphoreType.DMA,
|
||||
]
|
||||
|
||||
kernel = pl.pallas_call(
|
||||
_kv_cache_update_kernel,
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=len(scalar_prefetches),
|
||||
in_specs=in_specs,
|
||||
out_specs=out_specs,
|
||||
grid=(cdiv(num_kv_update_slices[0], num_slices_per_block), ),
|
||||
scratch_shapes=scratch_shapes,
|
||||
),
|
||||
out_shape=out_shape,
|
||||
input_output_aliases={len(scalar_prefetches) + 1: 0},
|
||||
)
|
||||
|
||||
return kernel(*scalar_prefetches, new_kv, kv_cache)[0]
|
||||
928
vllm/attention/ops/prefix_prefill.py
Normal file
928
vllm/attention/ops/prefix_prefill.py
Normal file
@@ -0,0 +1,928 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
|
||||
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
# Static kernels parameters
|
||||
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
|
||||
NUM_WARPS = 4 if current_platform.is_rocm() else 8
|
||||
|
||||
# To check compatibility
|
||||
IS_TURING = current_platform.get_device_capability() == (7, 5)
|
||||
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
|
||||
|
||||
# Here's an example autotuner config for this kernel. This config does provide
|
||||
# a performance improvement, but dramatically increases first call latency in
|
||||
# triton 3.2. Because of this tradeoff, it's currently commented out.
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \
|
||||
# "num_unroll_cache": 4, \
|
||||
# "num_unroll_request": 1 } | \
|
||||
# ({"kpack": 2, "waves_per_eu": 2} \
|
||||
# if current_platform.is_rocm() else {}), \
|
||||
# num_warps=4, \
|
||||
# num_stages=1)
|
||||
# ],
|
||||
# key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"]
|
||||
# )
|
||||
@triton.jit
|
||||
def _fwd_kernel(Q,
|
||||
K,
|
||||
V,
|
||||
K_cache,
|
||||
V_cache,
|
||||
sink_ptr,
|
||||
B_Loc,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
out_scale_inv,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
x: tl.constexpr,
|
||||
Out,
|
||||
stride_b_loc_b,
|
||||
stride_b_loc_s,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_k_cache_bs,
|
||||
stride_k_cache_h,
|
||||
stride_k_cache_d,
|
||||
stride_k_cache_bl: tl.constexpr,
|
||||
stride_k_cache_x,
|
||||
stride_v_cache_bs,
|
||||
stride_v_cache_h,
|
||||
stride_v_cache_d,
|
||||
stride_v_cache_bl,
|
||||
num_queries_per_kv: tl.constexpr,
|
||||
IN_PRECISION: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DMODEL_PADDED: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SLIDING_WINDOW: tl.constexpr,
|
||||
num_unroll_cache: tl.constexpr,
|
||||
num_unroll_request: tl.constexpr,
|
||||
SKIP_DECODE: tl.constexpr,
|
||||
USE_SINKS: tl.constexpr,
|
||||
USE_FP8: tl.constexpr,
|
||||
MAX_Q_LEN: tl.constexpr = 0,
|
||||
MAX_CTX_LEN: tl.constexpr = 0,
|
||||
FP8_MIN: tl.constexpr = float8_info.min,
|
||||
FP8_MAX: tl.constexpr = float8_info.max):
|
||||
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // num_queries_per_kv
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
|
||||
cur_batch_query_len = (cur_batch_in_all_stop_index -
|
||||
cur_batch_in_all_start_index)
|
||||
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
|
||||
|
||||
if SKIP_DECODE and cur_batch_query_len == 1:
|
||||
return
|
||||
|
||||
# start position inside of the query
|
||||
# generally, N goes over kv, while M goes over query_len
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
# [BLOCK_SIZE]; starts at 0
|
||||
offs_bs_n = tl.arange(0, BLOCK_SIZE)
|
||||
# [N]; starts at 0
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
# [D]; starts at 0
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
|
||||
# [M]; starts at current position in query
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# [M,D]
|
||||
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||
|
||||
dim_mask = tl.where(
|
||||
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,
|
||||
0).to(tl.int1) # [D]
|
||||
|
||||
q = tl.load(Q + off_q,
|
||||
mask=dim_mask[None, :] &
|
||||
(offs_m[:, None] < cur_batch_query_len),
|
||||
other=0.0) # [M,D]
|
||||
|
||||
# initialize pointer to m and l
|
||||
if not USE_SINKS:
|
||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
else:
|
||||
m_i = tl.load(
|
||||
sink_ptr + tl.full([BLOCK_M], cur_head, dtype=tl.int64),
|
||||
mask=(offs_m < cur_batch_query_len),
|
||||
other=float("-inf"),
|
||||
).to(dtype=tl.float32)
|
||||
|
||||
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D]
|
||||
|
||||
# compute query against context (no causal mask here)
|
||||
for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \
|
||||
loop_unroll_factor=num_unroll_cache):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_SIZE)
|
||||
# -- compute qk ----
|
||||
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||
(start_n // BLOCK_SIZE) * stride_b_loc_s).to(tl.int64)
|
||||
# [D,BLOCK_SIZE]
|
||||
off_k = (
|
||||
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||
((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl +
|
||||
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||
|
||||
# [BLOCK_SIZE,D]
|
||||
off_v = (bn[:, None] * stride_v_cache_bs +
|
||||
cur_kv_head * stride_v_cache_h +
|
||||
offs_d[None, :] * stride_v_cache_d +
|
||||
offs_bs_n[:, None] * stride_v_cache_bl)
|
||||
|
||||
if start_n + BLOCK_SIZE > cur_batch_ctx_len or \
|
||||
BLOCK_DMODEL != BLOCK_DMODEL_PADDED:
|
||||
k_load = tl.load(
|
||||
K_cache + off_k,
|
||||
mask=dim_mask[:, None] &
|
||||
((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len),
|
||||
other=0.0) # [D,N]
|
||||
else:
|
||||
k_load = tl.load(K_cache + off_k)
|
||||
|
||||
if k_load.dtype.is_fp8():
|
||||
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
|
||||
else:
|
||||
k = k_load
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N]
|
||||
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||
qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||
float("-inf"))
|
||||
qk *= sm_scale
|
||||
if SLIDING_WINDOW > 0:
|
||||
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
|
||||
# Q entries in sequence
|
||||
# (start_n + offs_bs_n[None, :]) are the positions of
|
||||
# KV entries in sequence
|
||||
# So the condition makes sure each entry in Q only attends
|
||||
# to KV entries not more than SLIDING_WINDOW away.
|
||||
#
|
||||
# We can't use -inf here, because the
|
||||
# sliding window may lead to the entire row being masked.
|
||||
# This then makes m_ij contain -inf, which causes NaNs in
|
||||
# exp().
|
||||
qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
|
||||
(start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, qk,
|
||||
-10000)
|
||||
|
||||
# compute running maximum
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, axis=1)
|
||||
alpha = tl.exp(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update acc
|
||||
if start_n + BLOCK_SIZE > cur_batch_ctx_len or \
|
||||
BLOCK_DMODEL != BLOCK_DMODEL_PADDED:
|
||||
v_load = tl.load(
|
||||
V_cache + off_v,
|
||||
mask=dim_mask[None, :] &
|
||||
((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len),
|
||||
other=0.0) # [N,D]
|
||||
else:
|
||||
v_load = tl.load(V_cache + off_v)
|
||||
|
||||
if v_load.dtype.is_fp8():
|
||||
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
|
||||
else:
|
||||
v = v_load
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
|
||||
# # update m_i and l_i
|
||||
l_i = l_i * alpha + l_ij
|
||||
m_i = m_ij
|
||||
|
||||
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
|
||||
offs_d[:, None] * stride_kd)
|
||||
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
|
||||
offs_d[None, :] * stride_vd)
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
# block_mask is 0 when we're already past the current query length
|
||||
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
|
||||
|
||||
# compute query against itself (with causal mask)
|
||||
for start_n in tl.range(0, \
|
||||
block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \
|
||||
loop_unroll_factor=num_unroll_request):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=dim_mask[:, None] &
|
||||
((start_n + offs_n[None, :]) < cur_batch_query_len),
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||
qk *= sm_scale
|
||||
# apply causal mask
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||
float("-inf"))
|
||||
if SLIDING_WINDOW > 0:
|
||||
qk = tl.where(
|
||||
offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW,
|
||||
qk, -10000)
|
||||
|
||||
# compute running maximum
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, axis=1)
|
||||
alpha = tl.exp(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update acc
|
||||
v = tl.load(v_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=dim_mask[None, :] &
|
||||
((start_n + offs_n[:, None]) < cur_batch_query_len),
|
||||
other=0.0)
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
|
||||
# update m_i and l_i
|
||||
l_i = l_i * alpha + l_ij
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l_i[:, None]
|
||||
|
||||
# initialize pointers to output
|
||||
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
||||
out_ptrs = Out + off_o
|
||||
if USE_FP8:
|
||||
acc = acc * tl.load(out_scale_inv)
|
||||
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||
tl.store(out_ptrs,
|
||||
acc,
|
||||
mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len))
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_flash_attn_v2(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
K_cache,
|
||||
V_cache,
|
||||
B_Loc,
|
||||
sm_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
B_Ctxlen,
|
||||
block_size,
|
||||
x,
|
||||
Out,
|
||||
stride_b_loc_b,
|
||||
stride_b_loc_s,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_k_cache_bs,
|
||||
stride_k_cache_h,
|
||||
stride_k_cache_d,
|
||||
stride_k_cache_bl,
|
||||
stride_k_cache_x,
|
||||
stride_v_cache_bs,
|
||||
stride_v_cache_h,
|
||||
stride_v_cache_d,
|
||||
stride_v_cache_bl,
|
||||
num_queries_per_kv: int,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // num_queries_per_kv
|
||||
|
||||
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||
|
||||
q = tl.load(Q + off_q,
|
||||
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
# # initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
||||
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||
other=0).to(tl.int64)
|
||||
off_k = (
|
||||
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||
((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl +
|
||||
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||
off_v = (bn[:, None] * stride_v_cache_bs +
|
||||
cur_kv_head * stride_v_cache_h +
|
||||
offs_d[None, :] * stride_v_cache_d +
|
||||
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||
k = tl.load(K_cache + off_k,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||
float("-inf"))
|
||||
qk *= sm_scale
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
l_i_new = alpha * l_i + l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# scale acc
|
||||
acc_scale = alpha
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(V_cache + off_v,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
|
||||
offs_d[:, None] * stride_kd)
|
||||
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
|
||||
offs_d[None, :] * stride_vd)
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
block_mask = tl.where(
|
||||
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
|
||||
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :])
|
||||
< cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||
float("-inf"))
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
l_i_new = alpha * l_i + l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# scale acc
|
||||
acc_scale = alpha
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None])
|
||||
< cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
# acc /= l_i[:, None]
|
||||
# initialize pointers to output
|
||||
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs,
|
||||
acc,
|
||||
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_alibi(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
K_cache,
|
||||
V_cache,
|
||||
B_Loc,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
Alibi_slopes,
|
||||
block_size,
|
||||
x,
|
||||
Out,
|
||||
stride_b_loc_b,
|
||||
stride_b_loc_s,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_k_cache_bs,
|
||||
stride_k_cache_h,
|
||||
stride_k_cache_d,
|
||||
stride_k_cache_bl,
|
||||
stride_k_cache_x,
|
||||
stride_v_cache_bs,
|
||||
stride_v_cache_h,
|
||||
stride_v_cache_d,
|
||||
stride_v_cache_bl,
|
||||
num_queries_per_kv: int,
|
||||
IN_PRECISION: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr, # head size
|
||||
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
||||
BLOCK_N: tl.constexpr,
|
||||
SKIP_DECODE: tl.constexpr,
|
||||
):
|
||||
# attn_bias[]
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // num_queries_per_kv
|
||||
|
||||
# cur_batch_seq_len: the length of prompts
|
||||
# cur_batch_ctx_len: the length of prefix
|
||||
# cur_batch_in_all_start_index: the start id of the dim=0
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
|
||||
cur_batch_query_len = (cur_batch_in_all_stop_index -
|
||||
cur_batch_in_all_start_index)
|
||||
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
|
||||
|
||||
if SKIP_DECODE and cur_batch_query_len == 1:
|
||||
return
|
||||
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||
|
||||
dim_mask = tl.where(
|
||||
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
|
||||
|
||||
q = tl.load(Q + off_q,
|
||||
mask=dim_mask[None, :] &
|
||||
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
|
||||
other=0.0)
|
||||
|
||||
# # initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
|
||||
|
||||
alibi_slope = tl.load(Alibi_slopes + cur_head)
|
||||
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
|
||||
alibi_start_k = 0
|
||||
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
||||
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||
other=0).to(tl.int64)
|
||||
off_k = (
|
||||
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||
((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl +
|
||||
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||
off_v = (bn[:, None] * stride_v_cache_bs +
|
||||
cur_kv_head * stride_v_cache_h +
|
||||
offs_d[None, :] * stride_v_cache_d +
|
||||
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||
k_load = tl.load(K_cache + off_k,
|
||||
mask=dim_mask[:, None] &
|
||||
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
|
||||
other=0.0) # [D,N]
|
||||
|
||||
if k_load.dtype.is_fp8():
|
||||
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
|
||||
else:
|
||||
k = k_load
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||
float("-inf"))
|
||||
qk *= sm_scale
|
||||
|
||||
# load alibi
|
||||
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
|
||||
alibi_start_q[:, None]) * alibi_slope
|
||||
alibi = tl.where(
|
||||
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi,
|
||||
float("-inf"))
|
||||
qk += alibi
|
||||
alibi_start_k += BLOCK_N
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
l_i_new = alpha * l_i + l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# scale acc
|
||||
acc_scale = alpha
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v_load = tl.load(V_cache + off_v,
|
||||
mask=dim_mask[None, :] &
|
||||
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
|
||||
other=0.0)
|
||||
if v_load.dtype.is_fp8():
|
||||
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
|
||||
else:
|
||||
v = v_load
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
|
||||
offs_d[:, None] * stride_kd)
|
||||
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
|
||||
offs_d[None, :] * stride_vd)
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
block_mask = tl.where(
|
||||
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
|
||||
|
||||
# init alibi
|
||||
alibi_slope = tl.load(Alibi_slopes + cur_head)
|
||||
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
|
||||
alibi_start_k = cur_batch_ctx_len
|
||||
# # init debugger
|
||||
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
|
||||
# offset_db_k = tl.arange(0, BLOCK_N)
|
||||
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(
|
||||
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=dim_mask[:, None] & ((start_n + offs_n[None, :])
|
||||
< cur_batch_seq_len - cur_batch_ctx_len),
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk = tl.dot(q, k, acc=qk, input_precision='ieee')
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||
float("-inf"))
|
||||
|
||||
# load alibi
|
||||
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
|
||||
alibi_start_q[:, None]) * alibi_slope
|
||||
alibi = tl.where(
|
||||
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi,
|
||||
float("-inf"))
|
||||
qk += alibi
|
||||
alibi_start_k += BLOCK_N
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
l_i_new = alpha * l_i + l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# scale acc
|
||||
acc_scale = alpha
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(
|
||||
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=dim_mask[None, :] & ((start_n + offs_n[:, None])
|
||||
< cur_batch_seq_len - cur_batch_ctx_len),
|
||||
other=0.0)
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
acc = acc / l_i[:, None]
|
||||
|
||||
# initialize pointers to output
|
||||
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs,
|
||||
acc,
|
||||
mask=dim_mask[None, :] &
|
||||
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))
|
||||
return
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def context_attention_fwd(q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
kv_cache_dtype: str,
|
||||
k_cache,
|
||||
v_cache,
|
||||
b_loc,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_seq_len,
|
||||
max_input_len,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
sm_scale=None,
|
||||
skip_decode=False,
|
||||
fp8_out_scale=None,
|
||||
sinks=None):
|
||||
|
||||
q_dtype_is_f32 = q.dtype is torch.float32
|
||||
|
||||
# Turing does have tensor core for float32 multiplication
|
||||
# use ieee as fallback for triton kernels work. There is also
|
||||
# warning on vllm/config.py to inform users this fallback
|
||||
# implementation
|
||||
IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None
|
||||
|
||||
# Conversion of FP8 Tensor from uint8 storage to
|
||||
# appropriate torch.dtype for interpretation by Triton
|
||||
if "fp8" in kv_cache_dtype:
|
||||
assert k_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
assert v_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
|
||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||
target_dtype = current_platform.fp8_dtype()
|
||||
elif kv_cache_dtype == "fp8_e5m2":
|
||||
target_dtype = torch.float8_e5m2
|
||||
else:
|
||||
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
|
||||
|
||||
k_cache = k_cache.view(target_dtype)
|
||||
v_cache = v_cache.view(target_dtype)
|
||||
|
||||
if (k_cache.dtype == torch.uint8
|
||||
or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"):
|
||||
raise ValueError("kv_cache_dtype='auto' unsupported for\
|
||||
FP8 KV Cache prefill kernel")
|
||||
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
# round up Lk to a power of 2 - this is required for Triton block size
|
||||
Lk_padded = triton.next_power_of_2(Lk)
|
||||
|
||||
if sm_scale is None:
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
num_queries_per_kv = q.shape[1] // k.shape[1]
|
||||
|
||||
assert batch + 1 == len(b_start_loc)
|
||||
|
||||
# 0 means "disable"
|
||||
if sliding_window is None or sliding_window <= 0:
|
||||
sliding_window = 0
|
||||
|
||||
if alibi_slopes is not None:
|
||||
assert sinks is None, "Sinks arg is not supported with alibi"
|
||||
assert fp8_out_scale is None, "FP8 output not supported with alibi"
|
||||
# need to reduce num. blocks when using fp32
|
||||
# due to increased use of GPU shared memory
|
||||
# if q.dtype is torch.float32:
|
||||
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK
|
||||
# batch, head,
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
_fwd_kernel_alibi[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
b_loc,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
alibi_slopes,
|
||||
v_cache.shape[3],
|
||||
k_cache.shape[4],
|
||||
o,
|
||||
b_loc.stride(0),
|
||||
b_loc.stride(1),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
k_cache.stride(
|
||||
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
v_cache.stride(0),
|
||||
v_cache.stride(1),
|
||||
v_cache.stride(2),
|
||||
v_cache.stride(
|
||||
3), #[num_blocks, num_kv_heads, head_size, block_size]
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
IN_PRECISION=IN_PRECISION,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||
BLOCK_N=BLOCK,
|
||||
SKIP_DECODE=skip_decode,
|
||||
num_warps=NUM_WARPS,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
||||
max_seq_len = 0 if max_seq_len is None else max_seq_len
|
||||
extra_kargs = {}
|
||||
if current_platform.is_rocm():
|
||||
extra_kargs = {"kpack": 1, "waves_per_eu": 2}
|
||||
|
||||
grid = lambda META: (batch, head,
|
||||
triton.cdiv(max_input_len, META["BLOCK_M"]))
|
||||
_fwd_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
sinks,
|
||||
b_loc,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
1.0 / fp8_out_scale if fp8_out_scale is not None else 1.0,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
k_cache.shape[4],
|
||||
o,
|
||||
b_loc.stride(0),
|
||||
b_loc.stride(1),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
k_cache.stride(
|
||||
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
v_cache.stride(0),
|
||||
v_cache.stride(1),
|
||||
v_cache.stride(2),
|
||||
v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size]
|
||||
BLOCK_SIZE=v_cache.shape[3],
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
IN_PRECISION=IN_PRECISION,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||
SLIDING_WINDOW=sliding_window,
|
||||
SKIP_DECODE=skip_decode,
|
||||
USE_FP8=fp8_out_scale is not None,
|
||||
BLOCK_M=128,
|
||||
BLOCK_N=64,
|
||||
num_unroll_cache=4,
|
||||
num_unroll_request=1,
|
||||
num_warps=4,
|
||||
num_stages=1,
|
||||
USE_SINKS=sinks is not None,
|
||||
**extra_kargs)
|
||||
return
|
||||
104
vllm/attention/ops/rocm_aiter_mla.py
Normal file
104
vllm/attention/ops/rocm_aiter_mla.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
|
||||
def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
|
||||
max_block_per_batch: int,
|
||||
device: torch.device) -> tuple[torch.Tensor, ...]:
|
||||
paged_kv_indices = torch.zeros(max_batch_size * max_block_per_batch,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
paged_kv_indptr = torch.zeros(max_batch_size + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
paged_kv_last_page_lens = torch.full((max_batch_size, ),
|
||||
block_size,
|
||||
dtype=torch.int32)
|
||||
qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device)
|
||||
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr
|
||||
|
||||
|
||||
def aiter_mla_decode_fwd(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
sm_scale: float,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: Optional[torch.Tensor] = None,
|
||||
kv_indices: Optional[torch.Tensor] = None,
|
||||
kv_last_page_lens: Optional[torch.Tensor] = None,
|
||||
logit_cap: float = 0.0,
|
||||
):
|
||||
|
||||
torch.ops.vllm.rocm_aiter_mla_decode_fwd(q,
|
||||
kv_buffer.view(
|
||||
-1, 1, 1, q.shape[-1]),
|
||||
o,
|
||||
qo_indptr,
|
||||
max_seqlen_qo,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
sm_scale=sm_scale,
|
||||
logit_cap=logit_cap)
|
||||
|
||||
|
||||
def mla_decode_fwd_impl(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: Optional[torch.Tensor] = None,
|
||||
kv_indices: Optional[torch.Tensor] = None,
|
||||
kv_last_page_lens: Optional[torch.Tensor] = None,
|
||||
sm_scale: float = 1.0,
|
||||
logit_cap: float = 0.0,
|
||||
) -> None:
|
||||
from aiter.mla import mla_decode_fwd
|
||||
|
||||
mla_decode_fwd(q,
|
||||
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
||||
o,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
max_seqlen_qo,
|
||||
sm_scale=sm_scale,
|
||||
logit_cap=logit_cap)
|
||||
|
||||
|
||||
def mla_decode_fwd_fake(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: Optional[torch.Tensor] = None,
|
||||
kv_indices: Optional[torch.Tensor] = None,
|
||||
kv_last_page_lens: Optional[torch.Tensor] = None,
|
||||
sm_scale: float = 1.0,
|
||||
logit_cap: float = 0.0,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
if is_torch_equal_or_newer("2.7.0"):
|
||||
tags = ()
|
||||
else:
|
||||
tags = (torch.Tag.needs_fixed_stride_order, ),
|
||||
direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd",
|
||||
op_func=mla_decode_fwd_impl,
|
||||
mutates_args=["o"],
|
||||
fake_impl=mla_decode_fwd_fake,
|
||||
tags=tags)
|
||||
102
vllm/attention/ops/rocm_aiter_paged_attn.py
Normal file
102
vllm/attention/ops/rocm_aiter_paged_attn.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import aiter as rocm_aiter
|
||||
import torch
|
||||
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class AITERPagedAttention(PagedAttention):
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
) -> None:
|
||||
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
|
||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||
value_cache, slot_mapping,
|
||||
kv_cache_dtype, k_scale,
|
||||
v_scale)
|
||||
else:
|
||||
kv_cache_torch_dtype = (FP8_DTYPE
|
||||
if "fp8" in kv_cache_dtype else torch.int8)
|
||||
key_cache = key_cache.view(kv_cache_torch_dtype)
|
||||
value_cache = value_cache.view(kv_cache_torch_dtype)
|
||||
|
||||
rocm_aiter.reshape_and_cache_with_pertoken_quant(
|
||||
key, value, key_cache, value_cache, k_scale, v_scale,
|
||||
slot_mapping.flatten(), True)
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
tp_rank: int = 0,
|
||||
blocksparse_local_blocks: int = 0,
|
||||
blocksparse_vert_stride: int = 0,
|
||||
blocksparse_block_size: int = 64,
|
||||
blocksparse_head_sliding_step: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
|
||||
return PagedAttention.forward_decode(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
num_kv_heads=num_kv_heads,
|
||||
scale=scale,
|
||||
alibi_slopes=alibi_slopes,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
tp_rank=tp_rank,
|
||||
blocksparse_local_blocks=blocksparse_local_blocks,
|
||||
blocksparse_vert_stride=blocksparse_vert_stride,
|
||||
blocksparse_block_size=blocksparse_block_size,
|
||||
blocksparse_head_sliding_step=blocksparse_head_sliding_step)
|
||||
|
||||
if "fp8" in kv_cache_dtype:
|
||||
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||
|
||||
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||
# use blocksparse paged attention
|
||||
block_size = value_cache.size(-1)
|
||||
assert (blocksparse_block_size > 0 and
|
||||
blocksparse_block_size % block_size == 0), \
|
||||
(f"{blocksparse_block_size=} needs to be a multiple of"
|
||||
f"{block_size=} used in block_tables.")
|
||||
|
||||
output = torch.empty_like(query)
|
||||
block_size = value_cache.shape[3]
|
||||
max_num_blocks_per_seq = cdiv(max_seq_len, block_size)
|
||||
|
||||
rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables,
|
||||
seq_lens, max_num_blocks_per_seq, k_scale,
|
||||
v_scale, output)
|
||||
return output
|
||||
691
vllm/attention/ops/triton_decode_attention.py
Normal file
691
vllm/attention/ops/triton_decode_attention.py
Normal file
@@ -0,0 +1,691 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py
|
||||
# which was originally adapted from
|
||||
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
|
||||
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
|
||||
|
||||
# Changes:
|
||||
# - Add support for page size >= 1.
|
||||
|
||||
# Copyright 2025 vLLM Team
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Memory-efficient attention for decoding.
|
||||
It supports page size >= 1.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from packaging import version
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
is_hip_ = current_platform.is_rocm()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Only print the following warnings when triton version < 3.2.0.
|
||||
# The issue won't affect performance or accuracy.
|
||||
if version.parse(triton.__version__) < version.parse('3.2.0'):
|
||||
logger.warning(
|
||||
"The following error message 'operation scheduled before its operands' "
|
||||
"can be ignored.")
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tanh(x):
|
||||
# Tanh is just a scaled sigmoid
|
||||
return 2 * tl.sigmoid(2 * x) - 1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_stage1(
|
||||
Q,
|
||||
K_Buffer,
|
||||
V_Buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
Att_Out,
|
||||
stride_req_to_tokens_b,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_buf_kbs,
|
||||
stride_buf_kh,
|
||||
stride_buf_vbs,
|
||||
stride_buf_vh,
|
||||
stride_mid_ob,
|
||||
stride_mid_oh,
|
||||
stride_mid_os,
|
||||
kv_group_num: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
NUM_KV_SPLITS: tl.constexpr,
|
||||
PAGE_SIZE: tl.constexpr,
|
||||
logit_cap: tl.constexpr,
|
||||
Lk: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
split_kv_id = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_dv = tl.arange(0, BLOCK_DV)
|
||||
mask_d = offs_d < Lk
|
||||
mask_dv = offs_dv < Lv
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_req_idx = cur_batch
|
||||
|
||||
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
||||
q = tl.load(Q + off_q, mask=mask_d, other=0.0)
|
||||
|
||||
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||
split_kv_start = kv_len_per_split * split_kv_id
|
||||
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
|
||||
cur_batch_seq_len)
|
||||
|
||||
e_max = -float("inf")
|
||||
e_sum = 0.0
|
||||
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
||||
|
||||
if split_kv_end > split_kv_start:
|
||||
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||
kv_page_number = tl.load(
|
||||
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +
|
||||
offs_n // PAGE_SIZE,
|
||||
mask=offs_n < split_kv_end,
|
||||
other=0,
|
||||
)
|
||||
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
|
||||
offs_buf_k = (kv_loc[:, None] * stride_buf_kbs +
|
||||
cur_kv_head * stride_buf_kh + offs_d[None, :])
|
||||
k = tl.load(
|
||||
K_Buffer + offs_buf_k,
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
qk = tl.sum(q[None, :] * k, 1)
|
||||
qk *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))
|
||||
|
||||
offs_buf_v = (kv_loc[:, None] * stride_buf_vbs +
|
||||
cur_kv_head * stride_buf_vh + offs_dv[None, :])
|
||||
v = tl.load(
|
||||
V_Buffer + offs_buf_v,
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max)
|
||||
acc *= re_scale
|
||||
acc += tl.sum(p[:, None] * v, 0)
|
||||
|
||||
e_sum = e_sum * re_scale + tl.sum(p, 0)
|
||||
e_max = n_e_max
|
||||
|
||||
offs_mid_o = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
|
||||
split_kv_id * stride_mid_os + offs_dv)
|
||||
|
||||
tl.store(
|
||||
Att_Out + offs_mid_o,
|
||||
acc / e_sum,
|
||||
mask=(mask_dv),
|
||||
)
|
||||
|
||||
offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
|
||||
split_kv_id * stride_mid_os + Lv)
|
||||
|
||||
tl.store(
|
||||
Att_Out + offs_mid_o_1,
|
||||
e_max + tl.log(e_sum),
|
||||
)
|
||||
|
||||
|
||||
def _decode_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
att_out,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
):
|
||||
BLOCK = 64 if not is_hip_ else 8
|
||||
|
||||
NUM_KV_SPLITS = num_kv_splits
|
||||
Lk = k_buffer.shape[-1]
|
||||
Lv = v_buffer.shape[-1]
|
||||
|
||||
batch, head_num = q.shape[0], q.shape[1]
|
||||
|
||||
grid = (batch, head_num, NUM_KV_SPLITS)
|
||||
kv_group_num = q.shape[1] // k_buffer.shape[-2]
|
||||
|
||||
num_warps = 4
|
||||
if kv_group_num != 1:
|
||||
num_warps = 1 if is_hip_ else 2
|
||||
|
||||
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
_fwd_kernel_stage1[grid](
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
att_out,
|
||||
Req_to_tokens.stride(0),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
att_out.stride(0),
|
||||
att_out.stride(1),
|
||||
att_out.stride(2),
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
BLOCK_N=BLOCK,
|
||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||
PAGE_SIZE=page_size,
|
||||
logit_cap=logit_cap,
|
||||
num_warps=num_warps,
|
||||
num_stages=2,
|
||||
Lk=Lk,
|
||||
Lv=Lv,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_grouped_kernel_stage1(
|
||||
Q,
|
||||
K_Buffer,
|
||||
V_Buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
Att_Out,
|
||||
stride_req_to_tokens_b,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_buf_kbs,
|
||||
stride_buf_kh,
|
||||
stride_buf_vbs,
|
||||
stride_buf_vh,
|
||||
stride_mid_ob,
|
||||
stride_mid_oh,
|
||||
stride_mid_os,
|
||||
kv_group_num: tl.constexpr,
|
||||
q_head_num: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DPE: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_H: tl.constexpr,
|
||||
NUM_KV_SPLITS: tl.constexpr,
|
||||
PAGE_SIZE: tl.constexpr,
|
||||
logit_cap: tl.constexpr,
|
||||
Lk: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head_id = tl.program_id(1)
|
||||
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
|
||||
split_kv_id = tl.program_id(2)
|
||||
|
||||
if kv_group_num > BLOCK_H:
|
||||
VALID_BLOCK_H: tl.constexpr = BLOCK_H
|
||||
else:
|
||||
VALID_BLOCK_H: tl.constexpr = kv_group_num
|
||||
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
|
||||
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
|
||||
mask_h = mask_h & (cur_head < q_head_num)
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_dv = tl.arange(0, BLOCK_DV)
|
||||
mask_d = offs_d < Lk
|
||||
mask_dv = offs_dv < Lv
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_req_idx = cur_batch
|
||||
|
||||
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[
|
||||
None, :]
|
||||
q = tl.load(Q + offs_q,
|
||||
mask=(mask_h[:, None]) & (mask_d[None, :]),
|
||||
other=0.0)
|
||||
|
||||
if BLOCK_DPE > 0:
|
||||
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
||||
mask_dpe = offs_dpe < Lk
|
||||
off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh +
|
||||
offs_dpe[None, :])
|
||||
qpe = tl.load(Q + off_qpe,
|
||||
mask=(mask_h[:, None]) & (mask_dpe[None, :]),
|
||||
other=0.0)
|
||||
|
||||
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||
split_kv_start = kv_len_per_split * split_kv_id
|
||||
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
|
||||
cur_batch_seq_len)
|
||||
|
||||
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
|
||||
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
|
||||
|
||||
if split_kv_end > split_kv_start:
|
||||
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||
kv_page_number = tl.load(
|
||||
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +
|
||||
offs_n // PAGE_SIZE,
|
||||
mask=offs_n < split_kv_end,
|
||||
other=0,
|
||||
)
|
||||
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
|
||||
offs_buf_k = (kv_loc[None, :] * stride_buf_kbs +
|
||||
cur_kv_head * stride_buf_kh + offs_d[:, None])
|
||||
k = tl.load(
|
||||
K_Buffer + offs_buf_k,
|
||||
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
|
||||
other=0.0,
|
||||
)
|
||||
qk = tl.dot(q, k.to(q.dtype))
|
||||
if BLOCK_DPE > 0:
|
||||
offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs +
|
||||
cur_kv_head * stride_buf_kh +
|
||||
offs_dpe[:, None])
|
||||
kpe = tl.load(
|
||||
K_Buffer + offs_buf_kpe,
|
||||
mask=(offs_n[None, :] < split_kv_end) &
|
||||
(mask_dpe[:, None]),
|
||||
other=0.0,
|
||||
)
|
||||
qk += tl.dot(qpe, kpe.to(qpe.dtype))
|
||||
qk *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end),
|
||||
qk, float("-inf"))
|
||||
|
||||
offs_buf_v = (kv_loc[:, None] * stride_buf_vbs +
|
||||
cur_kv_head * stride_buf_vh + offs_dv[None, :])
|
||||
v = tl.load(
|
||||
V_Buffer + offs_buf_v,
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max[:, None])
|
||||
acc *= re_scale[:, None]
|
||||
acc += tl.dot(p.to(v.dtype), v)
|
||||
|
||||
e_sum = e_sum * re_scale + tl.sum(p, 1)
|
||||
e_max = n_e_max
|
||||
|
||||
offs_mid_o = (cur_batch * stride_mid_ob +
|
||||
cur_head[:, None] * stride_mid_oh +
|
||||
split_kv_id * stride_mid_os + offs_dv[None, :])
|
||||
|
||||
tl.store(
|
||||
Att_Out + offs_mid_o,
|
||||
acc / e_sum[:, None],
|
||||
mask=(mask_h[:, None]) & (mask_dv[None, :]),
|
||||
)
|
||||
|
||||
offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
|
||||
split_kv_id * stride_mid_os + Lv)
|
||||
|
||||
tl.store(
|
||||
Att_Out + offs_mid_o_1,
|
||||
e_max + tl.log(e_sum),
|
||||
mask=mask_h,
|
||||
)
|
||||
|
||||
|
||||
def _decode_grouped_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
att_out,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
):
|
||||
BLOCK = 32
|
||||
Lk = k_buffer.shape[-1]
|
||||
Lv = v_buffer.shape[-1]
|
||||
|
||||
# [TODO] work around shmem limit on MI3xx
|
||||
if is_hip_ and Lk >= 576:
|
||||
BLOCK = 16
|
||||
|
||||
if Lk == 576:
|
||||
BLOCK_DMODEL = 512
|
||||
BLOCK_DPE = 64
|
||||
elif Lk == 288:
|
||||
BLOCK_DMODEL = 256
|
||||
BLOCK_DPE = 32
|
||||
else:
|
||||
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
||||
BLOCK_DPE = 0
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
batch, head_num = q.shape[0], q.shape[1]
|
||||
kv_group_num = q.shape[1] // k_buffer.shape[-2]
|
||||
|
||||
BLOCK_H = 16
|
||||
NUM_KV_SPLITS = num_kv_splits
|
||||
grid = (
|
||||
batch,
|
||||
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
|
||||
NUM_KV_SPLITS,
|
||||
)
|
||||
|
||||
extra_kargs = {}
|
||||
num_stages = 2
|
||||
if is_hip_:
|
||||
# https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization
|
||||
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||
extra_kargs = {
|
||||
"waves_per_eu": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
}
|
||||
num_stages = 1
|
||||
|
||||
_fwd_grouped_kernel_stage1[grid](
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
att_out,
|
||||
Req_to_tokens.stride(0),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
att_out.stride(0),
|
||||
att_out.stride(1),
|
||||
att_out.stride(2),
|
||||
kv_group_num=kv_group_num,
|
||||
q_head_num=head_num,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_DPE=BLOCK_DPE,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
BLOCK_N=BLOCK,
|
||||
BLOCK_H=BLOCK_H,
|
||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||
PAGE_SIZE=page_size,
|
||||
logit_cap=logit_cap,
|
||||
num_warps=4,
|
||||
num_stages=num_stages,
|
||||
Lk=Lk,
|
||||
Lv=Lv,
|
||||
**extra_kargs,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_stage2(
|
||||
Mid_O,
|
||||
o,
|
||||
lse,
|
||||
B_Seqlen,
|
||||
stride_mid_ob,
|
||||
stride_mid_oh,
|
||||
stride_mid_os,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_lse_bs,
|
||||
NUM_KV_SPLITS: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DV)
|
||||
mask_d = offs_d < Lv
|
||||
|
||||
e_sum = 0.0
|
||||
e_max = -float("inf")
|
||||
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
||||
|
||||
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
|
||||
offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv
|
||||
|
||||
for split_kv_id in range(0, NUM_KV_SPLITS):
|
||||
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||
split_kv_start = kv_len_per_split * split_kv_id
|
||||
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
|
||||
cur_batch_seq_len)
|
||||
|
||||
if split_kv_end > split_kv_start:
|
||||
tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os,
|
||||
mask=mask_d,
|
||||
other=0.0)
|
||||
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
|
||||
n_e_max = tl.maximum(tlogic, e_max)
|
||||
|
||||
old_scale = tl.exp(e_max - n_e_max)
|
||||
acc *= old_scale
|
||||
exp_logic = tl.exp(tlogic - n_e_max)
|
||||
acc += exp_logic * tv
|
||||
|
||||
e_sum = e_sum * old_scale + exp_logic
|
||||
e_max = n_e_max
|
||||
|
||||
tl.store(
|
||||
o + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
|
||||
acc / e_sum,
|
||||
mask=mask_d,
|
||||
)
|
||||
lse_val = e_max + tl.log(e_sum)
|
||||
tl.store(
|
||||
lse + cur_batch * stride_lse_bs + cur_head,
|
||||
lse_val,
|
||||
)
|
||||
|
||||
|
||||
def _decode_softmax_reducev_fwd(
|
||||
logits,
|
||||
q,
|
||||
o,
|
||||
lse,
|
||||
v_buffer,
|
||||
b_seq_len,
|
||||
num_kv_splits,
|
||||
):
|
||||
batch, head_num = q.shape[0], q.shape[1]
|
||||
Lv = v_buffer.shape[-1]
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
NUM_KV_SPLITS = num_kv_splits
|
||||
|
||||
extra_kargs = {}
|
||||
if is_hip_:
|
||||
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
||||
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||
extra_kargs = {
|
||||
"waves_per_eu": 4,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
}
|
||||
|
||||
grid = (batch, head_num)
|
||||
_fwd_kernel_stage2[grid](
|
||||
logits,
|
||||
o,
|
||||
lse,
|
||||
b_seq_len,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
logits.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
lse.stride(0),
|
||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
Lv=Lv,
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
**extra_kargs,
|
||||
)
|
||||
|
||||
|
||||
def decode_attention_fwd_normal(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
lse,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap=0.0,
|
||||
):
|
||||
_decode_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len,
|
||||
num_kv_splits)
|
||||
|
||||
|
||||
def decode_attention_fwd_grouped(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
lse,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap=0.0,
|
||||
):
|
||||
_decode_grouped_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len,
|
||||
num_kv_splits)
|
||||
|
||||
|
||||
def decode_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
lse,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size=1,
|
||||
logit_cap=0.0,
|
||||
):
|
||||
assert num_kv_splits == attn_logits.shape[2]
|
||||
kv_group_num = q.shape[1] // v_buffer.shape[-2]
|
||||
|
||||
if kv_group_num == 1:
|
||||
# MHA
|
||||
decode_attention_fwd_normal(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
lse,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
)
|
||||
else:
|
||||
# GQA/MQA/MLA
|
||||
decode_attention_fwd_grouped(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
lse,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
)
|
||||
984
vllm/attention/ops/triton_flash_attention.py
Normal file
984
vllm/attention/ops/triton_flash_attention.py
Normal file
@@ -0,0 +1,984 @@
|
||||
#!/usr/bin/env python
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
|
||||
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
|
||||
(https://tridao.me/publications/flash2/flash2.pdf)
|
||||
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
|
||||
|
||||
Features supported:
|
||||
|
||||
1) Fwd with causal masking
|
||||
2) Any sequence lengths without padding (currently fwd kernel only)
|
||||
3) Support for different sequence lengths for q and k
|
||||
4) Nested tensor API currently does not support dropout or bias.
|
||||
|
||||
Not currently supported:
|
||||
|
||||
1) Non power of two head dims
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
# Avoid misleading ROCm warning.
|
||||
if current_platform.is_rocm():
|
||||
from vllm.platforms.rocm import on_gfx1x
|
||||
else:
|
||||
on_gfx1x = lambda *args, **kwargs: False
|
||||
|
||||
torch_dtype: tl.constexpr = torch.float16
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cdiv_fn(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
@triton.jit
|
||||
def max_fn(x, y):
|
||||
return tl.math.max(x, y)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
ms = tl.arange(0, m)
|
||||
ns = tl.arange(0, n)
|
||||
return philox_offset + ms[:, None] * stride + ns[None, :]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n,
|
||||
stride).to(tl.uint32)
|
||||
# TODO: use tl.randint for better performance
|
||||
return tl.rand(philox_seed, rng_offsets)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n,
|
||||
stride)
|
||||
rng_keep = rng_output > dropout_p
|
||||
return rng_keep
|
||||
|
||||
|
||||
@triton.jit
|
||||
def load_fn(block_ptr, first, second, pad):
|
||||
if first and second:
|
||||
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
|
||||
elif first:
|
||||
tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
|
||||
elif second:
|
||||
tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
|
||||
else:
|
||||
tensor = tl.load(block_ptr)
|
||||
return tensor
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
start_m,
|
||||
actual_seqlen_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
batch_philox_offset,
|
||||
encoded_softmax_block_ptr,
|
||||
block_min,
|
||||
block_max,
|
||||
offs_n_causal,
|
||||
masked_blocks,
|
||||
n_extra_tokens,
|
||||
bias_ptr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
OFFS_M: tl.constexpr,
|
||||
OFFS_N: tl.constexpr,
|
||||
PRE_LOAD_V: tl.constexpr,
|
||||
MASK_STEPS: tl.constexpr,
|
||||
ENABLE_DROPOUT: tl.constexpr,
|
||||
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
||||
PADDED_HEAD: tl.constexpr,
|
||||
USE_FP8: tl.constexpr,
|
||||
qk_scale,
|
||||
p_descale,
|
||||
):
|
||||
# loop over k, v, and update accumulator
|
||||
for start_n in range(block_min, block_max, BLOCK_N):
|
||||
# For padded blocks, we will overrun the tensor size if
|
||||
# we load all BLOCK_N. For others, the blocks are all within range.
|
||||
k = load_fn(
|
||||
K_block_ptr,
|
||||
PADDED_HEAD,
|
||||
MASK_STEPS and (n_extra_tokens != 0),
|
||||
"zero",
|
||||
)
|
||||
if PRE_LOAD_V:
|
||||
v = load_fn(
|
||||
V_block_ptr,
|
||||
MASK_STEPS and (n_extra_tokens != 0),
|
||||
PADDED_HEAD,
|
||||
"zero",
|
||||
)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
# We start from end of seqlen_k so only the first iteration would need
|
||||
# to be checked for padding if it is not a multiple of block_n
|
||||
# TODO: This can be optimized to only be true for the padded block.
|
||||
if MASK_STEPS: # noqa: SIM102
|
||||
# If this is the last block / iteration, we want to
|
||||
# mask if the sequence length is not a multiple of block size
|
||||
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
|
||||
# if not is_modulo_mn. last step might get wasted but that is okay.
|
||||
# check if this masking works for that case.
|
||||
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
|
||||
boundary_m = tl.full([BLOCK_M],
|
||||
actual_seqlen_k,
|
||||
dtype=tl.int32)
|
||||
size_n = start_n + OFFS_N[None, :]
|
||||
mask = size_n < boundary_m[:, None]
|
||||
qk = tl.where(mask, qk, float("-inf"))
|
||||
if IS_CAUSAL:
|
||||
causal_boundary = start_n + offs_n_causal
|
||||
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
|
||||
qk = tl.where(causal_mask, qk, float("-inf"))
|
||||
# -- compute qk ----
|
||||
qk += tl.dot(q, k)
|
||||
if USE_FP8:
|
||||
qk *= qk_scale
|
||||
if bias_ptr is not None:
|
||||
bias = load_fn(bias_ptr, False, MASK_STEPS
|
||||
and (n_extra_tokens != 0), "zero")
|
||||
# While bias is added after multiplying qk with sm_scale, our
|
||||
# optimization to use 2^x instead of e^x results in an additional
|
||||
# scale factor of log2(e) which we must also multiply the bias with.
|
||||
qk += bias * 1.44269504089
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
qk = qk - m_ij[:, None]
|
||||
p = tl.math.exp2(qk)
|
||||
|
||||
# CAVEAT: Must update l_ij before applying dropout
|
||||
l_ij = tl.sum(p, 1)
|
||||
if ENABLE_DROPOUT:
|
||||
philox_offset = (batch_philox_offset +
|
||||
start_m * BLOCK_M * actual_seqlen_k + start_n -
|
||||
BLOCK_N)
|
||||
keep = dropout_mask(
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
dropout_p,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
actual_seqlen_k,
|
||||
)
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
tl.store(
|
||||
encoded_softmax_block_ptr,
|
||||
tl.where(keep, p,
|
||||
-p).to(encoded_softmax_block_ptr.type.element_ty),
|
||||
)
|
||||
p = tl.where(keep, p, 0.0)
|
||||
elif RETURN_ENCODED_SOFTMAX:
|
||||
tl.store(
|
||||
encoded_softmax_block_ptr,
|
||||
p.to(encoded_softmax_block_ptr.type.element_ty),
|
||||
)
|
||||
# -- update output accumulator --
|
||||
alpha = tl.math.exp2(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
if not PRE_LOAD_V:
|
||||
v = load_fn(
|
||||
V_block_ptr,
|
||||
MASK_STEPS and (n_extra_tokens != 0),
|
||||
PADDED_HEAD,
|
||||
"zero",
|
||||
)
|
||||
# -- update m_i and l_i
|
||||
l_i = l_i * alpha + l_ij
|
||||
# update m_i and l_i
|
||||
m_i = m_ij
|
||||
|
||||
if USE_FP8:
|
||||
p *= p_descale
|
||||
|
||||
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
|
||||
|
||||
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
if bias_ptr is not None:
|
||||
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
|
||||
(0, BLOCK_N))
|
||||
return acc, l_i, m_i
|
||||
|
||||
|
||||
def get_cdna_autotune_configs():
|
||||
return [
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 256,
|
||||
'BLOCK_N': 64,
|
||||
'waves_per_eu': 2,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 128,
|
||||
'BLOCK_N': 128,
|
||||
'waves_per_eu': 2,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 256,
|
||||
'BLOCK_N': 128,
|
||||
'waves_per_eu': 2,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 128,
|
||||
'BLOCK_N': 64,
|
||||
'waves_per_eu': 1,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 128,
|
||||
'BLOCK_N': 64,
|
||||
'waves_per_eu': 3,
|
||||
'PRE_LOAD_V': True
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 128,
|
||||
'BLOCK_N': 64,
|
||||
'waves_per_eu': 3,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 64,
|
||||
'BLOCK_N': 64,
|
||||
'waves_per_eu': 4,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 32,
|
||||
'BLOCK_N': 32,
|
||||
'waves_per_eu': 4,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8),
|
||||
# TODO: This config fails with head_size not pow2 with data mismatches.
|
||||
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
|
||||
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
|
||||
|
||||
# Fails in AccelerateAMDMatmul (Triton) assert when using FP8:
|
||||
# triton.Config(
|
||||
# {
|
||||
# "BLOCK_M": 16,
|
||||
# "BLOCK_N": 16,
|
||||
# "waves_per_eu": 1,
|
||||
# "PRE_LOAD_V": False,
|
||||
# },
|
||||
# num_stages=1,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8']
|
||||
|
||||
|
||||
def get_rdna_autotune_configs():
|
||||
return [
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 32,
|
||||
'BLOCK_N': 32,
|
||||
'waves_per_eu': 4,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 32,
|
||||
'BLOCK_N': 32,
|
||||
'waves_per_eu': 2,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 32,
|
||||
'BLOCK_N': 16,
|
||||
'waves_per_eu': 4,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 32,
|
||||
'BLOCK_N': 16,
|
||||
'waves_per_eu': 2,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=2),
|
||||
# Fails in AccelerateAMDMatmul (Triton) assert when using FP8:
|
||||
# triton.Config(
|
||||
# {
|
||||
# 'BLOCK_M': 16,
|
||||
# 'BLOCK_N': 16,
|
||||
# 'waves_per_eu': 4,
|
||||
# 'PRE_LOAD_V': False
|
||||
# },
|
||||
# num_stages=1,
|
||||
# num_warps=2),
|
||||
# triton.Config(
|
||||
# {
|
||||
# 'BLOCK_M': 16,
|
||||
# 'BLOCK_N': 16,
|
||||
# 'waves_per_eu': 2,
|
||||
# 'PRE_LOAD_V': False
|
||||
# },
|
||||
# num_stages=1,
|
||||
# num_warps=2),
|
||||
# # Fall-back config.
|
||||
# triton.Config(
|
||||
# {
|
||||
# 'BLOCK_M': 16,
|
||||
# 'BLOCK_N': 16,
|
||||
# 'waves_per_eu': 1,
|
||||
# 'PRE_LOAD_V': False
|
||||
# },
|
||||
# num_stages=1,
|
||||
# num_warps=2),
|
||||
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8']
|
||||
|
||||
|
||||
def get_autotune_configs():
|
||||
if on_gfx1x():
|
||||
return get_rdna_autotune_configs()
|
||||
else:
|
||||
return get_cdna_autotune_configs()
|
||||
|
||||
|
||||
autotune_configs, autotune_keys = get_autotune_configs()
|
||||
|
||||
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=autotune_configs,
|
||||
key=autotune_keys,
|
||||
)
|
||||
@triton.jit
|
||||
def attn_fwd(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
bias,
|
||||
sm_scale,
|
||||
q_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
p_scale,
|
||||
p_descale,
|
||||
o_descale,
|
||||
L,
|
||||
Out,
|
||||
stride_qz: tl.int64,
|
||||
stride_qh: tl.int64,
|
||||
stride_qm: tl.int64,
|
||||
stride_qk: tl.int64,
|
||||
stride_kz: tl.int64,
|
||||
stride_kh: tl.int64,
|
||||
stride_kn: tl.int64,
|
||||
stride_kk: tl.int64,
|
||||
stride_vz: tl.int64,
|
||||
stride_vh: tl.int64,
|
||||
stride_vk: tl.int64,
|
||||
stride_vn: tl.int64,
|
||||
stride_oz: tl.int64,
|
||||
stride_oh: tl.int64,
|
||||
stride_om: tl.int64,
|
||||
stride_on: tl.int64,
|
||||
stride_bz: tl.int64,
|
||||
stride_bh: tl.int64,
|
||||
stride_bm: tl.int64,
|
||||
stride_bn: tl.int64,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
philox_offset_base,
|
||||
encoded_softmax,
|
||||
HQ: tl.constexpr,
|
||||
HK: tl.constexpr,
|
||||
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
||||
MAX_SEQLENS_Q: tl.constexpr,
|
||||
MAX_SEQLENS_K: tl.constexpr,
|
||||
VARLEN: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
USE_FP8: tl.constexpr,
|
||||
USE_FP8_OUT: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
PRE_LOAD_V: tl.constexpr,
|
||||
BIAS_TYPE: tl.constexpr,
|
||||
ENABLE_DROPOUT: tl.constexpr,
|
||||
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
||||
FP8_MIN: tl.constexpr = float8_info.min,
|
||||
FP8_MAX: tl.constexpr = float8_info.max,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_h_q = tl.program_id(1)
|
||||
off_z = tl.program_id(2)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
if VARLEN:
|
||||
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
|
||||
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
|
||||
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
|
||||
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
|
||||
# small for all start_m so for those we return early.
|
||||
if start_m * BLOCK_M > seqlen_q:
|
||||
return
|
||||
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
|
||||
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
|
||||
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
|
||||
else:
|
||||
cu_seqlens_q_start = 0
|
||||
cu_seqlens_k_start = 0
|
||||
seqlen_q = MAX_SEQLENS_Q
|
||||
seqlen_k = MAX_SEQLENS_K
|
||||
|
||||
# Now we compute whether we need to exit early due to causal masking.
|
||||
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
|
||||
# are completely masked, resulting in 0s written to the output, and
|
||||
# inf written to LSE. We don't need to do any GEMMs in this case.
|
||||
# This block of code determines what N is, and if this WG is operating
|
||||
# on those M rows.
|
||||
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
|
||||
if IS_CAUSAL:
|
||||
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
|
||||
# If seqlen_q != seqlen_k, attn scores are rectangular which means
|
||||
# the causal mask boundary is bottom right aligned, and ends at either
|
||||
# the top edge (seqlen_q < seqlen_k) or left edge.
|
||||
# This captures the decrease in n_blocks if we have a rectangular attn
|
||||
# matrix
|
||||
n_blocks_seqlen = cdiv_fn(
|
||||
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
|
||||
# This is what adjusts the block_max for the current WG, only
|
||||
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
|
||||
n_blocks = min(n_blocks, n_blocks_seqlen)
|
||||
# If we have no blocks after adjusting for seqlen deltas, this WG is
|
||||
# part of the blocks that are all 0. We exit early.
|
||||
if n_blocks <= 0:
|
||||
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
|
||||
off_h_q * stride_oh)
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out + o_offset,
|
||||
shape=(seqlen_q, BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
|
||||
# We still need to write 0s to the result
|
||||
# tl.store(O_block_ptr,
|
||||
# acc.to(Out.type.element_ty), boundary_check=(0,1))
|
||||
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
|
||||
# + offs_m
|
||||
# We store inf to LSE, not -inf because in the bwd pass,
|
||||
# we subtract this
|
||||
# from qk which makes it -inf, such that exp(qk - inf) = 0
|
||||
# for these masked blocks.
|
||||
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
|
||||
# tl.store(l_ptrs, l)
|
||||
# TODO: Should dropout and return encoded softmax be handled here?
|
||||
return
|
||||
|
||||
# If MQA / GQA, set the K and V head offsets appropriately.
|
||||
GROUP_SIZE: tl.constexpr = HQ // HK
|
||||
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
|
||||
|
||||
n_extra_tokens = 0
|
||||
if seqlen_k < BLOCK_N:
|
||||
n_extra_tokens = BLOCK_N - seqlen_k
|
||||
elif seqlen_k % BLOCK_N:
|
||||
n_extra_tokens = seqlen_k % BLOCK_N
|
||||
padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
|
||||
|
||||
# Compute pointers for all the tensors used in this kernel.
|
||||
q_offset = (off_z * stride_qz + off_h_q * stride_qh +
|
||||
cu_seqlens_q_start * stride_qm)
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + q_offset,
|
||||
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
k_offset = (off_z * stride_kz + off_h_k * stride_kh +
|
||||
cu_seqlens_k_start * stride_kn)
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + k_offset,
|
||||
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1),
|
||||
)
|
||||
v_offset = (off_z * stride_vz + off_h_k * stride_vh +
|
||||
cu_seqlens_k_start * stride_vk)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + v_offset,
|
||||
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
if BIAS_TYPE != 0:
|
||||
bias_ptr = tl.make_block_ptr(
|
||||
base=bias + off_h_q * stride_bh,
|
||||
shape=(seqlen_q, seqlen_k),
|
||||
strides=(stride_bm, stride_bn),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_N),
|
||||
order=(1, 0),
|
||||
)
|
||||
else:
|
||||
bias_ptr = None
|
||||
if ENABLE_DROPOUT:
|
||||
batch_philox_offset = philox_offset_base \
|
||||
+ (off_z * HQ + off_h_q) \
|
||||
* seqlen_q * seqlen_k
|
||||
else:
|
||||
batch_philox_offset = 0
|
||||
# We can ask to return the dropout mask without actually doing any dropout.
|
||||
# In this case, we return an invalid pointer so indicate the mask is not i
|
||||
# valid.
|
||||
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.make_block_ptr(
|
||||
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
|
||||
shape=(seqlen_q, seqlen_k),
|
||||
strides=(seqlen_k, 1),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_N),
|
||||
order=(1, 0),
|
||||
)
|
||||
else:
|
||||
encoded_softmax_block_ptr = 0
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
|
||||
# have native e^x support in HW.
|
||||
qk_scale = sm_scale * 1.44269504089
|
||||
# Q is loaded once at the beginning and shared by all N blocks.
|
||||
q = load_fn(Q_block_ptr, True, padded_head, "zero")
|
||||
if not USE_FP8:
|
||||
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
|
||||
acc_scale = 1.0
|
||||
else:
|
||||
qk_scale *= q_scale * k_scale
|
||||
acc_scale = p_scale * v_scale
|
||||
|
||||
# Here we compute how many full and masked blocks we have.
|
||||
padded_block_k = n_extra_tokens != 0
|
||||
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
|
||||
if IS_CAUSAL:
|
||||
# There are always at least BLOCK_M // BLOCK_N masked blocks.
|
||||
# Additionally there might be one more due to dissimilar seqlens.
|
||||
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
|
||||
else:
|
||||
# Padding on Q does not need to be masked in the FA loop.
|
||||
masked_blocks = padded_block_k
|
||||
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
|
||||
# block. In this case we might exceed n_blocks so pick the min.
|
||||
masked_blocks = min(masked_blocks, n_blocks)
|
||||
n_full_blocks = n_blocks - masked_blocks
|
||||
block_min = 0
|
||||
block_max = n_blocks * BLOCK_N
|
||||
# Compute for full blocks. Here we set causal to false regardless of its
|
||||
# value because there is no masking. Similarly we do not need padding.
|
||||
if n_full_blocks > 0:
|
||||
block_max = (n_blocks - masked_blocks) * BLOCK_N
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
start_m,
|
||||
seqlen_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
batch_philox_offset,
|
||||
encoded_softmax_block_ptr,
|
||||
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
|
||||
block_min,
|
||||
block_max,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
bias_ptr,
|
||||
# IS_CAUSAL, ....
|
||||
False,
|
||||
BLOCK_M,
|
||||
BLOCK_DMODEL,
|
||||
BLOCK_N,
|
||||
offs_m,
|
||||
offs_n,
|
||||
# _, MASK_STEPS, ...
|
||||
PRE_LOAD_V,
|
||||
False,
|
||||
ENABLE_DROPOUT,
|
||||
RETURN_ENCODED_SOFTMAX,
|
||||
padded_head,
|
||||
USE_FP8,
|
||||
qk_scale,
|
||||
p_descale,
|
||||
)
|
||||
block_min = block_max
|
||||
block_max = n_blocks * BLOCK_N
|
||||
|
||||
tl.debug_barrier()
|
||||
# Remaining blocks, if any, are full / not masked.
|
||||
if masked_blocks > 0:
|
||||
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
|
||||
if bias_ptr is not None:
|
||||
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
|
||||
(0, n_full_blocks))
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
start_m,
|
||||
seqlen_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
batch_philox_offset,
|
||||
encoded_softmax_block_ptr,
|
||||
block_min,
|
||||
block_max,
|
||||
offs_n_causal,
|
||||
masked_blocks,
|
||||
n_extra_tokens,
|
||||
bias_ptr,
|
||||
IS_CAUSAL,
|
||||
BLOCK_M,
|
||||
BLOCK_DMODEL,
|
||||
BLOCK_N,
|
||||
offs_m,
|
||||
offs_n,
|
||||
# _, MASK_STEPS, ...
|
||||
PRE_LOAD_V,
|
||||
True,
|
||||
ENABLE_DROPOUT,
|
||||
RETURN_ENCODED_SOFTMAX,
|
||||
padded_head,
|
||||
USE_FP8,
|
||||
qk_scale,
|
||||
p_descale,
|
||||
)
|
||||
# epilogue
|
||||
|
||||
if USE_FP8:
|
||||
acc *= acc_scale
|
||||
acc = acc / l_i[:, None]
|
||||
if ENABLE_DROPOUT:
|
||||
acc = acc / (1 - dropout_p)
|
||||
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
|
||||
# then we have one block with a row of all NaNs which come from computing
|
||||
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
|
||||
# and store 0s where there are NaNs as these rows should've been zeroed out.
|
||||
end_m_idx = (start_m + 1) * BLOCK_M
|
||||
start_m_idx = start_m * BLOCK_M
|
||||
causal_start_idx = seqlen_q - seqlen_k
|
||||
if USE_FP8_OUT:
|
||||
acc *= o_descale
|
||||
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||
acc = acc.to(Out.type.element_ty)
|
||||
if IS_CAUSAL: # noqa: SIM102
|
||||
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
|
||||
out_mask_boundary = tl.full((BLOCK_DMODEL, ),
|
||||
causal_start_idx,
|
||||
dtype=tl.int32)
|
||||
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
|
||||
out_ptrs_mask = (mask_m_offsets[:, None]
|
||||
>= out_mask_boundary[None, :])
|
||||
z = tl.zeros((1, ), tl.float32)
|
||||
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
|
||||
# write back LSE
|
||||
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
|
||||
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
|
||||
# few rows. This is only true for the last M block. For others,
|
||||
# overflow_size will be -ve
|
||||
# overflow_size = end_m_idx - seqlen_q
|
||||
# if overflow_size > 0:
|
||||
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
|
||||
# # This is a > check because mask being 0 blocks the store.
|
||||
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
|
||||
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
|
||||
# else:
|
||||
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
||||
|
||||
# write back O
|
||||
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
|
||||
off_h_q * stride_oh)
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out + o_offset,
|
||||
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
# Need boundary check on this to make sure the padding from the
|
||||
# Q and KV tensors in both dims are not part of what we store back.
|
||||
# TODO: Do the boundary check optionally.
|
||||
tl.store(O_block_ptr, acc, boundary_check=(0, 1))
|
||||
|
||||
|
||||
def check_args(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
varlen=True,
|
||||
max_seqlens=None,
|
||||
cu_seqlens_q=None,
|
||||
cu_seqlens_k=None,
|
||||
):
|
||||
assert q.dim() == k.dim() and q.dim() == v.dim()
|
||||
if varlen:
|
||||
assert q.dim() == 3
|
||||
total_q, nheads_q, head_size = q.shape
|
||||
total_k, nheads_k, _ = k.shape
|
||||
assert cu_seqlens_q is not None
|
||||
assert cu_seqlens_k is not None
|
||||
assert len(cu_seqlens_q) == len(cu_seqlens_k)
|
||||
else:
|
||||
assert q.dim() == 4
|
||||
batch, nheads_q, seqlen_q, head_size = q.shape
|
||||
_, nheads_k, seqlen_k, _ = k.shape
|
||||
assert max_seqlens > 0
|
||||
assert k.shape == v.shape
|
||||
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
|
||||
# TODO: Change assert if we support qkl f8 and v f16
|
||||
assert q.dtype == k.dtype and q.dtype == v.dtype
|
||||
assert head_size <= 256
|
||||
assert o.shape == q.shape
|
||||
assert (nheads_q % nheads_k) == 0
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlens_q,
|
||||
max_seqlens_k,
|
||||
causal=False,
|
||||
sm_scale=1.0,
|
||||
bias=None,
|
||||
fp8_scales=None,
|
||||
fp8_out_scale=None,
|
||||
):
|
||||
if fp8_scales is not None:
|
||||
use_fp8 = True
|
||||
(q_scale, k_scale, v_scale, p_scale) = fp8_scales
|
||||
float8 = current_platform.fp8_dtype()
|
||||
|
||||
def check_and_convert(t, scale):
|
||||
if t.dtype != float8:
|
||||
descale = 1.0 / scale
|
||||
ts = (t * descale).clamp(min=float8_info.min,
|
||||
max=float8_info.max)
|
||||
return ts.to(float8)
|
||||
else:
|
||||
return t
|
||||
|
||||
q = check_and_convert(q, q_scale)
|
||||
k = check_and_convert(k, k_scale)
|
||||
v = check_and_convert(v, v_scale)
|
||||
else:
|
||||
use_fp8 = False
|
||||
q_scale = k_scale = v_scale = p_scale = 1.0
|
||||
|
||||
if o is None:
|
||||
o = torch.empty_like(q, dtype=v.dtype)
|
||||
|
||||
check_args(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
varlen=True,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
)
|
||||
if True: # varlen
|
||||
total_q, nheads_q, head_size = q.shape
|
||||
total_k, nheads_k, _ = k.shape
|
||||
batch = len(cu_seqlens_q) - 1
|
||||
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
|
||||
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
|
||||
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
|
||||
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
|
||||
else:
|
||||
batch, seqlen_q, nheads_q, head_size = q.shape
|
||||
_, seqlen_k, nheads_k, _ = k.shape
|
||||
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
|
||||
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
|
||||
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
|
||||
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
|
||||
|
||||
# Get closest power of 2 over or equal to 32.
|
||||
unpadded_head_dims = {32, 64, 128, 256}
|
||||
if head_size not in unpadded_head_dims:
|
||||
padded_d_model = None
|
||||
for i in unpadded_head_dims:
|
||||
if i > head_size:
|
||||
padded_d_model = i
|
||||
break
|
||||
assert padded_d_model is not None
|
||||
else:
|
||||
padded_d_model = head_size
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
|
||||
nheads_q,
|
||||
batch,
|
||||
)
|
||||
|
||||
encoded_softmax = None
|
||||
|
||||
# Seed the RNG so we get reproducible results for testing.
|
||||
philox_seed = 0x1BF52
|
||||
philox_offset = 0x1D4B42
|
||||
|
||||
if bias is not None:
|
||||
bias_strides = (
|
||||
bias.stride(0),
|
||||
bias.stride(1),
|
||||
bias.stride(2),
|
||||
bias.stride(3),
|
||||
)
|
||||
else:
|
||||
bias_strides = (0, 0, 0, 0)
|
||||
|
||||
p_descale = 1.0 / p_scale
|
||||
o_descale = 1.0 / fp8_out_scale.item(
|
||||
) if fp8_out_scale is not None else 1.0
|
||||
|
||||
arg_max_seqlens_q = 0 if on_gfx1x() else max_seqlens_q
|
||||
arg_max_seqlens_k = 0 if on_gfx1x() else max_seqlens_k
|
||||
|
||||
attn_fwd[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
bias,
|
||||
sm_scale,
|
||||
q_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
p_scale,
|
||||
p_descale,
|
||||
o_descale,
|
||||
None,
|
||||
o,
|
||||
*q_strides,
|
||||
*k_strides,
|
||||
*v_strides,
|
||||
*o_strides,
|
||||
*bias_strides,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
dropout_p=0.0,
|
||||
philox_seed=philox_seed,
|
||||
philox_offset_base=philox_offset,
|
||||
encoded_softmax=encoded_softmax,
|
||||
HQ=nheads_q,
|
||||
HK=nheads_k,
|
||||
ACTUAL_BLOCK_DMODEL=head_size,
|
||||
MAX_SEQLENS_Q=arg_max_seqlens_q,
|
||||
MAX_SEQLENS_K=arg_max_seqlens_k,
|
||||
IS_CAUSAL=causal,
|
||||
VARLEN=True,
|
||||
BLOCK_DMODEL=padded_d_model,
|
||||
BIAS_TYPE=0 if bias is None else 1,
|
||||
ENABLE_DROPOUT=False,
|
||||
RETURN_ENCODED_SOFTMAX=False,
|
||||
USE_FP8=use_fp8,
|
||||
USE_FP8_OUT=fp8_out_scale is not None,
|
||||
)
|
||||
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = head_size
|
||||
ctx.causal = causal
|
||||
ctx.dropout_p = 0.0
|
||||
ctx.philox_seed = philox_seed
|
||||
ctx.philox_offset = philox_offset
|
||||
ctx.encoded_softmax = encoded_softmax
|
||||
ctx.return_encoded_softmax = False
|
||||
return o, encoded_softmax
|
||||
|
||||
|
||||
triton_attention = _attention.apply
|
||||
97
vllm/attention/ops/triton_merge_attn_states.py
Normal file
97
vllm/attention/ops/triton_merge_attn_states.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
# can be used to combine partial attention results (in the split-KV case)
|
||||
def merge_attn_states(
|
||||
output: torch.Tensor,
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
suffix_output: torch.Tensor,
|
||||
suffix_lse: torch.Tensor,
|
||||
output_lse: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
num_tokens = output.shape[0]
|
||||
num_query_heads = output.shape[1]
|
||||
head_size = output.shape[2]
|
||||
padded_head_size = triton.next_power_of_2(head_size)
|
||||
|
||||
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
|
||||
merge_attn_states_kernel[(num_tokens, num_query_heads)](
|
||||
output,
|
||||
output_lse,
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
head_size,
|
||||
padded_head_size,
|
||||
output_lse is not None,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def merge_attn_states_kernel(
|
||||
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
output_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
HEAD_SIZE: tl.constexpr,
|
||||
PADDED_HEAD_SIZE: tl.constexpr,
|
||||
OUTPUT_LSE: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
num_tokens = tl.num_programs(0)
|
||||
head_idx = tl.program_id(1)
|
||||
num_heads = tl.num_programs(1)
|
||||
|
||||
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
|
||||
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
|
||||
|
||||
# FA2 and FA3 have different behavior for when the sum-exp is 0, this namely
|
||||
# arises with 0 len seqlens. FA3 returns -inf here while FA2 returns inf.
|
||||
# If we see an inf assume FA2 and convert inf to -inf for consistency
|
||||
# and correctness. Inf generally doesn't make sense in this context outside
|
||||
# of undefined-behavior/FA2-case, so I think this a safe assumption.
|
||||
p_lse = float('-inf') if p_lse == float('inf') else p_lse
|
||||
s_lse = float('-inf') if s_lse == float('inf') else s_lse
|
||||
|
||||
max_lse = tl.maximum(p_lse, s_lse)
|
||||
p_lse = p_lse - max_lse
|
||||
s_lse = s_lse - max_lse
|
||||
# Will reuse precomputed Exp values for scale factor computation.
|
||||
p_se = tl.exp(p_lse)
|
||||
s_se = tl.exp(s_lse)
|
||||
out_se = (p_se + s_se)
|
||||
|
||||
if OUTPUT_LSE:
|
||||
out_lse = tl.log(out_se) + max_lse
|
||||
tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse)
|
||||
|
||||
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
|
||||
head_mask = head_arange < HEAD_SIZE
|
||||
p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE +
|
||||
head_idx * HEAD_SIZE + head_arange,
|
||||
mask=head_mask)
|
||||
s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE +
|
||||
head_idx * HEAD_SIZE + head_arange,
|
||||
mask=head_mask)
|
||||
|
||||
# NOTE(woosuk): Be careful with the numerical stability.
|
||||
# We should compute the scale first, and then multiply it with the output.
|
||||
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
|
||||
p_scale = p_se / out_se
|
||||
s_scale = s_se / out_se
|
||||
out = p_out * p_scale + s_out * s_scale
|
||||
tl.store(output + token_idx * num_heads * HEAD_SIZE +
|
||||
head_idx * HEAD_SIZE + head_arange,
|
||||
out,
|
||||
mask=head_mask)
|
||||
175
vllm/attention/ops/triton_reshape_and_cache_flash.py
Normal file
175
vllm/attention/ops/triton_reshape_and_cache_flash.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def reshape_and_cache_kernel_flash(
|
||||
key_ptr, # [num_tokens, num_heads, head_size]
|
||||
value_ptr, # [num_tokens, num_heads, head_size]
|
||||
key_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
|
||||
value_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
|
||||
slot_mapping_ptr, # [num_tokens]
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
# strides
|
||||
key_stride: tl.int64,
|
||||
value_stride: tl.int64,
|
||||
block_stride: tl.int64,
|
||||
page_stride: tl.int64,
|
||||
num_heads: tl.constexpr,
|
||||
head_size: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
# FP8 flags
|
||||
FP8_KV_CACHE: tl.constexpr,
|
||||
# tune parameters
|
||||
TILE_SIZE: tl.constexpr,
|
||||
):
|
||||
|
||||
token_idx = tl.program_id(axis=0)
|
||||
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64)
|
||||
if slot_idx < 0:
|
||||
# Padding token that should be ignored.
|
||||
return
|
||||
|
||||
tile_i = tl.program_id(axis=1)
|
||||
tile_offs = tl.arange(0, TILE_SIZE)
|
||||
tile_pos = tile_i * TILE_SIZE + tile_offs
|
||||
|
||||
block_idx = slot_idx // block_size
|
||||
block_offset = slot_idx % block_size
|
||||
|
||||
src_key_idx = token_idx * key_stride
|
||||
src_value_idx = token_idx * value_stride
|
||||
|
||||
tgt_idx = block_idx * block_stride + block_offset * page_stride
|
||||
|
||||
# [TILE_SIZE]
|
||||
key_load = tl.load(key_ptr + src_key_idx + tile_pos,
|
||||
mask=tile_pos < (num_heads * head_size))
|
||||
if FP8_KV_CACHE:
|
||||
if key_load.dtype.is_fp8():
|
||||
key_tile = key_load
|
||||
else:
|
||||
# tl.store will do the correct implicit cast to fp8,
|
||||
# based on the key_cache_ptr.dtype.element_ty
|
||||
key_tile = key_load / tl.load(k_scale)
|
||||
else:
|
||||
key_tile = key_load
|
||||
|
||||
# [TILE_SIZE]
|
||||
value_load = tl.load(value_ptr + src_value_idx + tile_pos,
|
||||
mask=tile_pos < (num_heads * head_size))
|
||||
if FP8_KV_CACHE:
|
||||
if value_load.dtype.is_fp8():
|
||||
value_tile = value_load
|
||||
else:
|
||||
# tl.store will do the correct implicit cast to fp8,
|
||||
# based on the value_cache_ptr.dtype.element_ty
|
||||
value_tile = value_load / tl.load(v_scale)
|
||||
else:
|
||||
value_tile = value_load
|
||||
|
||||
tl.store(
|
||||
key_cache_ptr + tgt_idx + tile_pos,
|
||||
key_tile,
|
||||
mask=tile_pos < (num_heads * head_size),
|
||||
)
|
||||
tl.store(
|
||||
value_cache_ptr + tgt_idx + tile_pos,
|
||||
value_tile,
|
||||
mask=tile_pos < (num_heads * head_size),
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def triton_reshape_and_cache_flash(
|
||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
# [num_blocks, block_size, num_heads, head_size]
|
||||
key_cache: torch.Tensor,
|
||||
# [num_blocks, block_size, num_heads, head_size]
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor, # [num_tokens]
|
||||
kv_cache_dtype: str, # "auto", "fp8"
|
||||
k_scale: torch.Tensor, # float32
|
||||
v_scale: torch.Tensor, # float32
|
||||
):
|
||||
num_tokens = key.shape[0]
|
||||
num_heads = key.shape[1]
|
||||
head_size = key.shape[2]
|
||||
block_size = key_cache.shape[1]
|
||||
n = num_heads * head_size
|
||||
|
||||
key_stride = key.stride()[0]
|
||||
value_stride = value.stride()[0]
|
||||
block_stride = key_cache.stride()[0]
|
||||
page_stride = key_cache.stride()[1]
|
||||
|
||||
head_stride = key_cache.stride()[2]
|
||||
assert head_stride == head_size, "only continous heads are supported"
|
||||
|
||||
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), \
|
||||
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
|
||||
kv_cache_torch_dtype = current_platform.fp8_dtype() if \
|
||||
kv_cache_dtype.startswith("fp8") else key_cache.dtype
|
||||
|
||||
if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith(
|
||||
"fp8"):
|
||||
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
|
||||
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
|
||||
key_cache = key_cache.view(kv_cache_torch_dtype)
|
||||
value_cache = value_cache.view(kv_cache_torch_dtype)
|
||||
assert kv_cache_dtype != torch.uint8, "explicit fp8 cast and store to "\
|
||||
"uint8 is not supported by triton reshape_and_cache_flash"
|
||||
|
||||
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
|
||||
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
|
||||
torch.float8_e4m3fn, torch.float8_e5m2, torch.uint8,
|
||||
torch.float8_e4m3fnuz], \
|
||||
"unsupported dtype of KV cache tensor, got "\
|
||||
"{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, " \
|
||||
"fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz."
|
||||
|
||||
# heuristics instead of autotuning
|
||||
TILE_SIZE = min(2048, triton.next_power_of_2(n))
|
||||
if torch.version.hip or torch.version.xpu:
|
||||
num_stages = 4
|
||||
num_warps = 8
|
||||
else: # cuda
|
||||
num_stages = 10
|
||||
num_warps = 16
|
||||
if torch.cuda.get_device_capability(key.device)[0] < 9:
|
||||
TILE_SIZE = min(512, TILE_SIZE)
|
||||
|
||||
# TODO(ngl): maybe replace with static launch grid to avoid overhead if
|
||||
# using cudagraphs
|
||||
grid = lambda meta: (int(num_tokens), triton.cdiv(n, meta["TILE_SIZE"]))
|
||||
|
||||
reshape_and_cache_kernel_flash[grid](
|
||||
key_ptr=key,
|
||||
value_ptr=value,
|
||||
key_cache_ptr=key_cache,
|
||||
value_cache_ptr=value_cache,
|
||||
slot_mapping_ptr=slot_mapping,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
# strides
|
||||
key_stride=key_stride,
|
||||
value_stride=value_stride,
|
||||
block_stride=block_stride,
|
||||
page_stride=page_stride,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
block_size=block_size,
|
||||
# FP8 flags
|
||||
FP8_KV_CACHE=FP8_KV_CACHE,
|
||||
# autotune parameters
|
||||
TILE_SIZE=TILE_SIZE,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
894
vllm/attention/ops/triton_unified_attention.py
Normal file
894
vllm/attention/ops/triton_unified_attention.py
Normal file
@@ -0,0 +1,894 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Authors:
|
||||
# - Burkhard Ringlein <ngl@zurich.ibm.com>
|
||||
# - Jan van Lunteren <jvl@zurich.ibm.com>
|
||||
# - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
|
||||
# - Thomas Parnell <tpa@zurich.ibm.com>
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
logger = init_logger(__name__)
|
||||
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cdiv_fn(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
@triton.jit
|
||||
def apply_softcap(S, x):
|
||||
Sdiv = S / x
|
||||
p1 = tl.exp(Sdiv)
|
||||
p2 = tl.exp(-Sdiv)
|
||||
return x * (p1 - p2) / (p1 + p2)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def find_seq_idx(query_start_len_ptr, target_idx, num_seqs,
|
||||
BLOCK_Q: tl.constexpr, use_q_block_mode: tl.constexpr):
|
||||
left: tl.int32 = 0
|
||||
right = num_seqs
|
||||
while left < right:
|
||||
mid = (left + right) // 2
|
||||
val = tl.load(query_start_len_ptr + mid)
|
||||
mid_val = val // BLOCK_Q + mid if use_q_block_mode else val
|
||||
|
||||
if mid_val <= target_idx:
|
||||
left = mid + 1
|
||||
else:
|
||||
right = mid
|
||||
|
||||
return left - 1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_unified_attention_2d(
|
||||
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||
sink_ptr, # [num_query_heads]
|
||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
alibi_slopes_ptr, # [num_query_heads]
|
||||
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
|
||||
scale, # float32
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
out_scale, # float32
|
||||
softcap, # float32
|
||||
num_query_heads: tl.constexpr, # int
|
||||
num_queries_per_kv: tl.constexpr, # int
|
||||
block_table_stride: tl.int64, # int
|
||||
query_stride_0: tl.int64, # int
|
||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||
output_stride_0: tl.int64, # int
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
qq_bias_stride_0: tl.int64, # int
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
TILE_SIZE: tl.constexpr, # int must be power of 2
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
USE_QQ_BIAS: tl.constexpr, # bool
|
||||
USE_SOFTCAP: tl.constexpr, # bool
|
||||
USE_SINKS: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
stride_k_cache_2: tl.int64, # int
|
||||
stride_k_cache_3: tl.constexpr, # int
|
||||
stride_v_cache_0: tl.int64, # int
|
||||
stride_v_cache_1: tl.int64, # int
|
||||
stride_v_cache_2: tl.int64, # int
|
||||
stride_v_cache_3: tl.constexpr, # int
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
BLOCK_Q: tl.constexpr, # int
|
||||
num_seqs: tl.int32,
|
||||
BLOCK_M: tl.constexpr, # int
|
||||
USE_FP8: tl.constexpr, # bool
|
||||
FP8_MIN: tl.constexpr = float8_info.min,
|
||||
FP8_MAX: tl.constexpr = float8_info.max,
|
||||
):
|
||||
q_block_global_idx = tl.program_id(0)
|
||||
kv_head_idx = tl.program_id(1)
|
||||
|
||||
seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs,
|
||||
BLOCK_Q, True)
|
||||
|
||||
q_block_start_idx = tl.load(query_start_len_ptr +
|
||||
seq_idx) // BLOCK_Q + seq_idx
|
||||
|
||||
q_block_local_idx = q_block_global_idx - q_block_start_idx
|
||||
|
||||
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
||||
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
|
||||
|
||||
cur_batch_query_len = cur_batch_in_all_stop_index \
|
||||
- cur_batch_in_all_start_index
|
||||
|
||||
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
|
||||
return
|
||||
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||
offs_t = tl.arange(0, TILE_SIZE)
|
||||
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
|
||||
|
||||
query_offset_0 = cur_batch_in_all_start_index + query_pos
|
||||
query_offset_1 = kv_head_idx * num_queries_per_kv + \
|
||||
offs_m % num_queries_per_kv
|
||||
query_offset = (query_offset_0[:, None] * query_stride_0 +
|
||||
query_offset_1[:, None] * query_stride_1 + offs_d[None, :])
|
||||
|
||||
dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
|
||||
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
|
||||
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
|
||||
|
||||
# Q : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
Q = tl.load(
|
||||
query_ptr + query_offset,
|
||||
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
block_table_offset = seq_idx * block_table_stride
|
||||
|
||||
if not USE_SINKS:
|
||||
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
else:
|
||||
M = tl.load(
|
||||
sink_ptr + query_offset_1,
|
||||
mask=query_mask_1,
|
||||
other=float("-inf"),
|
||||
).to(dtype=tl.float32)
|
||||
|
||||
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
|
||||
|
||||
# sequence len for this particular sequence
|
||||
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||
|
||||
# context length for this particular sequences
|
||||
context_len = seq_len - cur_batch_query_len
|
||||
|
||||
# alibi slope for this head
|
||||
if USE_ALIBI_SLOPES:
|
||||
alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1,
|
||||
mask=query_mask_1,
|
||||
other=0.0)
|
||||
|
||||
# query-query attention bias
|
||||
if USE_QQ_BIAS:
|
||||
qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
|
||||
) # shape: [BLOCK_M]
|
||||
|
||||
# compute the length of the longest sequence prefix spanned by any
|
||||
# query token in the current q_block (q_block_local_idx)
|
||||
max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
|
||||
BLOCK_M - 1) // num_queries_per_kv + 1
|
||||
|
||||
# adjust for potential padding in the last q_block by considering the
|
||||
# actual sequence length
|
||||
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
|
||||
|
||||
# calculate the number of tiles that need to be processed to
|
||||
# cover the longest sequence prefix (due to causal masking, tiles beyond
|
||||
# this prefix can be skipped)
|
||||
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
|
||||
|
||||
# ---- Sliding-window tile pruning --------------------
|
||||
# Default: keep previous global behavior
|
||||
tile_start = 0
|
||||
tile_end = num_tiles
|
||||
if SLIDING_WINDOW > 0:
|
||||
# Query rows covered by this Q-block
|
||||
qpos_lo = q_block_local_idx * BLOCK_Q
|
||||
qpos_hi = tl.minimum(
|
||||
qpos_lo + (BLOCK_M - 1) // num_queries_per_kv,
|
||||
cur_batch_query_len - 1,
|
||||
)
|
||||
# For sliding window, each query position q can only attend to
|
||||
# keys in the range [q_abs - SLIDING_WINDOW + 1, q_abs]
|
||||
# where q_abs = context_len + q
|
||||
# The union of allowed key positions for this Q-block is:
|
||||
# [context_len + qpos_lo - SLIDING_WINDOW + 1, context_len + qpos_hi]
|
||||
first_allowed_key = context_len + qpos_lo - SLIDING_WINDOW + 1
|
||||
last_allowed_key = context_len + qpos_hi
|
||||
# Convert to tile indices and clamp
|
||||
tile_start = tl.maximum(0, first_allowed_key // TILE_SIZE)
|
||||
tile_end = tl.minimum((last_allowed_key // TILE_SIZE) + 1, num_tiles)
|
||||
|
||||
# iterate through tiles (now limited to the sliding window range)
|
||||
for j in range(tile_start, tile_end):
|
||||
seq_offset = j * TILE_SIZE + offs_t
|
||||
tile_mask = seq_offset < max_seq_prefix_len
|
||||
|
||||
physical_block_idx = tl.load(block_tables_ptr + block_table_offset +
|
||||
seq_offset // BLOCK_SIZE).to(tl.int64)
|
||||
|
||||
v_offset = (physical_block_idx[:, None] * stride_v_cache_0 +
|
||||
kv_head_idx * stride_v_cache_2 +
|
||||
offs_d[None, :] * stride_v_cache_3 +
|
||||
(seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1)
|
||||
|
||||
k_offset = (physical_block_idx[None, :] * stride_k_cache_0 +
|
||||
kv_head_idx * stride_k_cache_2 +
|
||||
offs_d[:, None] * stride_k_cache_3 +
|
||||
(seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1)
|
||||
|
||||
# K : (HEAD_SIZE, TILE_SIZE)
|
||||
K_load = tl.load(key_cache_ptr + k_offset,
|
||||
mask=dim_mask[:, None] & tile_mask[None, :],
|
||||
other=0.0)
|
||||
|
||||
if K_load.dtype.is_fp8():
|
||||
if Q.dtype.is_fp8():
|
||||
K = K_load
|
||||
else:
|
||||
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
||||
else:
|
||||
K = K_load
|
||||
|
||||
# V : (TILE_SIZE, HEAD_SIZE)
|
||||
V_load = tl.load(value_cache_ptr + v_offset,
|
||||
mask=dim_mask[None, :] & tile_mask[:, None],
|
||||
other=0.0)
|
||||
|
||||
if V_load.dtype.is_fp8():
|
||||
if Q.dtype.is_fp8():
|
||||
V = V_load
|
||||
else:
|
||||
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
||||
else:
|
||||
V = V_load
|
||||
|
||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
||||
|
||||
# S : (BLOCK_M, TILE_SIZE)
|
||||
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||
|
||||
S += scale * tl.dot(Q, K)
|
||||
|
||||
if USE_SOFTCAP:
|
||||
S = apply_softcap(S, softcap)
|
||||
|
||||
S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask,
|
||||
S, float("-inf"))
|
||||
|
||||
if SLIDING_WINDOW > 0:
|
||||
S = tl.where((context_len + query_pos[:, None] - seq_offset)
|
||||
< SLIDING_WINDOW, S, float("-inf"))
|
||||
|
||||
if USE_ALIBI_SLOPES:
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
|
||||
if USE_QQ_BIAS:
|
||||
# compute key positions relative to query section
|
||||
key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE]
|
||||
# load bias only for keys that correspond to queries
|
||||
is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0
|
||||
qq_bias = tl.load(
|
||||
qq_bias_row_ptrs + key_rel_pos[None, :],
|
||||
mask=is_query_key[None, :], # avoid OOB for context keys
|
||||
other=0.0,
|
||||
)
|
||||
S += qq_bias
|
||||
|
||||
# compute running maximum
|
||||
# m_j : (BLOCK_M,)
|
||||
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||
|
||||
# For sliding window there's a chance the max is -inf due to masking of
|
||||
# the entire row. In this case we need to set m_j 0 to avoid NaN
|
||||
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
|
||||
|
||||
# P : (BLOCK_M, TILE_SIZE)
|
||||
P = tl.exp(S - m_j[:, None])
|
||||
|
||||
# l_j : (BLOCK_M,)
|
||||
l_j = tl.sum(P, axis=1)
|
||||
|
||||
# alpha : (BLOCK_M, )
|
||||
alpha = tl.exp(M - m_j)
|
||||
|
||||
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update constants
|
||||
L = L * alpha + l_j
|
||||
M = m_j
|
||||
|
||||
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
acc += tl.dot(P.to(V.dtype), V)
|
||||
|
||||
# epilogue
|
||||
acc = acc / L[:, None]
|
||||
if USE_FP8:
|
||||
acc = acc * tl.load(out_scale)
|
||||
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||
|
||||
output_offset = (query_offset_0[:, None] * output_stride_0 +
|
||||
query_offset_1[:, None] * output_stride_1 +
|
||||
offs_d[None, :])
|
||||
|
||||
tl.store(
|
||||
output_ptr + output_offset,
|
||||
acc,
|
||||
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_unified_attention_3d(
|
||||
segm_output_ptr,
|
||||
# [num_tokens, num_query_heads, num_segments, head_size]
|
||||
segm_max_ptr, # [num_tokens, num_query_heads, num_segments]
|
||||
segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments]
|
||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
|
||||
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
|
||||
sink_ptr, # [num_query_heads]
|
||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
alibi_slopes_ptr, # [num_query_heads]
|
||||
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
|
||||
scale, # float32
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
softcap, # float32
|
||||
num_query_heads: tl.constexpr, # int
|
||||
num_queries_per_kv: tl.constexpr, # int
|
||||
block_table_stride: tl.int64, # int
|
||||
query_stride_0: tl.int64, # int
|
||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||
qq_bias_stride_0: tl.int64, # int
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
TILE_SIZE: tl.constexpr, # int, must be power of 2
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
USE_QQ_BIAS: tl.constexpr, # bool
|
||||
USE_SOFTCAP: tl.constexpr, # bool
|
||||
USE_SINKS: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
stride_k_cache_2: tl.int64, # int
|
||||
stride_k_cache_3: tl.constexpr, # int
|
||||
stride_v_cache_0: tl.int64, # int
|
||||
stride_v_cache_1: tl.int64, # int
|
||||
stride_v_cache_2: tl.int64, # int
|
||||
stride_v_cache_3: tl.constexpr, # int
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
BLOCK_Q: tl.constexpr, # int
|
||||
num_seqs: tl.int32,
|
||||
BLOCK_M: tl.constexpr, # int
|
||||
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
||||
):
|
||||
q_block_global_idx = tl.program_id(0)
|
||||
kv_head_idx = tl.program_id(1)
|
||||
segm_idx = tl.program_id(2)
|
||||
|
||||
seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs,
|
||||
BLOCK_Q, True)
|
||||
|
||||
q_block_start_idx = tl.load(query_start_len_ptr +
|
||||
seq_idx) // BLOCK_Q + seq_idx
|
||||
|
||||
q_block_local_idx = q_block_global_idx - q_block_start_idx
|
||||
|
||||
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
||||
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
|
||||
|
||||
cur_batch_query_len = cur_batch_in_all_stop_index \
|
||||
- cur_batch_in_all_start_index
|
||||
|
||||
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
|
||||
return
|
||||
|
||||
# sequence len for this particular sequence
|
||||
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||
|
||||
# number of segments for this particular sequence
|
||||
num_segments = NUM_SEGMENTS_PER_SEQ
|
||||
tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE)
|
||||
|
||||
if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len:
|
||||
return
|
||||
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||
offs_t = tl.arange(0, TILE_SIZE)
|
||||
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
|
||||
|
||||
query_offset_0 = cur_batch_in_all_start_index + query_pos
|
||||
query_offset_1 = kv_head_idx * num_queries_per_kv + \
|
||||
offs_m % num_queries_per_kv
|
||||
query_offset = (query_offset_0[:, None] * query_stride_0 +
|
||||
query_offset_1[:, None] * query_stride_1 + offs_d[None, :])
|
||||
|
||||
dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
|
||||
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
|
||||
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
|
||||
|
||||
# Q : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
Q = tl.load(
|
||||
query_ptr + query_offset,
|
||||
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
block_table_offset = seq_idx * block_table_stride
|
||||
|
||||
if USE_SINKS:
|
||||
if segm_idx == 0:
|
||||
M = tl.load(
|
||||
sink_ptr + query_offset_1,
|
||||
mask=query_mask_1,
|
||||
other=float("-inf"),
|
||||
).to(dtype=tl.float32)
|
||||
else:
|
||||
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
else:
|
||||
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
|
||||
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
|
||||
|
||||
# context length for this particular sequences
|
||||
context_len = seq_len - cur_batch_query_len
|
||||
|
||||
# alibi slope for this head
|
||||
if USE_ALIBI_SLOPES:
|
||||
alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1,
|
||||
mask=query_mask_1,
|
||||
other=0.0)
|
||||
|
||||
# query-query attention bias
|
||||
if USE_QQ_BIAS:
|
||||
qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
|
||||
) # shape: [BLOCK_M]
|
||||
|
||||
# compute the length of the longest sequence prefix spanned by any
|
||||
# query token in the current q_block (q_block_local_idx)
|
||||
max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
|
||||
BLOCK_M - 1) // num_queries_per_kv + 1
|
||||
|
||||
# adjust for potential padding in the last q_block by considering the
|
||||
# actual sequence length
|
||||
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
|
||||
|
||||
# calculate the number of tiles that need to be processed to
|
||||
# cover the longest sequence prefix (due to causal masking, tiles beyond
|
||||
# this prefix can be skipped)
|
||||
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
|
||||
|
||||
# iterate through tiles within current segment
|
||||
for j in range(
|
||||
segm_idx * tiles_per_segment,
|
||||
min((segm_idx + 1) * tiles_per_segment, num_tiles),
|
||||
):
|
||||
seq_offset = j * TILE_SIZE + offs_t
|
||||
tile_mask = seq_offset < max_seq_prefix_len
|
||||
|
||||
physical_block_idx = tl.load(block_tables_ptr + block_table_offset +
|
||||
seq_offset // BLOCK_SIZE).to(tl.int64)
|
||||
|
||||
v_offset = (physical_block_idx[:, None] * stride_v_cache_0 +
|
||||
kv_head_idx * stride_v_cache_2 +
|
||||
offs_d[None, :] * stride_v_cache_3 +
|
||||
(seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1)
|
||||
|
||||
k_offset = (physical_block_idx[None, :] * stride_k_cache_0 +
|
||||
kv_head_idx * stride_k_cache_2 +
|
||||
offs_d[:, None] * stride_k_cache_3 +
|
||||
(seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1)
|
||||
|
||||
# K : (HEAD_SIZE, TILE_SIZE)
|
||||
K_load = tl.load(key_cache_ptr + k_offset,
|
||||
mask=dim_mask[:, None] & tile_mask[None, :],
|
||||
other=0.0)
|
||||
|
||||
if K_load.dtype.is_fp8():
|
||||
if Q.dtype.is_fp8():
|
||||
K = K_load
|
||||
else:
|
||||
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
||||
else:
|
||||
K = K_load
|
||||
|
||||
# V : (TILE_SIZE, HEAD_SIZE)
|
||||
V_load = tl.load(value_cache_ptr + v_offset,
|
||||
mask=dim_mask[None, :] & tile_mask[:, None],
|
||||
other=0.0)
|
||||
|
||||
if V_load.dtype.is_fp8():
|
||||
if Q.dtype.is_fp8():
|
||||
V = V_load
|
||||
else:
|
||||
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
||||
else:
|
||||
V = V_load
|
||||
|
||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
||||
|
||||
# S : (BLOCK_M, TILE_SIZE)
|
||||
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||
S += scale * tl.dot(Q, K)
|
||||
|
||||
if USE_SOFTCAP:
|
||||
S = apply_softcap(S, softcap)
|
||||
|
||||
S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask,
|
||||
S, float("-inf"))
|
||||
|
||||
if SLIDING_WINDOW > 0:
|
||||
S = tl.where((context_len + query_pos[:, None] - seq_offset)
|
||||
< SLIDING_WINDOW, S, float("-inf"))
|
||||
|
||||
if USE_ALIBI_SLOPES:
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
|
||||
if USE_QQ_BIAS:
|
||||
# compute key positions relative to query section
|
||||
key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE]
|
||||
# load bias only for keys that correspond to queries
|
||||
is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0
|
||||
qq_bias = tl.load(
|
||||
qq_bias_row_ptrs + key_rel_pos[None, :],
|
||||
mask=is_query_key[None, :], # avoid OOB for context keys
|
||||
other=0.0,
|
||||
)
|
||||
S += qq_bias
|
||||
|
||||
# compute running maximum
|
||||
# m_j : (BLOCK_M,)
|
||||
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||
|
||||
# For sliding window there's a chance the max is -inf due to masking of
|
||||
# the entire row. In this case we need to set m_j 0 to avoid NaN
|
||||
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
|
||||
|
||||
# P : (BLOCK_M, TILE_SIZE,)
|
||||
P = tl.exp(S - m_j[:, None])
|
||||
|
||||
# l_j : (BLOCK_M,)
|
||||
l_j = tl.sum(P, axis=1)
|
||||
|
||||
# alpha : (BLOCK_M, )
|
||||
alpha = tl.exp(M - m_j)
|
||||
|
||||
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update constants
|
||||
L = L * alpha + l_j
|
||||
M = m_j
|
||||
|
||||
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
acc += tl.dot(P.to(V.dtype), V)
|
||||
|
||||
segm_output_offset = (
|
||||
query_offset_0[:, None].to(tl.int64) *
|
||||
(num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
|
||||
query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
|
||||
segm_idx * HEAD_SIZE_PADDED + tl.arange(0, HEAD_SIZE_PADDED)[None, :])
|
||||
tl.store(
|
||||
segm_output_ptr + segm_output_offset,
|
||||
acc,
|
||||
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||
)
|
||||
segm_offset = (query_offset_0.to(tl.int64) *
|
||||
(num_query_heads * NUM_SEGMENTS_PER_SEQ) +
|
||||
query_offset_1 * NUM_SEGMENTS_PER_SEQ + segm_idx)
|
||||
tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1)
|
||||
tl.store(segm_expsum_ptr + segm_offset,
|
||||
L,
|
||||
mask=query_mask_0 & query_mask_1)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def reduce_segments(
|
||||
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
segm_output_ptr,
|
||||
#[num_tokens, num_query_heads, max_num_segments, head_size]
|
||||
segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments]
|
||||
segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
num_seqs, # int
|
||||
num_query_heads: tl.constexpr, # int
|
||||
out_scale_inv, # float32
|
||||
output_stride_0: tl.int64, # int
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
block_table_stride: tl.int64, # int
|
||||
TILE_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE: tl.constexpr, # int, must be power of 2
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
BLOCK_Q: tl.constexpr, # int
|
||||
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
||||
USE_FP8: tl.constexpr, # bool
|
||||
FP8_MIN: tl.constexpr = float8_info.min,
|
||||
FP8_MAX: tl.constexpr = float8_info.max,
|
||||
):
|
||||
query_token_idx = tl.program_id(0)
|
||||
query_head_idx = tl.program_id(1)
|
||||
|
||||
seq_idx = find_seq_idx(query_start_len_ptr, query_token_idx, num_seqs,
|
||||
BLOCK_Q, False)
|
||||
|
||||
# sequence len for this particular sequence
|
||||
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||
|
||||
# number of segments for this particular sequence
|
||||
num_segments = NUM_SEGMENTS_PER_SEQ
|
||||
tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE)
|
||||
|
||||
# create masks for subsequent loads
|
||||
act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE)
|
||||
segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full(
|
||||
[NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32)
|
||||
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
|
||||
0).to(tl.int1)
|
||||
|
||||
# load segment maxima
|
||||
segm_offset = (query_token_idx.to(tl.int64) *
|
||||
(num_query_heads * NUM_SEGMENTS_PER_SEQ) +
|
||||
query_head_idx * NUM_SEGMENTS_PER_SEQ +
|
||||
tl.arange(0, NUM_SEGMENTS_PER_SEQ))
|
||||
segm_max = tl.load(segm_max_ptr + segm_offset,
|
||||
mask=segm_mask,
|
||||
other=float("-inf"))
|
||||
overall_max = tl.max(segm_max)
|
||||
|
||||
# load and rescale segment exp sums
|
||||
segm_expsum = tl.load(segm_expsum_ptr + segm_offset,
|
||||
mask=segm_mask,
|
||||
other=0.0)
|
||||
segm_expsum = segm_expsum * tl.exp(segm_max - overall_max)
|
||||
overall_expsum = tl.sum(segm_expsum)
|
||||
|
||||
# load, rescale, and add segment attention outputs
|
||||
segm_output_offset = (
|
||||
query_token_idx.to(tl.int64) *
|
||||
(num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
|
||||
query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
|
||||
tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED +
|
||||
tl.arange(0, HEAD_SIZE_PADDED)[None, :])
|
||||
segm_output = tl.load(
|
||||
segm_output_ptr + segm_output_offset,
|
||||
mask=segm_mask[:, None] & dim_mask[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
segm_output *= tl.exp(segm_max - overall_max)[:, None]
|
||||
acc_sum = tl.sum(segm_output, axis=0)
|
||||
# safely divide by overall_expsum, returning 0.0 if overall_expsum is 0
|
||||
acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum)
|
||||
|
||||
if USE_FP8:
|
||||
acc = acc * tl.load(out_scale_inv)
|
||||
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||
|
||||
# write result
|
||||
output_offset = (query_token_idx * output_stride_0 +
|
||||
query_head_idx * output_stride_1 +
|
||||
tl.arange(0, HEAD_SIZE_PADDED))
|
||||
tl.store(output_ptr + output_offset, acc, mask=dim_mask)
|
||||
|
||||
|
||||
def unified_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
seqused_k,
|
||||
max_seqlen_k,
|
||||
softmax_scale,
|
||||
causal,
|
||||
window_size,
|
||||
block_table,
|
||||
softcap,
|
||||
q_descale,
|
||||
k_descale,
|
||||
v_descale,
|
||||
alibi_slopes=None,
|
||||
output_scale=None,
|
||||
qq_bias=None,
|
||||
# Optional tensor for sinks
|
||||
sinks=None,
|
||||
):
|
||||
|
||||
assert causal, "Only causal attention is supported"
|
||||
assert q_descale is None, "Q scales not supported"
|
||||
|
||||
if sinks is not None:
|
||||
assert sinks.shape[0] == q.shape[1], \
|
||||
"Sinks must be num_query_heads size"
|
||||
|
||||
use_alibi_slopes = alibi_slopes is not None
|
||||
use_qq_bias = qq_bias is not None
|
||||
|
||||
block_size = v.shape[1]
|
||||
num_seqs = len(seqused_k)
|
||||
num_query_heads = q.shape[1]
|
||||
num_kv_heads = k.shape[2]
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
head_size = q.shape[2]
|
||||
|
||||
BLOCK_M = 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(
|
||||
num_queries_per_kv)
|
||||
BLOCK_Q = BLOCK_M // num_queries_per_kv
|
||||
|
||||
# Ideally we would launch with kernel with:
|
||||
# \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks.
|
||||
# However, it is slow to realize the query_lens on cpu.
|
||||
# Instead we use upper-bound:
|
||||
# \sum_i[ceil(query_len[i] / BLOCK_Q)]
|
||||
# <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1]
|
||||
# = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs
|
||||
# <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs
|
||||
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
|
||||
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
|
||||
|
||||
# Assigning default tile sizes for prefill and decode.
|
||||
# Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1)
|
||||
# and at least 16 for all other data types.
|
||||
TILE_SIZE_PREFILL = 32
|
||||
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
|
||||
|
||||
# if batch contains a prefill
|
||||
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
|
||||
kernel_unified_attention_2d[(
|
||||
total_num_q_blocks,
|
||||
num_kv_heads,
|
||||
)](
|
||||
output_ptr=out,
|
||||
query_ptr=q,
|
||||
key_cache_ptr=k,
|
||||
value_cache_ptr=v,
|
||||
sink_ptr=sinks,
|
||||
block_tables_ptr=block_table,
|
||||
seq_lens_ptr=seqused_k,
|
||||
alibi_slopes_ptr=alibi_slopes,
|
||||
qq_bias_ptr=qq_bias,
|
||||
scale=softmax_scale,
|
||||
k_scale=k_descale,
|
||||
v_scale=v_descale,
|
||||
out_scale=1 / output_scale if output_scale is not None else 1.0,
|
||||
softcap=softcap,
|
||||
num_query_heads=num_query_heads,
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
block_table_stride=block_table.stride(0),
|
||||
query_stride_0=q.stride(0),
|
||||
query_stride_1=q.stride(1),
|
||||
output_stride_0=out.stride(0),
|
||||
output_stride_1=out.stride(1),
|
||||
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
|
||||
BLOCK_SIZE=block_size,
|
||||
TILE_SIZE=TILE_SIZE_PREFILL,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
USE_QQ_BIAS=use_qq_bias,
|
||||
USE_SOFTCAP=(softcap > 0),
|
||||
USE_SINKS=(sinks is not None),
|
||||
SLIDING_WINDOW=(1 + window_size[0]),
|
||||
stride_k_cache_0=k.stride(0),
|
||||
stride_k_cache_1=k.stride(1),
|
||||
stride_k_cache_2=k.stride(2),
|
||||
stride_k_cache_3=k.stride(3),
|
||||
stride_v_cache_0=v.stride(0),
|
||||
stride_v_cache_1=v.stride(1),
|
||||
stride_v_cache_2=v.stride(2),
|
||||
stride_v_cache_3=v.stride(3),
|
||||
query_start_len_ptr=cu_seqlens_q,
|
||||
BLOCK_Q=BLOCK_Q,
|
||||
num_seqs=num_seqs,
|
||||
BLOCK_M=BLOCK_M,
|
||||
USE_FP8=output_scale is not None,
|
||||
)
|
||||
else:
|
||||
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
|
||||
# value that showed good performance in tests
|
||||
NUM_SEGMENTS = 16
|
||||
|
||||
segm_output = torch.empty(
|
||||
q.shape[0],
|
||||
num_query_heads,
|
||||
NUM_SEGMENTS,
|
||||
triton.next_power_of_2(head_size),
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
segm_max = torch.empty(
|
||||
q.shape[0],
|
||||
num_query_heads,
|
||||
NUM_SEGMENTS,
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
segm_expsum = torch.empty(
|
||||
q.shape[0],
|
||||
num_query_heads,
|
||||
NUM_SEGMENTS,
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
|
||||
kernel_unified_attention_3d[(
|
||||
total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)](
|
||||
segm_output_ptr=segm_output,
|
||||
segm_max_ptr=segm_max,
|
||||
segm_expsum_ptr=segm_expsum,
|
||||
query_ptr=q,
|
||||
key_cache_ptr=k,
|
||||
value_cache_ptr=v,
|
||||
sink_ptr=sinks,
|
||||
block_tables_ptr=block_table,
|
||||
seq_lens_ptr=seqused_k,
|
||||
alibi_slopes_ptr=alibi_slopes,
|
||||
qq_bias_ptr=qq_bias,
|
||||
scale=softmax_scale,
|
||||
k_scale=k_descale,
|
||||
v_scale=v_descale,
|
||||
softcap=softcap,
|
||||
num_query_heads=num_query_heads,
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
block_table_stride=block_table.stride(0),
|
||||
query_stride_0=q.stride(0),
|
||||
query_stride_1=q.stride(1),
|
||||
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
|
||||
BLOCK_SIZE=block_size,
|
||||
TILE_SIZE=TILE_SIZE_DECODE,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
USE_QQ_BIAS=use_qq_bias,
|
||||
USE_SOFTCAP=(softcap > 0),
|
||||
USE_SINKS=(sinks is not None),
|
||||
SLIDING_WINDOW=(1 + window_size[0]),
|
||||
stride_k_cache_0=k.stride(0),
|
||||
stride_k_cache_1=k.stride(1),
|
||||
stride_k_cache_2=k.stride(2),
|
||||
stride_k_cache_3=k.stride(3),
|
||||
stride_v_cache_0=v.stride(0),
|
||||
stride_v_cache_1=v.stride(1),
|
||||
stride_v_cache_2=v.stride(2),
|
||||
stride_v_cache_3=v.stride(3),
|
||||
query_start_len_ptr=cu_seqlens_q,
|
||||
BLOCK_Q=BLOCK_Q,
|
||||
num_seqs=num_seqs,
|
||||
BLOCK_M=BLOCK_M,
|
||||
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
|
||||
)
|
||||
reduce_segments[(q.shape[0], num_query_heads)](
|
||||
output_ptr=out,
|
||||
segm_output_ptr=segm_output,
|
||||
segm_max_ptr=segm_max,
|
||||
segm_expsum_ptr=segm_expsum,
|
||||
seq_lens_ptr=seqused_k,
|
||||
num_seqs=num_seqs,
|
||||
num_query_heads=num_query_heads,
|
||||
out_scale_inv=1 /
|
||||
output_scale if output_scale is not None else 1.0,
|
||||
output_stride_0=out.stride(0),
|
||||
output_stride_1=out.stride(1),
|
||||
block_table_stride=block_table.stride(0),
|
||||
TILE_SIZE=TILE_SIZE_DECODE,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
query_start_len_ptr=cu_seqlens_q,
|
||||
BLOCK_Q=BLOCK_Q,
|
||||
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
|
||||
USE_FP8=output_scale is not None,
|
||||
)
|
||||
245
vllm/attention/selector.py
Normal file
245
vllm/attention/selector.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from typing import Generator, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
|
||||
"""
|
||||
Convert a string backend name to a _Backend enum value.
|
||||
|
||||
Returns:
|
||||
* _Backend: enum value if backend_name is a valid in-tree type
|
||||
* None: otherwise it's an invalid in-tree type or an out-of-tree platform is
|
||||
loaded.
|
||||
"""
|
||||
assert backend_name is not None
|
||||
return _Backend[backend_name] if backend_name in _Backend.__members__ else \
|
||||
None
|
||||
|
||||
|
||||
def get_env_variable_attn_backend() -> Optional[_Backend]:
|
||||
'''
|
||||
Get the backend override specified by the vLLM attention
|
||||
backend environment variable, if one is specified.
|
||||
|
||||
Returns:
|
||||
|
||||
* _Backend enum value if an override is specified
|
||||
* None otherwise
|
||||
'''
|
||||
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
|
||||
return (None
|
||||
if backend_name is None else backend_name_to_enum(backend_name))
|
||||
|
||||
|
||||
# Global state allows a particular choice of backend
|
||||
# to be forced, overriding the logic which auto-selects
|
||||
# a backend based on system & workload configuration
|
||||
# (default behavior if this variable is None)
|
||||
#
|
||||
# THIS SELECTION TAKES PRECEDENCE OVER THE
|
||||
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
|
||||
forced_attn_backend: Optional[_Backend] = None
|
||||
|
||||
|
||||
def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
|
||||
'''
|
||||
Force all attention operations to use a specified backend.
|
||||
|
||||
Passing `None` for the argument re-enables automatic
|
||||
backend selection.,
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_backend: backend selection (None to revert to auto)
|
||||
'''
|
||||
global forced_attn_backend
|
||||
forced_attn_backend = attn_backend
|
||||
|
||||
|
||||
def get_global_forced_attn_backend() -> Optional[_Backend]:
|
||||
'''
|
||||
Get the currently-forced choice of attention backend,
|
||||
or None if auto-selection is currently enabled.
|
||||
'''
|
||||
return forced_attn_backend
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _IsSupported:
|
||||
can_import: bool
|
||||
head_size: bool
|
||||
dtype: bool
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return self.can_import and self.head_size and self.dtype
|
||||
|
||||
|
||||
def is_attn_backend_supported(
|
||||
attn_backend: Union[str, type[AttentionBackend]],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
allow_import_error: bool = True,
|
||||
) -> _IsSupported:
|
||||
if isinstance(attn_backend, str):
|
||||
try:
|
||||
attn_backend = resolve_obj_by_qualname(attn_backend)
|
||||
except ImportError:
|
||||
if not allow_import_error:
|
||||
raise
|
||||
|
||||
return _IsSupported(can_import=False, head_size=False, dtype=False)
|
||||
|
||||
assert isinstance(attn_backend, type)
|
||||
|
||||
# TODO: Update the interface once V0 is removed
|
||||
if get_supported_head_sizes := getattr(attn_backend,
|
||||
"get_supported_head_sizes", None):
|
||||
is_head_size_supported = head_size in get_supported_head_sizes()
|
||||
elif validate_head_size := getattr(attn_backend, "validate_head_size",
|
||||
None):
|
||||
try:
|
||||
validate_head_size(head_size)
|
||||
is_head_size_supported = True
|
||||
except Exception:
|
||||
is_head_size_supported = False
|
||||
else:
|
||||
raise NotImplementedError(f"{attn_backend.__name__} does not support "
|
||||
"head size validation")
|
||||
|
||||
if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes",
|
||||
None):
|
||||
is_dtype_supported = dtype in get_supported_dtypes()
|
||||
else:
|
||||
raise NotImplementedError(f"{attn_backend.__name__} does not support "
|
||||
"dtype validation")
|
||||
|
||||
return _IsSupported(
|
||||
can_import=True,
|
||||
head_size=is_head_size_supported,
|
||||
dtype=is_dtype_supported,
|
||||
)
|
||||
|
||||
|
||||
def get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||
# value to be returned from the cache if the value changes between calls.
|
||||
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
|
||||
# private function.
|
||||
return _cached_get_attn_backend(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
block_size=block_size,
|
||||
use_v1=envs.VLLM_USE_V1,
|
||||
use_mla=use_mla,
|
||||
has_sink=has_sink,
|
||||
use_sparse=use_sparse,
|
||||
)
|
||||
|
||||
|
||||
@cache
|
||||
def _cached_get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
|
||||
# Check whether a particular choice of backend was
|
||||
# previously forced.
|
||||
#
|
||||
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||
# ENVIRONMENT VARIABLE.
|
||||
selected_backend = None
|
||||
backend_by_global_setting: Optional[_Backend] = (
|
||||
get_global_forced_attn_backend())
|
||||
if backend_by_global_setting is not None:
|
||||
selected_backend = backend_by_global_setting
|
||||
else:
|
||||
# Check the environment variable and override if specified
|
||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||
if backend_by_env_var is not None:
|
||||
if backend_by_env_var.endswith("_VLLM_V1"):
|
||||
logger.warning(
|
||||
"The suffix '_VLLM_V1' in the environment variable "
|
||||
"%s is no longer necessary as V0 backends have been "
|
||||
"deprecated. Please remove this suffix from your "
|
||||
"environment variable setting.", STR_BACKEND_ENV_VAR)
|
||||
backend_by_env_var = backend_by_env_var.removesuffix(
|
||||
"_VLLM_V1")
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
if selected_backend is None:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: '{backend_by_env_var}'. "
|
||||
f"Valid backends are: {list(_Backend.__members__.keys())}")
|
||||
|
||||
# get device-specific attn_backend
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
|
||||
use_mla, has_sink, use_sparse)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}")
|
||||
return resolve_obj_by_qualname(attention_cls)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def global_force_attn_backend_context_manager(
|
||||
attn_backend: _Backend) -> Generator[None, None, None]:
|
||||
'''
|
||||
Globally force a vLLM attention backend override within a
|
||||
context manager, reverting the global attention backend
|
||||
override to its prior state upon exiting the context
|
||||
manager.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_backend: attention backend to force
|
||||
|
||||
Returns:
|
||||
|
||||
* Generator
|
||||
'''
|
||||
|
||||
# Save the current state of the global backend override (if any)
|
||||
original_value = get_global_forced_attn_backend()
|
||||
|
||||
# Globally force the new backend override
|
||||
global_force_attn_backend(attn_backend)
|
||||
|
||||
# Yield control back to the enclosed code block
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Revert the original global backend override, if any
|
||||
global_force_attn_backend(original_value)
|
||||
0
vllm/attention/utils/__init__.py
Normal file
0
vllm/attention/utils/__init__.py
Normal file
85
vllm/attention/utils/fa_utils.py
Normal file
85
vllm/attention/utils/fa_utils.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm import _custom_ops as ops
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
get_scheduler_metadata)
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
flash_attn_varlen_func = ops.flash_attn_varlen_func
|
||||
get_scheduler_metadata = ops.get_scheduler_metadata
|
||||
|
||||
|
||||
def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
|
||||
# import here to avoid circular dependencies
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_xpu():
|
||||
return 2
|
||||
try:
|
||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
fa_version_unsupported_reason, is_fa_version_supported)
|
||||
device_capability = current_platform.get_device_capability()
|
||||
|
||||
assert device_capability is not None
|
||||
|
||||
# 1. default version depending on platform
|
||||
fa_version = 3 if (device_capability.major == 9
|
||||
and is_fa_version_supported(3)) else 2
|
||||
|
||||
# 2. override if passed by environment
|
||||
if envs.VLLM_FLASH_ATTN_VERSION is not None:
|
||||
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
|
||||
fa_version = envs.VLLM_FLASH_ATTN_VERSION
|
||||
|
||||
# 3. fallback for unsupported combinations
|
||||
if device_capability.major == 10 and fa_version == 3:
|
||||
logger.warning_once(
|
||||
"Cannot use FA version 3 on Blackwell platform "
|
||||
"defaulting to FA version 2.")
|
||||
fa_version = 2
|
||||
|
||||
if requires_alibi and fa_version == 3:
|
||||
logger.warning_once("Cannot use FA version 3 with ALiBi, "
|
||||
"defaulting to FA version 2.")
|
||||
fa_version = 2
|
||||
|
||||
if not is_fa_version_supported(fa_version):
|
||||
logger.error("Cannot use FA version %d is not supported due to %s",
|
||||
fa_version, fa_version_unsupported_reason(fa_version))
|
||||
|
||||
assert is_fa_version_supported(fa_version)
|
||||
return fa_version
|
||||
except (ImportError, AssertionError):
|
||||
return None
|
||||
|
||||
|
||||
def flash_attn_supports_fp8() -> bool:
|
||||
return get_flash_attn_version() == 3 and \
|
||||
current_platform.get_device_capability().major == 9
|
||||
|
||||
|
||||
def flash_attn_supports_mla():
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_cuda():
|
||||
try:
|
||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
is_fa_version_supported)
|
||||
return is_fa_version_supported(3) \
|
||||
and current_platform.get_device_capability()[0] == 9
|
||||
except (ImportError, AssertionError):
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def is_flash_attn_varlen_func_available() -> bool:
|
||||
return current_platform.is_cuda() or current_platform.is_xpu()
|
||||
33
vllm/attention/utils/kv_sharing_utils.py
Normal file
33
vllm/attention/utils/kv_sharing_utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
def validate_kv_sharing_target(current_layer_name, target_layer_name,
|
||||
static_forward_context):
|
||||
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "
|
||||
f"is not valid: target layer {target_layer_name} ")
|
||||
|
||||
if current_layer_name == target_layer_name:
|
||||
raise ValueError(error_msg +
|
||||
"cannot be the same as the current layer.")
|
||||
|
||||
if target_layer_name not in static_forward_context:
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
|
||||
# If target layer name is not in the static fwd context, it means either
|
||||
# a) the target layer does not come BEFORE the current layer, or
|
||||
# b) the target layer is not an Attention layer that exists in the model
|
||||
current_layer_idx = extract_layer_index(current_layer_name)
|
||||
target_layer_idx = extract_layer_index(target_layer_name)
|
||||
if current_layer_idx <= target_layer_idx:
|
||||
raise ValueError(error_msg + "must come before the current layer.")
|
||||
else:
|
||||
raise ValueError(error_msg +
|
||||
"is not a valid Attention layer in the model.")
|
||||
|
||||
# Currently KV sharing is only supported between layers of the same type
|
||||
target_layer_attn_type = static_forward_context[
|
||||
target_layer_name].attn_type
|
||||
expected = static_forward_context[current_layer_name].attn_type
|
||||
if target_layer_attn_type != expected:
|
||||
raise ValueError(
|
||||
error_msg +
|
||||
f"must be the same type as the current layer ({expected}).")
|
||||
Reference in New Issue
Block a user