v1.0
This commit is contained in:
107
__init__.py
Normal file
107
__init__.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
|
||||||
|
|
||||||
|
# The version.py should be independent library, and we always import the
|
||||||
|
# version library first. Such assumption is critical for some customization.
|
||||||
|
from .version import __version__, __version_tuple__ # isort:skip
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
# The environment variables override should be imported before any other
|
||||||
|
# modules to ensure that the environment variables are set before any
|
||||||
|
# other modules are imported.
|
||||||
|
import vllm.env_override # noqa: F401
|
||||||
|
|
||||||
|
MODULE_ATTRS = {
|
||||||
|
"bc_linter_skip": "._bc_linter:bc_linter_skip",
|
||||||
|
"bc_linter_include": "._bc_linter:bc_linter_include",
|
||||||
|
"AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs",
|
||||||
|
"EngineArgs": ".engine.arg_utils:EngineArgs",
|
||||||
|
"AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine",
|
||||||
|
"LLMEngine": ".engine.llm_engine:LLMEngine",
|
||||||
|
"LLM": ".entrypoints.llm:LLM",
|
||||||
|
"initialize_ray_cluster": ".v1.executor.ray_utils:initialize_ray_cluster",
|
||||||
|
"PromptType": ".inputs:PromptType",
|
||||||
|
"TextPrompt": ".inputs:TextPrompt",
|
||||||
|
"TokensPrompt": ".inputs:TokensPrompt",
|
||||||
|
"ModelRegistry": ".model_executor.models:ModelRegistry",
|
||||||
|
"SamplingParams": ".sampling_params:SamplingParams",
|
||||||
|
"PoolingParams": ".pooling_params:PoolingParams",
|
||||||
|
"ClassificationOutput": ".outputs:ClassificationOutput",
|
||||||
|
"ClassificationRequestOutput": ".outputs:ClassificationRequestOutput",
|
||||||
|
"CompletionOutput": ".outputs:CompletionOutput",
|
||||||
|
"EmbeddingOutput": ".outputs:EmbeddingOutput",
|
||||||
|
"EmbeddingRequestOutput": ".outputs:EmbeddingRequestOutput",
|
||||||
|
"PoolingOutput": ".outputs:PoolingOutput",
|
||||||
|
"PoolingRequestOutput": ".outputs:PoolingRequestOutput",
|
||||||
|
"RequestOutput": ".outputs:RequestOutput",
|
||||||
|
"ScoringOutput": ".outputs:ScoringOutput",
|
||||||
|
"ScoringRequestOutput": ".outputs:ScoringRequestOutput",
|
||||||
|
}
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
|
from vllm.entrypoints.llm import LLM
|
||||||
|
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
||||||
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
from vllm.outputs import (
|
||||||
|
ClassificationOutput,
|
||||||
|
ClassificationRequestOutput,
|
||||||
|
CompletionOutput,
|
||||||
|
EmbeddingOutput,
|
||||||
|
EmbeddingRequestOutput,
|
||||||
|
PoolingOutput,
|
||||||
|
PoolingRequestOutput,
|
||||||
|
RequestOutput,
|
||||||
|
ScoringOutput,
|
||||||
|
ScoringRequestOutput,
|
||||||
|
)
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.v1.executor.ray_utils import initialize_ray_cluster
|
||||||
|
|
||||||
|
from ._bc_linter import bc_linter_include, bc_linter_skip
|
||||||
|
else:
|
||||||
|
|
||||||
|
def __getattr__(name: str) -> typing.Any:
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
|
if name in MODULE_ATTRS:
|
||||||
|
module_name, attr_name = MODULE_ATTRS[name].split(":")
|
||||||
|
module = import_module(module_name, __package__)
|
||||||
|
return getattr(module, attr_name)
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"module {__package__} has no attribute {name}")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"__version__",
|
||||||
|
"bc_linter_skip",
|
||||||
|
"bc_linter_include",
|
||||||
|
"__version_tuple__",
|
||||||
|
"LLM",
|
||||||
|
"ModelRegistry",
|
||||||
|
"PromptType",
|
||||||
|
"TextPrompt",
|
||||||
|
"TokensPrompt",
|
||||||
|
"SamplingParams",
|
||||||
|
"RequestOutput",
|
||||||
|
"CompletionOutput",
|
||||||
|
"PoolingOutput",
|
||||||
|
"PoolingRequestOutput",
|
||||||
|
"EmbeddingOutput",
|
||||||
|
"EmbeddingRequestOutput",
|
||||||
|
"ClassificationOutput",
|
||||||
|
"ClassificationRequestOutput",
|
||||||
|
"ScoringOutput",
|
||||||
|
"ScoringRequestOutput",
|
||||||
|
"LLMEngine",
|
||||||
|
"EngineArgs",
|
||||||
|
"AsyncLLMEngine",
|
||||||
|
"AsyncEngineArgs",
|
||||||
|
"initialize_ray_cluster",
|
||||||
|
"PoolingParams",
|
||||||
|
]
|
||||||
BIN
__pycache__/__init__.cpython-312.pyc
Normal file
BIN
__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/_aiter_ops.cpython-312.pyc
Normal file
BIN
__pycache__/_aiter_ops.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/_bc_linter.cpython-312.pyc
Normal file
BIN
__pycache__/_bc_linter.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/_custom_ops.cpython-312.pyc
Normal file
BIN
__pycache__/_custom_ops.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/_ipex_ops.cpython-312.pyc
Normal file
BIN
__pycache__/_ipex_ops.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/beam_search.cpython-312.pyc
Normal file
BIN
__pycache__/beam_search.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/collect_env.cpython-312.pyc
Normal file
BIN
__pycache__/collect_env.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/connections.cpython-312.pyc
Normal file
BIN
__pycache__/connections.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/env_override.cpython-312.pyc
Normal file
BIN
__pycache__/env_override.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/envs.cpython-312.pyc
Normal file
BIN
__pycache__/envs.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/forward_context.cpython-312.pyc
Normal file
BIN
__pycache__/forward_context.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/logger.cpython-312.pyc
Normal file
BIN
__pycache__/logger.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/logits_process.cpython-312.pyc
Normal file
BIN
__pycache__/logits_process.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/logprobs.cpython-312.pyc
Normal file
BIN
__pycache__/logprobs.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/outputs.cpython-312.pyc
Normal file
BIN
__pycache__/outputs.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/pooling_params.cpython-312.pyc
Normal file
BIN
__pycache__/pooling_params.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/sampling_params.cpython-312.pyc
Normal file
BIN
__pycache__/sampling_params.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/scalar_type.cpython-312.pyc
Normal file
BIN
__pycache__/scalar_type.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/scripts.cpython-312.pyc
Normal file
BIN
__pycache__/scripts.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/sequence.cpython-312.pyc
Normal file
BIN
__pycache__/sequence.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/tasks.cpython-312.pyc
Normal file
BIN
__pycache__/tasks.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/tracing.cpython-312.pyc
Normal file
BIN
__pycache__/tracing.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/version.cpython-312.pyc
Normal file
BIN
__pycache__/version.cpython-312.pyc
Normal file
Binary file not shown.
983
_aiter_ops.py
Normal file
983
_aiter_ops.py
Normal file
@@ -0,0 +1,983 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import functools
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||||
|
|
||||||
|
|
||||||
|
def is_aiter_found() -> bool:
|
||||||
|
from importlib.util import find_spec
|
||||||
|
|
||||||
|
return find_spec("aiter") is not None
|
||||||
|
|
||||||
|
|
||||||
|
# `find_spec` is not torch.compile compatible.
|
||||||
|
# In cases where aiter availability might have
|
||||||
|
# been checked in forward passes that are torch compiled.
|
||||||
|
# we keep this global outside to not cause torch compile breaks.
|
||||||
|
IS_AITER_FOUND = is_aiter_found()
|
||||||
|
|
||||||
|
|
||||||
|
def if_aiter_supported(func: Callable) -> Callable:
|
||||||
|
"""Decorator that only executes the function if
|
||||||
|
ROCm AITER package is supported on gfx9 archs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# checks the platform, device arch and aiter library existence.
|
||||||
|
|
||||||
|
if current_platform.is_rocm() and IS_AITER_FOUND:
|
||||||
|
from vllm.platforms.rocm import on_gfx9
|
||||||
|
|
||||||
|
if on_gfx9():
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_group_fp8_quant_impl(
|
||||||
|
x: torch.Tensor,
|
||||||
|
group_size: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
|
||||||
|
from aiter import QuantType, dtypes, get_hip_quant
|
||||||
|
|
||||||
|
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
|
||||||
|
return aiter_per1x128_quant(x.contiguous(), quant_dtype=dtypes.fp8)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_group_fp8_quant_fake(
|
||||||
|
x: torch.Tensor,
|
||||||
|
group_size: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
from aiter import dtypes
|
||||||
|
|
||||||
|
M, N = x.shape
|
||||||
|
x_fp8 = torch.empty((M, N), dtype=dtypes.fp8, device=x.device)
|
||||||
|
out_bs = torch.empty(
|
||||||
|
(
|
||||||
|
M,
|
||||||
|
(N + group_size - 1) // group_size,
|
||||||
|
),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
return x_fp8, out_bs
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_fused_moe_impl(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weight: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
expert_mask: torch.Tensor | None = None,
|
||||||
|
activation_method: int = 0,
|
||||||
|
quant_method: int = 0,
|
||||||
|
doweight_stage1: bool = False,
|
||||||
|
w1_scale: torch.Tensor | None = None,
|
||||||
|
w2_scale: torch.Tensor | None = None,
|
||||||
|
a1_scale: torch.Tensor | None = None,
|
||||||
|
a2_scale: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from aiter import ActivationType, QuantType
|
||||||
|
from aiter.fused_moe import fused_moe
|
||||||
|
|
||||||
|
activation = ActivationType(activation_method)
|
||||||
|
quant_type = QuantType(quant_method)
|
||||||
|
|
||||||
|
return fused_moe(
|
||||||
|
hidden_states,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weight,
|
||||||
|
topk_ids,
|
||||||
|
expert_mask,
|
||||||
|
activation,
|
||||||
|
quant_type,
|
||||||
|
doweight_stage1,
|
||||||
|
w1_scale,
|
||||||
|
w2_scale,
|
||||||
|
a1_scale,
|
||||||
|
a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_fused_moe_fake(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weight: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
expert_mask: torch.Tensor | None = None,
|
||||||
|
activation_method: int = 0,
|
||||||
|
quant_method: int = 0,
|
||||||
|
doweight_stage1: bool = False,
|
||||||
|
w1_scale: torch.Tensor | None = None,
|
||||||
|
w2_scale: torch.Tensor | None = None,
|
||||||
|
a1_scale: torch.Tensor | None = None,
|
||||||
|
a2_scale: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_asm_moe_tkw1_impl(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
fc1_scale: torch.Tensor | None = None,
|
||||||
|
fc2_scale: torch.Tensor | None = None,
|
||||||
|
fc1_smooth_scale: torch.Tensor | None = None,
|
||||||
|
fc2_smooth_scale: torch.Tensor | None = None,
|
||||||
|
a16: bool = False,
|
||||||
|
per_tensor_quant_scale: torch.Tensor | None = None,
|
||||||
|
expert_mask: torch.Tensor | None = None,
|
||||||
|
activation_method: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from aiter import ActivationType
|
||||||
|
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
|
||||||
|
|
||||||
|
activation = ActivationType(activation_method)
|
||||||
|
|
||||||
|
return asm_moe_tkw1(
|
||||||
|
hidden_states,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
fc1_scale=fc1_scale,
|
||||||
|
fc2_scale=fc2_scale,
|
||||||
|
fc1_smooth_scale=fc1_smooth_scale,
|
||||||
|
fc2_smooth_scale=fc2_smooth_scale,
|
||||||
|
a16=a16,
|
||||||
|
per_tensor_quant_scale=per_tensor_quant_scale,
|
||||||
|
expert_mask=expert_mask,
|
||||||
|
activation=activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_asm_moe_tkw1_fake(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
fc1_scale: torch.Tensor | None = None,
|
||||||
|
fc2_scale: torch.Tensor | None = None,
|
||||||
|
fc1_smooth_scale: torch.Tensor | None = None,
|
||||||
|
fc2_smooth_scale: torch.Tensor | None = None,
|
||||||
|
a16: bool = False,
|
||||||
|
per_tensor_quant_scale: torch.Tensor | None = None,
|
||||||
|
expert_mask: torch.Tensor | None = None,
|
||||||
|
activation_method: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_topk_softmax_impl(
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_indices: torch.Tensor,
|
||||||
|
token_expert_indices: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
renormalize: bool,
|
||||||
|
) -> None:
|
||||||
|
from aiter import topk_softmax
|
||||||
|
|
||||||
|
topk_softmax(
|
||||||
|
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_topk_softmax_fake(
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_indices: torch.Tensor,
|
||||||
|
token_expert_indices: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
renormalize: bool,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_biased_grouped_topk_impl(
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
correction_bias: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
num_expert_group: int,
|
||||||
|
topk_group: int,
|
||||||
|
need_renorm: bool,
|
||||||
|
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
||||||
|
) -> None:
|
||||||
|
from aiter import biased_grouped_topk
|
||||||
|
|
||||||
|
biased_grouped_topk(
|
||||||
|
gating_output,
|
||||||
|
correction_bias,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
num_expert_group,
|
||||||
|
topk_group,
|
||||||
|
need_renorm,
|
||||||
|
routed_scaling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_biased_grouped_topk_fake(
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
correction_bias: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
num_expert_group: int,
|
||||||
|
topk_group: int,
|
||||||
|
need_renorm: bool,
|
||||||
|
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_grouped_topk_impl(
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
num_expert_group: int,
|
||||||
|
topk_group: int,
|
||||||
|
need_renorm: bool,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
||||||
|
) -> None:
|
||||||
|
is_softmax = scoring_func == "softmax"
|
||||||
|
from aiter import grouped_topk
|
||||||
|
|
||||||
|
grouped_topk(
|
||||||
|
gating_output,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
num_expert_group,
|
||||||
|
topk_group,
|
||||||
|
need_renorm,
|
||||||
|
is_softmax,
|
||||||
|
routed_scaling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_grouped_topk_fake(
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
num_expert_group: int,
|
||||||
|
topk_group: int,
|
||||||
|
need_renorm: bool,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_mla_decode_fwd_impl(
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv_buffer: torch.Tensor,
|
||||||
|
o: torch.Tensor,
|
||||||
|
qo_indptr: torch.Tensor,
|
||||||
|
max_seqlen_qo: int,
|
||||||
|
kv_indptr: torch.Tensor | None = None,
|
||||||
|
kv_indices: torch.Tensor | None = None,
|
||||||
|
kv_last_page_lens: torch.Tensor | None = None,
|
||||||
|
sm_scale: float = 1.0,
|
||||||
|
logit_cap: float = 0.0,
|
||||||
|
) -> None:
|
||||||
|
from aiter.mla import mla_decode_fwd
|
||||||
|
|
||||||
|
mla_decode_fwd(
|
||||||
|
q,
|
||||||
|
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
||||||
|
o,
|
||||||
|
qo_indptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_last_page_lens,
|
||||||
|
max_seqlen_qo,
|
||||||
|
sm_scale=sm_scale,
|
||||||
|
logit_cap=logit_cap,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_mla_decode_fwd_fake(
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv_buffer: torch.Tensor,
|
||||||
|
o: torch.Tensor,
|
||||||
|
qo_indptr: torch.Tensor,
|
||||||
|
max_seqlen_qo: int,
|
||||||
|
kv_indptr: torch.Tensor | None = None,
|
||||||
|
kv_indices: torch.Tensor | None = None,
|
||||||
|
kv_last_page_lens: torch.Tensor | None = None,
|
||||||
|
sm_scale: float = 1.0,
|
||||||
|
logit_cap: float = 0.0,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_gemm_a8w8_impl(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
output_dtype: torch.dtype = torch.float16,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from aiter import gemm_a8w8_CK
|
||||||
|
|
||||||
|
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
|
||||||
|
# a to be [M, K]
|
||||||
|
# b to be [N, K]
|
||||||
|
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
||||||
|
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_gemm_a8w8_fake(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
output_dtype: torch.dtype = torch.float16,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
m = A.shape[0]
|
||||||
|
n = B.shape[0]
|
||||||
|
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
|
||||||
|
return Y
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_gemm_a8w8_blockscale_impl(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
output_dtype: torch.dtype = torch.float16,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from aiter import gemm_a8w8_blockscale
|
||||||
|
|
||||||
|
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_gemm_a8w8_blockscale_fake(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
output_dtype: torch.dtype = torch.float16,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
m = A.shape[0]
|
||||||
|
n = B.shape[0]
|
||||||
|
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
|
||||||
|
return Y
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_rms_norm_impl(
|
||||||
|
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from aiter import rms_norm
|
||||||
|
|
||||||
|
if x.dim() > 2:
|
||||||
|
x_original_shape = x.shape
|
||||||
|
x = x.reshape(-1, x_original_shape[-1])
|
||||||
|
x = rms_norm(x, weight, variance_epsilon)
|
||||||
|
return x.reshape(x_original_shape)
|
||||||
|
|
||||||
|
return rms_norm(x, weight, variance_epsilon)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_rms_norm_fake(
|
||||||
|
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty_like(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_rmsnorm2d_fwd_with_add_impl(
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
variance_epsilon: float,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
from aiter import rmsnorm2d_fwd_with_add
|
||||||
|
|
||||||
|
residual_out = torch.empty_like(residual)
|
||||||
|
output = torch.empty_like(x)
|
||||||
|
rmsnorm2d_fwd_with_add(
|
||||||
|
output, # output
|
||||||
|
x, # input
|
||||||
|
residual, # residual input
|
||||||
|
residual_out, # residual output
|
||||||
|
weight,
|
||||||
|
variance_epsilon,
|
||||||
|
)
|
||||||
|
return output, residual_out
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
variance_epsilon: float,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
return torch.empty_like(x), torch.empty_like(residual)
|
||||||
|
|
||||||
|
|
||||||
|
# Global flag to ensure ops are registered only once
|
||||||
|
_OPS_REGISTERED = False
|
||||||
|
|
||||||
|
|
||||||
|
class rocm_aiter_ops:
|
||||||
|
_AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
|
||||||
|
_LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||||
|
_RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
|
||||||
|
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
||||||
|
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
||||||
|
_PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||||
|
_MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
|
||||||
|
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
||||||
|
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
||||||
|
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
||||||
|
_TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
|
||||||
|
_MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
|
||||||
|
_TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@if_aiter_supported
|
||||||
|
def is_enabled(cls) -> bool:
|
||||||
|
"""Verifies device specs and availability of aiter main env variable."""
|
||||||
|
return cls._AITER_ENABLED
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@if_aiter_supported
|
||||||
|
def is_linear_enabled(cls) -> bool:
|
||||||
|
""" "Verifies device specs and availability of env variable."""
|
||||||
|
return cls._AITER_ENABLED and cls._LINEAR_ENABLED
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@if_aiter_supported
|
||||||
|
def is_linear_fp8_enaled(cls) -> bool:
|
||||||
|
""" "Verifies device specs and availability of env variable."""
|
||||||
|
return cls.is_linear_enabled() and current_platform.is_fp8_fnuz()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@if_aiter_supported
|
||||||
|
def is_rmsnorm_enabled(cls) -> bool:
|
||||||
|
""" "Verifies device specs and availability of env variable."""
|
||||||
|
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@if_aiter_supported
|
||||||
|
def is_fused_moe_enabled(cls) -> bool:
|
||||||
|
""" "Verifies device specs and availability of env variable."""
|
||||||
|
return cls._AITER_ENABLED and cls._FMOE_ENABLED
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@if_aiter_supported
|
||||||
|
def is_fusion_moe_shared_experts_enabled(cls) -> bool:
|
||||||
|
return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@if_aiter_supported
|
||||||
|
def is_mla_enabled(cls) -> bool:
|
||||||
|
""" "Verifies device specs and availability of env variable."""
|
||||||
|
return cls._AITER_ENABLED and cls._MLA_ENABLED
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@if_aiter_supported
|
||||||
|
def is_mha_enabled(cls) -> bool:
|
||||||
|
""" "Verifies device specs and availability of env variable."""
|
||||||
|
return cls._AITER_ENABLED and cls._MHA_ENABLED
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@if_aiter_supported
|
||||||
|
def is_pa_attn_enabled(cls) -> bool:
|
||||||
|
""" "Verifies device specs and availability of env variable."""
|
||||||
|
return cls._AITER_ENABLED and cls._PG_ATTN_ENABLED
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@if_aiter_supported
|
||||||
|
def is_triton_unified_attn_enabled(cls) -> bool:
|
||||||
|
""" "Verifies device specs and availability of env variable."""
|
||||||
|
return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@if_aiter_supported
|
||||||
|
def is_fp8bmm_enabled(cls) -> bool:
|
||||||
|
return cls._AITER_ENABLED and cls._FP8BMM_ENABLED
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@if_aiter_supported
|
||||||
|
def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool:
|
||||||
|
return cls._AITER_ENABLED and cls._FP4_GEMM_DYNAMIC_QUANT_ASM
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@if_aiter_supported
|
||||||
|
def is_triton_rotary_embed_enabled(cls) -> bool:
|
||||||
|
return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@if_aiter_supported
|
||||||
|
def is_triton_gemm_enabled(cls) -> bool:
|
||||||
|
return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@if_aiter_supported
|
||||||
|
def register_ops_once() -> None:
|
||||||
|
global _OPS_REGISTERED
|
||||||
|
if not _OPS_REGISTERED:
|
||||||
|
tags = (
|
||||||
|
tuple()
|
||||||
|
if is_torch_equal_or_newer("2.7.0")
|
||||||
|
else (torch.Tag.needs_fixed_stride_order,)
|
||||||
|
)
|
||||||
|
|
||||||
|
# register all the custom ops here
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_group_fp8_quant",
|
||||||
|
op_func=_rocm_aiter_group_fp8_quant_impl,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=_rocm_aiter_group_fp8_quant_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_asm_moe_tkw1",
|
||||||
|
op_func=_rocm_aiter_asm_moe_tkw1_impl,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=_rocm_aiter_asm_moe_tkw1_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_fused_moe",
|
||||||
|
op_func=_rocm_aiter_fused_moe_impl,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=_rocm_aiter_fused_moe_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_topk_softmax",
|
||||||
|
op_func=_rocm_aiter_topk_softmax_impl,
|
||||||
|
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
|
||||||
|
fake_impl=_rocm_aiter_topk_softmax_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_biased_grouped_topk",
|
||||||
|
op_func=_rocm_aiter_biased_grouped_topk_impl,
|
||||||
|
mutates_args=["topk_weights", "topk_ids"],
|
||||||
|
fake_impl=_rocm_aiter_biased_grouped_topk_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_grouped_topk",
|
||||||
|
op_func=_rocm_aiter_grouped_topk_impl,
|
||||||
|
mutates_args=["topk_weights", "topk_ids"],
|
||||||
|
fake_impl=_rocm_aiter_grouped_topk_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_mla_decode_fwd",
|
||||||
|
op_func=_rocm_aiter_mla_decode_fwd_impl,
|
||||||
|
mutates_args=["o"],
|
||||||
|
fake_impl=_rocm_aiter_mla_decode_fwd_fake,
|
||||||
|
tags=tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_gemm_a8w8",
|
||||||
|
op_func=_rocm_aiter_gemm_a8w8_impl,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=_rocm_aiter_gemm_a8w8_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_gemm_a8w8_blockscale",
|
||||||
|
op_func=_rocm_aiter_gemm_a8w8_blockscale_impl,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_rms_norm",
|
||||||
|
op_func=_rocm_aiter_rms_norm_impl,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=_rocm_aiter_rms_norm_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
|
||||||
|
op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
_OPS_REGISTERED = True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def rms_norm2d_with_add(
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
variance_epsilon: float,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add(
|
||||||
|
x, residual, weight, variance_epsilon
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def rms_norm(
|
||||||
|
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gemm_a8w8(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
output_dtype: torch.dtype = torch.float16,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.vllm.rocm_aiter_gemm_a8w8(A, B, As, Bs, bias, output_dtype)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gemm_a8w8_blockscale(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
block_size: list[int],
|
||||||
|
output_dtype: torch.dtype = torch.float16,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.vllm.rocm_aiter_gemm_a8w8_blockscale(
|
||||||
|
A, B, As, Bs, output_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fused_moe(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weight: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
expert_mask: torch.Tensor | None = None,
|
||||||
|
activation_method: int = 0,
|
||||||
|
quant_method: int = 0,
|
||||||
|
doweight_stage1: bool = False,
|
||||||
|
w1_scale: torch.Tensor | None = None,
|
||||||
|
w2_scale: torch.Tensor | None = None,
|
||||||
|
a1_scale: torch.Tensor | None = None,
|
||||||
|
a2_scale: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.vllm.rocm_aiter_fused_moe(
|
||||||
|
hidden_states,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weight,
|
||||||
|
topk_ids,
|
||||||
|
expert_mask,
|
||||||
|
activation_method,
|
||||||
|
quant_method,
|
||||||
|
doweight_stage1,
|
||||||
|
w1_scale,
|
||||||
|
w2_scale,
|
||||||
|
a1_scale,
|
||||||
|
a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def asm_moe_tkw1(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
fc1_scale: torch.Tensor | None = None,
|
||||||
|
fc2_scale: torch.Tensor | None = None,
|
||||||
|
fc1_smooth_scale: torch.Tensor | None = None,
|
||||||
|
fc2_smooth_scale: torch.Tensor | None = None,
|
||||||
|
a16: bool = False,
|
||||||
|
per_tensor_quant_scale: torch.Tensor | None = None,
|
||||||
|
expert_mask: torch.Tensor | None = None,
|
||||||
|
activation_method: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.vllm.rocm_aiter_asm_moe_tkw1(
|
||||||
|
hidden_states,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
fc1_scale,
|
||||||
|
fc2_scale,
|
||||||
|
fc1_smooth_scale,
|
||||||
|
fc2_smooth_scale,
|
||||||
|
a16,
|
||||||
|
per_tensor_quant_scale,
|
||||||
|
expert_mask,
|
||||||
|
activation_method,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def topk_softmax(
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_indices: torch.Tensor,
|
||||||
|
token_expert_indices: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
renormalize: bool,
|
||||||
|
) -> tuple[torch.Tensor, ...]:
|
||||||
|
torch.ops.vllm.rocm_aiter_topk_softmax(
|
||||||
|
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
|
||||||
|
)
|
||||||
|
return topk_weights, topk_indices
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def biased_grouped_topk(
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
correction_bias: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
num_expert_group: int,
|
||||||
|
topk_group: int,
|
||||||
|
need_renorm: bool,
|
||||||
|
routed_scaling_factor: float = 1.0,
|
||||||
|
) -> None:
|
||||||
|
torch.ops.vllm.rocm_aiter_biased_grouped_topk(
|
||||||
|
gating_output,
|
||||||
|
correction_bias,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
num_expert_group,
|
||||||
|
topk_group,
|
||||||
|
need_renorm,
|
||||||
|
routed_scaling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def grouped_topk(
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
num_expert_group: int,
|
||||||
|
topk_group: int,
|
||||||
|
need_renorm: bool,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
routed_scaling_factor: float = 1.0,
|
||||||
|
) -> None:
|
||||||
|
torch.ops.vllm.rocm_aiter_grouped_topk(
|
||||||
|
gating_output,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
num_expert_group,
|
||||||
|
topk_group,
|
||||||
|
need_renorm,
|
||||||
|
scoring_func,
|
||||||
|
routed_scaling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mla_decode_fwd(
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv_buffer: torch.Tensor,
|
||||||
|
o: torch.Tensor,
|
||||||
|
sm_scale: float,
|
||||||
|
qo_indptr: torch.Tensor,
|
||||||
|
max_seqlen_qo: int,
|
||||||
|
kv_indptr: torch.Tensor | None = None,
|
||||||
|
kv_indices: torch.Tensor | None = None,
|
||||||
|
kv_last_page_lens: torch.Tensor | None = None,
|
||||||
|
logit_cap: float = 0.0,
|
||||||
|
):
|
||||||
|
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
|
||||||
|
q,
|
||||||
|
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
||||||
|
o,
|
||||||
|
qo_indptr,
|
||||||
|
max_seqlen_qo,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_last_page_lens,
|
||||||
|
sm_scale=sm_scale,
|
||||||
|
logit_cap=logit_cap,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def triton_fp4_gemm_dynamic_qaunt(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype | None = torch.bfloat16,
|
||||||
|
x_scales: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
||||||
|
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
||||||
|
|
||||||
|
if x_scales is None:
|
||||||
|
x_q, x_s = dynamic_mxfp4_quant(x)
|
||||||
|
else:
|
||||||
|
x_q = x
|
||||||
|
x_s = x_scales
|
||||||
|
|
||||||
|
y = torch.empty(
|
||||||
|
x_q.shape[0], weight.shape[0], device=x_q.device, dtype=out_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def triton_rotary_embed(
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
cos_sin_cache: torch.Tensor,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
is_neox_style: bool,
|
||||||
|
):
|
||||||
|
from aiter.ops.triton.rope import rope_cached_thd_positions_2c_fwd_inplace
|
||||||
|
|
||||||
|
num_tokens = positions.numel()
|
||||||
|
cos, sin = cos_sin_cache.chunk(2, dim=-1)
|
||||||
|
query_shape = query.shape
|
||||||
|
key_shape = key.shape
|
||||||
|
rotate_style = 0 if is_neox_style else 1
|
||||||
|
|
||||||
|
query = query.view(num_tokens, -1, head_size)
|
||||||
|
key = key.view(num_tokens, -1, head_size)
|
||||||
|
query_ = query[..., :rotary_dim]
|
||||||
|
key_ = key[..., :rotary_dim]
|
||||||
|
positions = positions.view(*query.shape[:1])
|
||||||
|
rope_cached_thd_positions_2c_fwd_inplace(
|
||||||
|
positions,
|
||||||
|
sin,
|
||||||
|
cos,
|
||||||
|
query_,
|
||||||
|
key_,
|
||||||
|
rotate_style,
|
||||||
|
reuse_freqs_front_part=True,
|
||||||
|
is_nope_first=False,
|
||||||
|
)
|
||||||
|
query = query.view(query_shape)
|
||||||
|
key = key.view(key_shape)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def triton_fp8_bmm(
|
||||||
|
X: torch.Tensor,
|
||||||
|
WQ: torch.Tensor,
|
||||||
|
w_scale: torch.Tensor,
|
||||||
|
group_size: int = 128,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
dtype: torch.dtype | None = torch.bfloat16,
|
||||||
|
splitK: int | None = None,
|
||||||
|
YQ: torch.Tensor | None = None,
|
||||||
|
transpose_bm: bool | None = False,
|
||||||
|
config: dict | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# ruff: noqa: E501 # isort: skip
|
||||||
|
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import (
|
||||||
|
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm,
|
||||||
|
)
|
||||||
|
|
||||||
|
return aiter_triton_fp8_bmm(
|
||||||
|
X,
|
||||||
|
WQ,
|
||||||
|
w_scale,
|
||||||
|
group_size=group_size,
|
||||||
|
bias=bias,
|
||||||
|
dtype=dtype,
|
||||||
|
splitK=splitK,
|
||||||
|
YQ=YQ,
|
||||||
|
transpose_bm=transpose_bm,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def triton_gemm_a8w8_blockscale(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
block_size: list[int],
|
||||||
|
output_dtype: torch.dtype = torch.float16,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
|
||||||
|
|
||||||
|
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def group_fp8_quant(
|
||||||
|
input_2d: torch.Tensor,
|
||||||
|
group_size: int = 128,
|
||||||
|
) -> tuple[torch.Tensor, ...]:
|
||||||
|
assert group_size == 128, "Group size must be 128"
|
||||||
|
return torch.ops.vllm.rocm_aiter_group_fp8_quant(input_2d, group_size)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool:
|
||||||
|
return (n, k) in [
|
||||||
|
(1024, 8192),
|
||||||
|
(2112, 7168),
|
||||||
|
(3072, 1536),
|
||||||
|
(32768, 8192),
|
||||||
|
(4096, 7168),
|
||||||
|
(4608, 7168),
|
||||||
|
(512, 7168),
|
||||||
|
(7168, 2048),
|
||||||
|
(7168, 256),
|
||||||
|
(8192, 1024),
|
||||||
|
(8192, 32768),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def shuffle_weight(
|
||||||
|
self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
|
|
||||||
|
return shuffle_weight(tensor, layout=layout)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def shuffle_weights(
|
||||||
|
*tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
|
||||||
|
) -> tuple[torch.Tensor, ...]:
|
||||||
|
"""
|
||||||
|
Applies shuffle_weight function from AITER to each
|
||||||
|
input tensor and returns them.
|
||||||
|
|
||||||
|
Rearranges (shuffles) the input tensor/s
|
||||||
|
into a specified block layout for optimized computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*tensors: Variable number of torch.Tensor objects.
|
||||||
|
layout: A pair of integers specifying the block sizes used to divide
|
||||||
|
the tensors during shuffling. Default is (16, 16).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Tuple of shuffled tensors.
|
||||||
|
"""
|
||||||
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
|
|
||||||
|
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
|
||||||
|
|
||||||
|
|
||||||
|
rocm_aiter_ops.register_ops_once()
|
||||||
54
_bc_linter.py
Normal file
54
_bc_linter.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# vllm/_bc_linter.py
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any, TypeVar, overload
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def bc_linter_skip(obj: T) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def bc_linter_skip(*, reason: str | None = ...) -> Callable[[T], T]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def bc_linter_skip(obj: Any = None, *, reason: str | None = None):
|
||||||
|
"""
|
||||||
|
No-op decorator to mark symbols/files for BC-linter suppression.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@bc_linter_skip
|
||||||
|
def legacy_api(...): ...
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _wrap(x: T) -> T:
|
||||||
|
return x
|
||||||
|
|
||||||
|
return _wrap if obj is None else obj
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def bc_linter_include(obj: T) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def bc_linter_include(*, reason: str | None = ...) -> Callable[[T], T]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def bc_linter_include(obj: Any = None, *, reason: str | None = None):
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
@bc_linter_include
|
||||||
|
def public_api(...): ...
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _wrap(x: T) -> T:
|
||||||
|
return x
|
||||||
|
|
||||||
|
return _wrap if obj is None else obj
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["bc_linter_skip", "bc_linter_include"]
|
||||||
3512
_custom_ops.py
Normal file
3512
_custom_ops.py
Normal file
File diff suppressed because it is too large
Load Diff
457
_ipex_ops.py
Normal file
457
_ipex_ops.py
Normal file
@@ -0,0 +1,457 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
except ImportError as e:
|
||||||
|
logger.debug("Import error msg: %s", e.msg)
|
||||||
|
|
||||||
|
|
||||||
|
class ipex_ops:
|
||||||
|
@staticmethod
|
||||||
|
def _reshape_activation_tensor(
|
||||||
|
x: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
num = x.size(0)
|
||||||
|
d = x.size(1) // 2
|
||||||
|
x = x.reshape(num, 2, d)
|
||||||
|
x1, x2 = torch.chunk(x, chunks=2, dim=1)
|
||||||
|
x1 = x1.reshape(num, d)
|
||||||
|
x2 = x2.reshape(num, d)
|
||||||
|
return x1, x2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||||
|
ipex.llm.functional.silu_and_mul(x, out)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||||
|
ipex.llm.functional.gelu_and_mul(x, out)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||||
|
ipex.llm.functional.gelu_and_mul(x, out)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gelu_fast(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.nn.functional.gelu(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gelu_new(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.nn.functional.gelu(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||||
|
ipex.llm.functional.gelu_quick(x, out)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def paged_attention_v1(
|
||||||
|
out: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
num_kv_heads: int,
|
||||||
|
scale: float,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
context_lens: torch.Tensor,
|
||||||
|
block_size: int,
|
||||||
|
max_context_len: int,
|
||||||
|
alibi_slopes: torch.Tensor | None,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
|
tp_rank: int = 0,
|
||||||
|
blocksparse_local_blocks: int = 0,
|
||||||
|
blocksparse_vert_stride: int = 0,
|
||||||
|
blocksparse_block_size: int = 64,
|
||||||
|
blocksparse_head_sliding_step: int = 0,
|
||||||
|
) -> None:
|
||||||
|
assert kv_cache_dtype == "auto"
|
||||||
|
num_heads = out.size(1)
|
||||||
|
num_queries_per_tokens = num_heads // num_kv_heads
|
||||||
|
ipex.llm.modules.PagedAttention.single_query_kv_attention(
|
||||||
|
out,
|
||||||
|
query.contiguous(),
|
||||||
|
key_cache.view_as(value_cache),
|
||||||
|
value_cache,
|
||||||
|
num_queries_per_tokens,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
block_size,
|
||||||
|
max_context_len,
|
||||||
|
alibi_slopes,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def paged_attention_v2(
|
||||||
|
out: torch.Tensor,
|
||||||
|
exp_sum: torch.Tensor,
|
||||||
|
max_logits: torch.Tensor,
|
||||||
|
tmp_out: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
num_kv_heads: int,
|
||||||
|
scale: float,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
context_lens: torch.Tensor,
|
||||||
|
block_size: int,
|
||||||
|
max_context_len: int,
|
||||||
|
alibi_slopes: torch.Tensor | None,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
|
tp_rank: int = 0,
|
||||||
|
blocksparse_local_blocks: int = 0,
|
||||||
|
blocksparse_vert_stride: int = 0,
|
||||||
|
blocksparse_block_size: int = 64,
|
||||||
|
blocksparse_head_sliding_step: int = 0,
|
||||||
|
) -> None:
|
||||||
|
assert kv_cache_dtype == "auto"
|
||||||
|
num_heads = out.size(1)
|
||||||
|
num_queries_per_tokens = num_heads // num_kv_heads
|
||||||
|
ipex.llm.modules.PagedAttention.single_query_kv_attention(
|
||||||
|
out,
|
||||||
|
query.contiguous(),
|
||||||
|
key_cache.view_as(value_cache),
|
||||||
|
value_cache,
|
||||||
|
num_queries_per_tokens,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
block_size,
|
||||||
|
max_context_len,
|
||||||
|
alibi_slopes,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def rotary_embedding(
|
||||||
|
positions: torch.Tensor, # [batch_size, seq_len]
|
||||||
|
query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size]
|
||||||
|
key: torch.Tensor, # [batch_size, seq_len, num_kv_heads*head_size]
|
||||||
|
head_size: int,
|
||||||
|
cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim]
|
||||||
|
is_neox: bool,
|
||||||
|
) -> None:
|
||||||
|
rot_dim = cos_sin_cache.size(1)
|
||||||
|
ipex.llm.functional.rotary_embedding_batched(
|
||||||
|
positions, query, key, head_size, cos_sin_cache, is_neox, rot_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def rms_norm(
|
||||||
|
input: torch.Tensor, weight: torch.Tensor, epsilon: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
out = torch.empty_like(input)
|
||||||
|
torch.ops.torch_ipex.rms_norm_vllm(out, input.contiguous(), weight, epsilon)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fused_add_rms_norm(
|
||||||
|
input: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
epsilon: float,
|
||||||
|
) -> None:
|
||||||
|
torch.ops.torch_ipex.fused_add_rms_norm_vllm(input, residual, weight, epsilon)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def varlen_attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
out: torch.Tensor,
|
||||||
|
seqlen_q: torch.Tensor,
|
||||||
|
seqlen_k: torch.Tensor,
|
||||||
|
alibi_slopes: torch.Tensor | None,
|
||||||
|
max_seqlen_q: int,
|
||||||
|
max_seqlen_k: int,
|
||||||
|
pdropout: float,
|
||||||
|
softmax_scale: float,
|
||||||
|
zero_tensors: bool,
|
||||||
|
is_causal: bool,
|
||||||
|
return_softmax: bool,
|
||||||
|
gen_: torch.Generator,
|
||||||
|
window_size_left: float,
|
||||||
|
window_size_right: float,
|
||||||
|
logits_soft_cap: float,
|
||||||
|
) -> None:
|
||||||
|
if ipex.__version__.endswith("cpu"):
|
||||||
|
if logits_soft_cap != 0.0:
|
||||||
|
raise ValueError("IPEX CPU does not support logits_soft_cap")
|
||||||
|
assert alibi_slopes is None
|
||||||
|
assert window_size_left < 0 and window_size_right < 0
|
||||||
|
ipex.llm.functional.varlen_attention(
|
||||||
|
query.contiguous(),
|
||||||
|
key.contiguous(),
|
||||||
|
value.contiguous(),
|
||||||
|
out,
|
||||||
|
seqlen_q.int(),
|
||||||
|
seqlen_k.int(),
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
pdropout,
|
||||||
|
softmax_scale,
|
||||||
|
zero_tensors,
|
||||||
|
is_causal,
|
||||||
|
return_softmax,
|
||||||
|
gen_,
|
||||||
|
)
|
||||||
|
else: # XPU build
|
||||||
|
ipex.llm.functional.varlen_attention(
|
||||||
|
query.contiguous(),
|
||||||
|
key.contiguous(),
|
||||||
|
value.contiguous(),
|
||||||
|
out,
|
||||||
|
seqlen_q.int(),
|
||||||
|
seqlen_k.int(),
|
||||||
|
alibi_slopes,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
pdropout,
|
||||||
|
softmax_scale,
|
||||||
|
zero_tensors,
|
||||||
|
is_causal,
|
||||||
|
return_softmax,
|
||||||
|
gen_,
|
||||||
|
window_size_left,
|
||||||
|
window_size_right,
|
||||||
|
logits_soft_cap,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reshape_and_cache(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
|
) -> None:
|
||||||
|
assert kv_cache_dtype == "auto"
|
||||||
|
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||||
|
key, value, key_cache, value_cache, slot_mapping
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reshape_and_cache_flash(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: torch.Tensor | None = None,
|
||||||
|
v_scale: torch.Tensor | None = None,
|
||||||
|
k_scale_float: float = 1.0,
|
||||||
|
v_scale_float: float = 1.0,
|
||||||
|
) -> None:
|
||||||
|
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
slot_mapping,
|
||||||
|
kv_cache_dtype,
|
||||||
|
k_scale_float,
|
||||||
|
v_scale_float,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def flash_attn_varlen_func(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens_q: torch.Tensor,
|
||||||
|
max_seqlen_q: int,
|
||||||
|
max_seqlen_k: int,
|
||||||
|
softmax_scale: float | None = None,
|
||||||
|
causal: bool = False,
|
||||||
|
out: torch.Tensor | None = None,
|
||||||
|
block_table: torch.Tensor | None = None,
|
||||||
|
alibi_slopes: torch.Tensor | None = None,
|
||||||
|
window_size: list[int] | None = None,
|
||||||
|
softcap: float | None = 0.0,
|
||||||
|
seqused_k: torch.Tensor | None = None,
|
||||||
|
cu_seqlens_k: torch.Tensor | None = None,
|
||||||
|
# passed in qwen vl
|
||||||
|
dropout_p: float = 0.0,
|
||||||
|
# The following parameters are not used in ipex kernel currently,
|
||||||
|
# we keep API compatible to CUDA's.
|
||||||
|
scheduler_metadata=None,
|
||||||
|
fa_version: int = 2,
|
||||||
|
q_descale=None,
|
||||||
|
k_descale=None,
|
||||||
|
v_descale=None,
|
||||||
|
num_splits=0,
|
||||||
|
s_aux: torch.Tensor | None = None,
|
||||||
|
):
|
||||||
|
if out is None:
|
||||||
|
out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
|
||||||
|
real_window_size: tuple[int, int]
|
||||||
|
if window_size is None:
|
||||||
|
real_window_size = (-1, -1)
|
||||||
|
else:
|
||||||
|
assert len(window_size) == 2
|
||||||
|
real_window_size = (window_size[0], window_size[1])
|
||||||
|
|
||||||
|
if block_table is None:
|
||||||
|
assert cu_seqlens_k is not None, (
|
||||||
|
"cu_seqlens_k can't be None when calling varlen_attention."
|
||||||
|
)
|
||||||
|
if softmax_scale is None:
|
||||||
|
softmax_scale = q.shape[-1] ** (-0.5)
|
||||||
|
ipex_ops.varlen_attention(
|
||||||
|
q.contiguous(),
|
||||||
|
k.contiguous(),
|
||||||
|
v.contiguous(),
|
||||||
|
out,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
None,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
real_window_size[0],
|
||||||
|
real_window_size[1],
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||||
|
out,
|
||||||
|
q.contiguous(),
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q,
|
||||||
|
seqused_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
softmax_scale,
|
||||||
|
causal,
|
||||||
|
block_table,
|
||||||
|
alibi_slopes,
|
||||||
|
sink=s_aux,
|
||||||
|
softcap=softcap,
|
||||||
|
window_size_left=real_window_size[0],
|
||||||
|
window_size_right=real_window_size[1],
|
||||||
|
k_scale=1.0,
|
||||||
|
v_scale=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_scheduler_metadata(
|
||||||
|
batch_size,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
num_heads_q,
|
||||||
|
num_heads_kv,
|
||||||
|
headdim,
|
||||||
|
cache_seqlens: torch.Tensor,
|
||||||
|
qkv_dtype=torch.bfloat16,
|
||||||
|
headdim_v=None,
|
||||||
|
cu_seqlens_q: torch.Tensor | None = None,
|
||||||
|
cu_seqlens_k_new: torch.Tensor | None = None,
|
||||||
|
cache_leftpad: torch.Tensor | None = None,
|
||||||
|
page_size: int | None = None,
|
||||||
|
max_seqlen_k_new=0,
|
||||||
|
causal=False,
|
||||||
|
window_size=(-1, -1), # -1 means infinite context window
|
||||||
|
has_softcap=False,
|
||||||
|
num_splits=0, # Can be tuned for speed
|
||||||
|
pack_gqa=None, # Can be tuned for speed
|
||||||
|
sm_margin=0, # Can be tuned if some SMs are used for communication
|
||||||
|
) -> None:
|
||||||
|
logger.warning_once(
|
||||||
|
"get_scheduler_metadata is not implemented for ipex_ops, returning None."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def copy_blocks(
|
||||||
|
key_caches: list[torch.Tensor],
|
||||||
|
value_caches: list[torch.Tensor],
|
||||||
|
block_mapping: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
torch.xpu.copy_blocks( # type: ignore
|
||||||
|
key_caches,
|
||||||
|
value_caches,
|
||||||
|
block_mapping,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def swap_blocks(
|
||||||
|
src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
|
||||||
|
) -> None:
|
||||||
|
torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def scaled_fp8_quant(
|
||||||
|
input: torch.Tensor,
|
||||||
|
scale: torch.Tensor | None = None,
|
||||||
|
num_token_padding: int | None = None,
|
||||||
|
scale_ub: torch.Tensor | None = None,
|
||||||
|
use_per_token_if_dynamic: bool = False,
|
||||||
|
output: torch.Tensor | None = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Quantize input tensor to FP8 and return quantized tensor and scale.
|
||||||
|
|
||||||
|
This function is designed for both static and dynamic quantization:
|
||||||
|
If you provide the scale, it will use static scaling and if you omit
|
||||||
|
it, the scale will be determined dynamically. Currently, XPU platform
|
||||||
|
only supports dynamic quantization. The function also allows optional
|
||||||
|
padding of the output tensors for downstream kernels that will benefit
|
||||||
|
from padding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: The input tensor to be quantized to FP8
|
||||||
|
scale: Optional scaling factor for the FP8 quantization
|
||||||
|
scale_ub: Optional upper bound for scaling factor in dynamic
|
||||||
|
per token case
|
||||||
|
num_token_padding: If specified, pad the first dimension
|
||||||
|
of the output to at least this value.
|
||||||
|
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
||||||
|
in the dynamic quantization case.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
||||||
|
scaling factor.
|
||||||
|
"""
|
||||||
|
# This code assumes batch_dim and num_tokens are flattened
|
||||||
|
assert input.ndim == 2
|
||||||
|
shape: tuple[int, int] | torch.Size = input.shape
|
||||||
|
out_dtype: torch.dtype = current_platform.fp8_dtype()
|
||||||
|
if num_token_padding:
|
||||||
|
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||||
|
if output is None:
|
||||||
|
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
||||||
|
else:
|
||||||
|
assert num_token_padding is None, (
|
||||||
|
"padding not supported if output passed in"
|
||||||
|
)
|
||||||
|
assert output.dtype == out_dtype
|
||||||
|
assert scale is None, "only dynamic fp8 quantization supported on XPU"
|
||||||
|
assert not use_per_token_if_dynamic, (
|
||||||
|
"per token dynamic fp8 quantization not supported on XPU"
|
||||||
|
)
|
||||||
|
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||||
|
torch.ops.torch_ipex.dynamic_scaled_fp8_quant(output, input, scale)
|
||||||
|
|
||||||
|
return output, scale
|
||||||
0
assets/__init__.py
Normal file
0
assets/__init__.py
Normal file
BIN
assets/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
assets/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
assets/__pycache__/audio.cpython-312.pyc
Normal file
BIN
assets/__pycache__/audio.cpython-312.pyc
Normal file
Binary file not shown.
BIN
assets/__pycache__/base.cpython-312.pyc
Normal file
BIN
assets/__pycache__/base.cpython-312.pyc
Normal file
Binary file not shown.
BIN
assets/__pycache__/image.cpython-312.pyc
Normal file
BIN
assets/__pycache__/image.cpython-312.pyc
Normal file
Binary file not shown.
BIN
assets/__pycache__/video.cpython-312.pyc
Normal file
BIN
assets/__pycache__/video.cpython-312.pyc
Normal file
Binary file not shown.
43
assets/audio.py
Normal file
43
assets/audio.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
|
from vllm.utils.import_utils import PlaceholderModule
|
||||||
|
|
||||||
|
from .base import VLLM_S3_BUCKET_URL, get_vllm_public_assets
|
||||||
|
|
||||||
|
try:
|
||||||
|
import librosa
|
||||||
|
except ImportError:
|
||||||
|
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||||
|
|
||||||
|
ASSET_DIR = "multimodal_asset"
|
||||||
|
|
||||||
|
AudioAssetName = Literal["winning_call", "mary_had_lamb"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AudioAsset:
|
||||||
|
name: AudioAssetName
|
||||||
|
|
||||||
|
@property
|
||||||
|
def filename(self) -> str:
|
||||||
|
return f"{self.name}.ogg"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
|
||||||
|
audio_path = get_vllm_public_assets(filename=self.filename, s3_prefix=ASSET_DIR)
|
||||||
|
return librosa.load(audio_path, sr=None)
|
||||||
|
|
||||||
|
def get_local_path(self) -> Path:
|
||||||
|
return get_vllm_public_assets(filename=self.filename, s3_prefix=ASSET_DIR)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def url(self) -> str:
|
||||||
|
return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg")
|
||||||
40
assets/base.py
Normal file
40
assets/base.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.connections import global_http_connection
|
||||||
|
|
||||||
|
VLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com"
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_dir() -> Path:
|
||||||
|
"""Get the path to the cache for storing downloaded assets."""
|
||||||
|
path = Path(envs.VLLM_ASSETS_CACHE)
|
||||||
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_vllm_public_assets(filename: str, s3_prefix: str | None = None) -> Path:
|
||||||
|
"""
|
||||||
|
Download an asset file from `s3://vllm-public-assets`
|
||||||
|
and return the path to the downloaded file.
|
||||||
|
"""
|
||||||
|
asset_directory = get_cache_dir() / "vllm_public_assets"
|
||||||
|
asset_directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
asset_path = asset_directory / filename
|
||||||
|
if not asset_path.exists():
|
||||||
|
if s3_prefix is not None:
|
||||||
|
filename = s3_prefix + "/" + filename
|
||||||
|
global_http_connection.download_file(
|
||||||
|
f"{VLLM_S3_BUCKET_URL}/{filename}",
|
||||||
|
asset_path,
|
||||||
|
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
return asset_path
|
||||||
59
assets/image.py
Normal file
59
assets/image.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from .base import get_vllm_public_assets
|
||||||
|
|
||||||
|
VLM_IMAGES_DIR = "vision_model_images"
|
||||||
|
|
||||||
|
ImageAssetName = Literal[
|
||||||
|
"stop_sign",
|
||||||
|
"cherry_blossom",
|
||||||
|
"hato",
|
||||||
|
"2560px-Gfp-wisconsin-madison-the-nature-boardwalk",
|
||||||
|
"Grayscale_8bits_palette_sample_image",
|
||||||
|
"1280px-Venn_diagram_rgb",
|
||||||
|
"RGBA_comp",
|
||||||
|
"237-400x300",
|
||||||
|
"231-200x300",
|
||||||
|
"27-500x500",
|
||||||
|
"17-150x600",
|
||||||
|
"handelsblatt-preview",
|
||||||
|
"paper-11",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ImageAsset:
|
||||||
|
name: ImageAssetName
|
||||||
|
|
||||||
|
def get_path(self, ext: str) -> Path:
|
||||||
|
"""
|
||||||
|
Return s3 path for given image.
|
||||||
|
"""
|
||||||
|
return get_vllm_public_assets(
|
||||||
|
filename=f"{self.name}.{ext}", s3_prefix=VLM_IMAGES_DIR
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pil_image(self, ext="jpg") -> Image.Image:
|
||||||
|
image_path = self.get_path(ext)
|
||||||
|
return Image.open(image_path)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_embeds(self) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Image embeddings, only used for testing purposes with llava 1.5.
|
||||||
|
"""
|
||||||
|
image_path = self.get_path("pt")
|
||||||
|
return torch.load(image_path, map_location="cpu", weights_only=True)
|
||||||
|
|
||||||
|
def read_bytes(self, ext: str) -> bytes:
|
||||||
|
p = Path(self.get_path(ext))
|
||||||
|
return p.read_bytes()
|
||||||
149
assets/video.py
Normal file
149
assets/video.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Any, ClassVar, Literal
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from vllm.utils.import_utils import PlaceholderModule
|
||||||
|
|
||||||
|
from .base import get_cache_dir
|
||||||
|
|
||||||
|
try:
|
||||||
|
import librosa
|
||||||
|
except ImportError:
|
||||||
|
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def download_video_asset(filename: str) -> str:
|
||||||
|
"""
|
||||||
|
Download and open an image from huggingface
|
||||||
|
repo: raushan-testing-hf/videos-test
|
||||||
|
"""
|
||||||
|
video_directory = get_cache_dir() / "video-example-data"
|
||||||
|
video_directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
video_path = video_directory / filename
|
||||||
|
video_path_str = str(video_path)
|
||||||
|
if not video_path.exists():
|
||||||
|
video_path_str = hf_hub_download(
|
||||||
|
repo_id="raushan-testing-hf/videos-test",
|
||||||
|
filename=filename,
|
||||||
|
repo_type="dataset",
|
||||||
|
cache_dir=video_directory,
|
||||||
|
)
|
||||||
|
return video_path_str
|
||||||
|
|
||||||
|
|
||||||
|
def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
cap = cv2.VideoCapture(path)
|
||||||
|
if not cap.isOpened():
|
||||||
|
raise ValueError(f"Could not open video file {path}")
|
||||||
|
|
||||||
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
frames = []
|
||||||
|
|
||||||
|
num_frames = num_frames if num_frames > 0 else total_frames
|
||||||
|
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
||||||
|
for idx in range(total_frames):
|
||||||
|
ok = cap.grab() # next img
|
||||||
|
if not ok:
|
||||||
|
break
|
||||||
|
if idx in frame_indices: # only decompress needed
|
||||||
|
ret, frame = cap.retrieve()
|
||||||
|
if ret:
|
||||||
|
# OpenCV uses BGR format, we need to convert it to RGB
|
||||||
|
# for PIL and transformers compatibility
|
||||||
|
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||||
|
|
||||||
|
frames = np.stack(frames)
|
||||||
|
if len(frames) < num_frames:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not read enough frames from video file {path}"
|
||||||
|
f" (expected {num_frames} frames, got {len(frames)})"
|
||||||
|
)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def video_to_pil_images_list(path: str, num_frames: int = -1) -> list[Image.Image]:
|
||||||
|
frames = video_to_ndarrays(path, num_frames)
|
||||||
|
return [Image.fromarray(frame) for frame in frames]
|
||||||
|
|
||||||
|
|
||||||
|
def video_get_metadata(path: str, num_frames: int = -1) -> dict[str, Any]:
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
cap = cv2.VideoCapture(path)
|
||||||
|
if not cap.isOpened():
|
||||||
|
raise ValueError(f"Could not open video file {path}")
|
||||||
|
|
||||||
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||||
|
duration = total_frames / fps if fps > 0 else 0
|
||||||
|
|
||||||
|
if num_frames == -1 or num_frames > total_frames:
|
||||||
|
num_frames = total_frames
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"total_num_frames": num_frames,
|
||||||
|
"fps": duration / num_frames,
|
||||||
|
"duration": duration,
|
||||||
|
"video_backend": "opencv",
|
||||||
|
"frames_indices": list(range(num_frames)),
|
||||||
|
# extra field used to control hf processor's video
|
||||||
|
# sampling behavior
|
||||||
|
"do_sample_frames": num_frames == total_frames,
|
||||||
|
}
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
VideoAssetName = Literal["baby_reading"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class VideoAsset:
|
||||||
|
name: VideoAssetName
|
||||||
|
num_frames: int = -1
|
||||||
|
|
||||||
|
_NAME_TO_FILE: ClassVar[dict[VideoAssetName, str]] = {
|
||||||
|
"baby_reading": "sample_demo_1.mp4",
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def filename(self) -> str:
|
||||||
|
return self._NAME_TO_FILE[self.name]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def video_path(self) -> str:
|
||||||
|
return download_video_asset(self.filename)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pil_images(self) -> list[Image.Image]:
|
||||||
|
ret = video_to_pil_images_list(self.video_path, self.num_frames)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@property
|
||||||
|
def np_ndarrays(self) -> npt.NDArray:
|
||||||
|
ret = video_to_ndarrays(self.video_path, self.num_frames)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metadata(self) -> dict[str, Any]:
|
||||||
|
ret = video_get_metadata(self.video_path, self.num_frames)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def get_audio(self, sampling_rate: float | None = None) -> npt.NDArray:
|
||||||
|
"""
|
||||||
|
Read audio data from the video asset, used in Qwen2.5-Omni examples.
|
||||||
|
|
||||||
|
See also: examples/offline_inference/qwen2_5_omni/only_thinker.py
|
||||||
|
"""
|
||||||
|
return librosa.load(self.video_path, sr=sampling_rate)[0]
|
||||||
18
attention/__init__.py
Normal file
18
attention/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import (
|
||||||
|
AttentionBackend,
|
||||||
|
AttentionMetadata,
|
||||||
|
AttentionType,
|
||||||
|
)
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Attention",
|
||||||
|
"AttentionBackend",
|
||||||
|
"AttentionMetadata",
|
||||||
|
"AttentionType",
|
||||||
|
"get_attn_backend",
|
||||||
|
]
|
||||||
BIN
attention/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
attention/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
attention/__pycache__/layer.cpython-312.pyc
Normal file
BIN
attention/__pycache__/layer.cpython-312.pyc
Normal file
Binary file not shown.
BIN
attention/__pycache__/selector.cpython-312.pyc
Normal file
BIN
attention/__pycache__/selector.cpython-312.pyc
Normal file
Binary file not shown.
0
attention/backends/__init__.py
Normal file
0
attention/backends/__init__.py
Normal file
BIN
attention/backends/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
attention/backends/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
attention/backends/__pycache__/abstract.cpython-312.pyc
Normal file
BIN
attention/backends/__pycache__/abstract.cpython-312.pyc
Normal file
Binary file not shown.
BIN
attention/backends/__pycache__/registry.cpython-312.pyc
Normal file
BIN
attention/backends/__pycache__/registry.cpython-312.pyc
Normal file
Binary file not shown.
BIN
attention/backends/__pycache__/utils.cpython-312.pyc
Normal file
BIN
attention/backends/__pycache__/utils.cpython-312.pyc
Normal file
Binary file not shown.
391
attention/backends/abstract.py
Normal file
391
attention/backends/abstract.py
Normal file
@@ -0,0 +1,391 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.config.cache import CacheDType
|
||||||
|
from vllm.platforms.interface import DeviceCapability
|
||||||
|
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionType:
|
||||||
|
"""
|
||||||
|
Attention type.
|
||||||
|
Use string to be compatible with `torch.compile`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DECODER = "decoder"
|
||||||
|
"""Decoder attention between previous layer Q/K/V."""
|
||||||
|
ENCODER = "encoder"
|
||||||
|
"""Encoder attention between previous layer Q/K/V for encoder-decoder."""
|
||||||
|
ENCODER_ONLY = "encoder_only"
|
||||||
|
"""Encoder attention between previous layer Q/K/V."""
|
||||||
|
ENCODER_DECODER = "encoder_decoder"
|
||||||
|
"""Attention between dec. Q and enc. K/V for encoder-decoder."""
|
||||||
|
|
||||||
|
|
||||||
|
class MultipleOf:
|
||||||
|
base: int
|
||||||
|
|
||||||
|
def __init__(self, base: int):
|
||||||
|
self.base = base
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBackend(ABC):
|
||||||
|
"""Abstract class for attention backends."""
|
||||||
|
|
||||||
|
# For some attention backends, we allocate an output tensor before
|
||||||
|
# calling the custom op. When piecewise cudagraph is enabled, this
|
||||||
|
# makes sure the output tensor is allocated inside the cudagraph.
|
||||||
|
accept_output_buffer: bool = False
|
||||||
|
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||||
|
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(1)]
|
||||||
|
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_impl_cls() -> type["AttentionImpl"]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_kv_cache_shape(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
cache_dtype_str: str = "auto",
|
||||||
|
) -> tuple[int, ...]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_kv_cache_stride_order() -> tuple[int, ...]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def full_cls_name(cls) -> tuple[str, str]:
|
||||||
|
return (cls.__module__, cls.__qualname__)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_head_size(cls, head_size: int) -> bool:
|
||||||
|
supported_head_sizes = cls.get_supported_head_sizes()
|
||||||
|
return (not supported_head_sizes) or head_size in supported_head_sizes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_dtype(cls, dtype: torch.dtype) -> bool:
|
||||||
|
return dtype in cls.supported_dtypes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool:
|
||||||
|
if kv_cache_dtype is None:
|
||||||
|
return True
|
||||||
|
return (not cls.supported_kv_cache_dtypes) or (
|
||||||
|
kv_cache_dtype in cls.supported_kv_cache_dtypes
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_block_size(cls, block_size: int | None) -> bool:
|
||||||
|
from vllm.config.cache import BlockSize
|
||||||
|
|
||||||
|
if block_size is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
valid_sizes = get_args(BlockSize)
|
||||||
|
if block_size not in valid_sizes:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not cls.supported_kernel_block_sizes:
|
||||||
|
return True
|
||||||
|
|
||||||
|
for supported_size in cls.supported_kernel_block_sizes:
|
||||||
|
is_multiple_of = (
|
||||||
|
isinstance(supported_size, MultipleOf)
|
||||||
|
and block_size % supported_size.base == 0
|
||||||
|
)
|
||||||
|
is_int_equal = (
|
||||||
|
isinstance(supported_size, int) and block_size == supported_size
|
||||||
|
)
|
||||||
|
if is_multiple_of or is_int_equal:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_mla(cls) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_sink(cls) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_sparse(cls) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||||
|
"""Check if backend supports a given attention type.
|
||||||
|
|
||||||
|
By default, only supports decoder attention.
|
||||||
|
Backends should override this to support other attention types.
|
||||||
|
"""
|
||||||
|
from vllm.attention import AttentionType
|
||||||
|
|
||||||
|
return attn_type == AttentionType.DECODER
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_combination(
|
||||||
|
cls,
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: "CacheDType | None",
|
||||||
|
block_size: int | None,
|
||||||
|
use_mla: bool,
|
||||||
|
has_sink: bool,
|
||||||
|
use_sparse: bool,
|
||||||
|
device_capability: "DeviceCapability",
|
||||||
|
) -> str | None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_configuration(
|
||||||
|
cls,
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: "CacheDType | None",
|
||||||
|
block_size: int | None,
|
||||||
|
use_mla: bool,
|
||||||
|
has_sink: bool,
|
||||||
|
use_sparse: bool,
|
||||||
|
device_capability: "DeviceCapability",
|
||||||
|
attn_type: str,
|
||||||
|
) -> list[str]:
|
||||||
|
invalid_reasons = []
|
||||||
|
if not cls.supports_head_size(head_size):
|
||||||
|
invalid_reasons.append("head_size not supported")
|
||||||
|
if not cls.supports_dtype(dtype):
|
||||||
|
invalid_reasons.append("dtype not supported")
|
||||||
|
if not cls.supports_kv_cache_dtype(kv_cache_dtype):
|
||||||
|
invalid_reasons.append("kv_cache_dtype not supported")
|
||||||
|
if not cls.supports_block_size(block_size):
|
||||||
|
invalid_reasons.append("block_size not supported")
|
||||||
|
if use_mla != cls.is_mla():
|
||||||
|
if use_mla:
|
||||||
|
invalid_reasons.append("MLA not supported")
|
||||||
|
else:
|
||||||
|
invalid_reasons.append("non-MLA not supported")
|
||||||
|
if has_sink and not cls.supports_sink():
|
||||||
|
invalid_reasons.append("sink setting not supported")
|
||||||
|
if use_sparse != cls.is_sparse():
|
||||||
|
if use_sparse:
|
||||||
|
invalid_reasons.append("sparse not supported")
|
||||||
|
else:
|
||||||
|
invalid_reasons.append("non-sparse not supported")
|
||||||
|
if not cls.supports_compute_capability(device_capability):
|
||||||
|
invalid_reasons.append("compute capability not supported")
|
||||||
|
if not cls.supports_attn_type(attn_type):
|
||||||
|
invalid_reasons.append(f"attention type {attn_type} not supported")
|
||||||
|
combination_reason = cls.supports_combination(
|
||||||
|
head_size,
|
||||||
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
|
block_size,
|
||||||
|
use_mla,
|
||||||
|
has_sink,
|
||||||
|
use_sparse,
|
||||||
|
device_capability,
|
||||||
|
)
|
||||||
|
if combination_reason is not None:
|
||||||
|
invalid_reasons.append(combination_reason)
|
||||||
|
return invalid_reasons
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionMetadata:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=AttentionMetadata)
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionLayer(Protocol):
|
||||||
|
_q_scale: torch.Tensor
|
||||||
|
_k_scale: torch.Tensor
|
||||||
|
_v_scale: torch.Tensor
|
||||||
|
_q_scale_float: float
|
||||||
|
_k_scale_float: float
|
||||||
|
_v_scale_float: float
|
||||||
|
_prob_scale: torch.Tensor
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionImpl(ABC, Generic[T]):
|
||||||
|
# Whether the attention impl can return the softmax lse for decode.
|
||||||
|
# Some features like decode context parallelism require the softmax lse.
|
||||||
|
can_return_lse_for_decode: bool = False
|
||||||
|
|
||||||
|
# some attention backends might not always want to return lse
|
||||||
|
# even if they can return lse (for efficiency reasons)
|
||||||
|
need_to_return_lse_for_decode: bool = False
|
||||||
|
|
||||||
|
dcp_world_size: int
|
||||||
|
dcp_rank: int
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
# use __new__ so that all subclasses will call this
|
||||||
|
self = super().__new__(cls)
|
||||||
|
try:
|
||||||
|
from vllm.distributed.parallel_state import get_dcp_group
|
||||||
|
|
||||||
|
self.dcp_world_size = get_dcp_group().world_size
|
||||||
|
self.dcp_rank = get_dcp_group().rank_in_group
|
||||||
|
except AssertionError:
|
||||||
|
# DCP might not be initialized in testing
|
||||||
|
self.dcp_world_size = 1
|
||||||
|
self.dcp_rank = 0
|
||||||
|
self.need_to_return_lse_for_decode = (
|
||||||
|
self.dcp_world_size > 1 and self.can_return_lse_for_decode
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: int | None = None,
|
||||||
|
alibi_slopes: list[float] | None = None,
|
||||||
|
sliding_window: int | None = None,
|
||||||
|
kv_cache_dtype: str = "auto",
|
||||||
|
logits_soft_cap: float | None = None,
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: T,
|
||||||
|
output: torch.Tensor | None = None,
|
||||||
|
output_scale: torch.Tensor | None = None,
|
||||||
|
output_block_scale: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||||
|
"""
|
||||||
|
Does this attention implementation support fused output quantization.
|
||||||
|
This is used by the AttnFusionPass to only fuse output quantization
|
||||||
|
onto implementations that support it.
|
||||||
|
|
||||||
|
:param quant_key: QuantKey object that describes the quantization op
|
||||||
|
:return: is fusion supported for this type of quantization
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def supports_quant_query_input(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if this attention implementation supports pre-quantized query input.
|
||||||
|
|
||||||
|
When True, the attention layer will quantize queries before passing them
|
||||||
|
to this backend, allowing torch.compile to fuse the quantization with
|
||||||
|
previous operations. This is typically supported when using FP8 KV cache
|
||||||
|
with compatible attention kernels (e.g., TRT-LLM).
|
||||||
|
TODO add support to more backends:
|
||||||
|
https://github.com/vllm-project/vllm/issues/25584
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the implementation can accept pre-quantized queries.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: int,
|
||||||
|
alibi_slopes: list[float] | None,
|
||||||
|
sliding_window: int | None,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
logits_soft_cap: float | None,
|
||||||
|
attn_type: str,
|
||||||
|
kv_sharing_target_layer_name: str | None,
|
||||||
|
# MLA Specific Arguments
|
||||||
|
q_lora_rank: int | None,
|
||||||
|
kv_lora_rank: int,
|
||||||
|
qk_nope_head_dim: int,
|
||||||
|
qk_rope_head_dim: int,
|
||||||
|
qk_head_dim: int,
|
||||||
|
v_head_dim: int,
|
||||||
|
kv_b_proj: ColumnParallelLinear,
|
||||||
|
indexer: object | None = None,
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
|
hidden_states_or_cq: torch.Tensor,
|
||||||
|
kv_c_normed: torch.Tensor,
|
||||||
|
k_pe: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: T,
|
||||||
|
output: torch.Tensor | None = None,
|
||||||
|
output_scale: torch.Tensor | None = None,
|
||||||
|
output_block_scale: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
||||||
|
return kv_cache_dtype != "auto"
|
||||||
195
attention/backends/registry.py
Normal file
195
attention/backends/registry.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Attention backend registry"""
|
||||||
|
|
||||||
|
import enum
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import TYPE_CHECKING, cast
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class _AttentionBackendEnumMeta(enum.EnumMeta):
|
||||||
|
"""Metaclass for AttentionBackendEnum to provide better error messages."""
|
||||||
|
|
||||||
|
def __getitem__(cls, name: str):
|
||||||
|
"""Get backend by name with helpful error messages."""
|
||||||
|
try:
|
||||||
|
return super().__getitem__(name)
|
||||||
|
except KeyError:
|
||||||
|
members = cast("dict[str, AttentionBackendEnum]", cls.__members__).values()
|
||||||
|
valid_backends = ", ".join(m.name for m in members)
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown attention backend: '{name}'. "
|
||||||
|
f"Valid options are: {valid_backends}"
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
|
||||||
|
"""Enumeration of all supported attention backends.
|
||||||
|
|
||||||
|
The enum value is the default class path, but this can be overridden
|
||||||
|
at runtime using register_backend().
|
||||||
|
|
||||||
|
To get the actual backend class (respecting overrides), use:
|
||||||
|
backend.get_class()
|
||||||
|
"""
|
||||||
|
|
||||||
|
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||||
|
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
|
||||||
|
XFORMERS = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend"
|
||||||
|
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
|
||||||
|
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
|
||||||
|
ROCM_AITER_FA = (
|
||||||
|
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
|
||||||
|
)
|
||||||
|
TORCH_SDPA = "" # this tag is only used for ViT
|
||||||
|
FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
|
||||||
|
FLASHINFER_MLA = (
|
||||||
|
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
|
||||||
|
)
|
||||||
|
TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
|
||||||
|
CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
|
||||||
|
FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
|
||||||
|
FLASHMLA_SPARSE = (
|
||||||
|
"vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend"
|
||||||
|
)
|
||||||
|
FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
|
||||||
|
PALLAS = "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
|
||||||
|
IPEX = "vllm.v1.attention.backends.ipex.IpexAttentionBackend"
|
||||||
|
NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend"
|
||||||
|
FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
|
||||||
|
TREE_ATTN = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"
|
||||||
|
ROCM_AITER_UNIFIED_ATTN = (
|
||||||
|
"vllm.v1.attention.backends.rocm_aiter_unified_attn."
|
||||||
|
"RocmAiterUnifiedAttentionBackend"
|
||||||
|
)
|
||||||
|
CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
|
||||||
|
# Placeholder for third-party/custom backends - must be registered before use
|
||||||
|
CUSTOM = ""
|
||||||
|
|
||||||
|
def get_path(self, include_classname: bool = True) -> str:
|
||||||
|
"""Get the class path for this backend (respects overrides).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The fully qualified class path string
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If Backend.CUSTOM is used without being registered
|
||||||
|
"""
|
||||||
|
path = _OVERRIDES.get(self, self.value)
|
||||||
|
if not path:
|
||||||
|
raise ValueError(
|
||||||
|
f"Backend {self.name} must be registered before use. "
|
||||||
|
f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')"
|
||||||
|
)
|
||||||
|
if not include_classname:
|
||||||
|
path = path.rsplit(".", 1)[0]
|
||||||
|
return path
|
||||||
|
|
||||||
|
def get_class(self) -> "type[AttentionBackend]":
|
||||||
|
"""Get the backend class (respects overrides).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The backend class
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If the backend class cannot be imported
|
||||||
|
ValueError: If Backend.CUSTOM is used without being registered
|
||||||
|
"""
|
||||||
|
return resolve_obj_by_qualname(self.get_path())
|
||||||
|
|
||||||
|
def is_overridden(self) -> bool:
|
||||||
|
"""Check if this backend has been overridden.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the backend has a registered override
|
||||||
|
"""
|
||||||
|
return self in _OVERRIDES
|
||||||
|
|
||||||
|
def clear_override(self) -> None:
|
||||||
|
"""Clear any override for this backend, reverting to the default."""
|
||||||
|
_OVERRIDES.pop(self, None)
|
||||||
|
|
||||||
|
|
||||||
|
_OVERRIDES: dict[AttentionBackendEnum, str] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_backend(
|
||||||
|
backend: AttentionBackendEnum, class_path: str | None = None
|
||||||
|
) -> Callable[[type], type]:
|
||||||
|
"""Register or override a backend implementation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend: The AttentionBackendEnum member to register
|
||||||
|
class_path: Optional class path. If not provided and used as
|
||||||
|
decorator, will be auto-generated from the class.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decorator function if class_path is None, otherwise a no-op
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
# Override an existing backend
|
||||||
|
@register_backend(AttentionBackendEnum.FLASH_ATTN)
|
||||||
|
class MyCustomFlashAttn:
|
||||||
|
...
|
||||||
|
|
||||||
|
# Register a custom third-party backend
|
||||||
|
@register_backend(AttentionBackendEnum.CUSTOM)
|
||||||
|
class MyCustomBackend:
|
||||||
|
...
|
||||||
|
|
||||||
|
# Direct registration
|
||||||
|
register_backend(
|
||||||
|
AttentionBackendEnum.CUSTOM,
|
||||||
|
"my.module.MyCustomBackend"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(cls: type) -> type:
|
||||||
|
_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}"
|
||||||
|
return cls
|
||||||
|
|
||||||
|
if class_path is not None:
|
||||||
|
_OVERRIDES[backend] = class_path
|
||||||
|
return lambda x: x
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
# Backwards compatibility alias for plugins
|
||||||
|
class _BackendMeta(type):
|
||||||
|
"""Metaclass to provide deprecation warnings when accessing _Backend."""
|
||||||
|
|
||||||
|
def __getattribute__(cls, name: str):
|
||||||
|
if name not in ("__class__", "__mro__", "__name__"):
|
||||||
|
logger.warning(
|
||||||
|
"_Backend has been renamed to AttentionBackendEnum. "
|
||||||
|
"Please update your code to use AttentionBackendEnum instead. "
|
||||||
|
"_Backend will be removed in a future release."
|
||||||
|
)
|
||||||
|
return getattr(AttentionBackendEnum, name)
|
||||||
|
|
||||||
|
def __getitem__(cls, name: str):
|
||||||
|
logger.warning(
|
||||||
|
"_Backend has been renamed to AttentionBackendEnum. "
|
||||||
|
"Please update your code to use AttentionBackendEnum instead. "
|
||||||
|
"_Backend will be removed in a future release."
|
||||||
|
)
|
||||||
|
return AttentionBackendEnum[name]
|
||||||
|
|
||||||
|
|
||||||
|
class _Backend(metaclass=_BackendMeta):
|
||||||
|
"""Deprecated: Use AttentionBackendEnum instead.
|
||||||
|
|
||||||
|
This class is provided for backwards compatibility with plugins
|
||||||
|
and will be removed in a future release.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
33
attention/backends/utils.py
Normal file
33
attention/backends/utils.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Attention backend utils"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
PAD_SLOT_ID = -1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MLADims:
|
||||||
|
q_lora_rank: int | None
|
||||||
|
kv_lora_rank: int
|
||||||
|
qk_nope_head_dim: int
|
||||||
|
qk_rope_head_dim: int
|
||||||
|
v_head_dim: int
|
||||||
|
|
||||||
|
|
||||||
|
def get_mla_dims(model_config: ModelConfig) -> MLADims:
|
||||||
|
hf_text_config = model_config.hf_text_config
|
||||||
|
|
||||||
|
return MLADims(
|
||||||
|
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
|
||||||
|
kv_lora_rank=hf_text_config.kv_lora_rank,
|
||||||
|
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
|
||||||
|
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
|
||||||
|
v_head_dim=hf_text_config.v_head_dim,
|
||||||
|
)
|
||||||
1051
attention/layer.py
Normal file
1051
attention/layer.py
Normal file
File diff suppressed because it is too large
Load Diff
0
attention/layers/__init__.py
Normal file
0
attention/layers/__init__.py
Normal file
BIN
attention/layers/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
attention/layers/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
attention/layers/__pycache__/cross_attention.cpython-312.pyc
Normal file
BIN
attention/layers/__pycache__/cross_attention.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
121
attention/layers/chunked_local_attention.py
Normal file
121
attention/layers/chunked_local_attention.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||||
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
from vllm.config import CacheConfig
|
||||||
|
from vllm.config.vllm import VllmConfig
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.v1.attention.backends.utils import (
|
||||||
|
AttentionCGSupport,
|
||||||
|
AttentionMetadataBuilder,
|
||||||
|
CommonAttentionMetadata,
|
||||||
|
make_local_attention_virtual_batches,
|
||||||
|
subclass_attention_backend,
|
||||||
|
)
|
||||||
|
from vllm.v1.kv_cache_interface import (
|
||||||
|
AttentionSpec,
|
||||||
|
ChunkedLocalAttentionSpec,
|
||||||
|
KVCacheSpec,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..layer import Attention
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
|
def create_chunked_local_attention_backend(
|
||||||
|
underlying_attn_backend: AttentionBackend,
|
||||||
|
attention_chunk_size: int,
|
||||||
|
block_size: int,
|
||||||
|
) -> type[AttentionBackend]:
|
||||||
|
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
|
||||||
|
|
||||||
|
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||||
|
assert issubclass(underlying_builder, AttentionMetadataBuilder)
|
||||||
|
|
||||||
|
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
|
||||||
|
@classmethod
|
||||||
|
def get_cudagraph_support(
|
||||||
|
cls: type["AttentionMetadataBuilder"],
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
kv_cache_spec: AttentionSpec,
|
||||||
|
) -> AttentionCGSupport:
|
||||||
|
# Explicit override in case the underlying builder specialized this getter.
|
||||||
|
# @override omitted only because of mypy limitation due to type variable.
|
||||||
|
return AttentionCGSupport.NEVER
|
||||||
|
|
||||||
|
def build(
|
||||||
|
self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
fast_build: bool = False,
|
||||||
|
) -> AttentionMetadata:
|
||||||
|
common_attn_metadata = make_local_attention_virtual_batches(
|
||||||
|
attention_chunk_size, common_attn_metadata, block_size
|
||||||
|
)
|
||||||
|
return super().build(common_prefix_len, common_attn_metadata, fast_build)
|
||||||
|
|
||||||
|
attn_backend = subclass_attention_backend(
|
||||||
|
name_prefix=prefix,
|
||||||
|
attention_backend_cls=underlying_attn_backend,
|
||||||
|
builder_cls=ChunkedLocalAttentionBuilder,
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_backend
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkedLocalAttention(Attention):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
attention_chunk_size: int,
|
||||||
|
num_kv_heads: int | None = None,
|
||||||
|
alibi_slopes: list[float] | None = None,
|
||||||
|
cache_config: CacheConfig | None = None,
|
||||||
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
kv_sharing_target_layer_name: str | None = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
self.attention_chunk_size = attention_chunk_size
|
||||||
|
dtype = torch.get_default_dtype()
|
||||||
|
if cache_config is not None:
|
||||||
|
kv_cache_dtype = cache_config.cache_dtype
|
||||||
|
block_size = cache_config.block_size
|
||||||
|
else:
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
block_size = 16
|
||||||
|
|
||||||
|
underlying_attn_backend = get_attn_backend(
|
||||||
|
head_size, dtype, kv_cache_dtype, block_size
|
||||||
|
)
|
||||||
|
attn_backend = create_chunked_local_attention_backend(
|
||||||
|
underlying_attn_backend, attention_chunk_size, block_size
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
num_heads=num_heads,
|
||||||
|
head_size=head_size,
|
||||||
|
scale=scale,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
alibi_slopes=alibi_slopes,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix,
|
||||||
|
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
||||||
|
attn_backend=attn_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||||
|
assert self.attention_chunk_size
|
||||||
|
return ChunkedLocalAttentionSpec(
|
||||||
|
block_size=vllm_config.cache_config.block_size,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
head_size=self.head_size,
|
||||||
|
dtype=self.kv_cache_torch_dtype,
|
||||||
|
attention_chunk_size=self.attention_chunk_size,
|
||||||
|
)
|
||||||
178
attention/layers/cross_attention.py
Normal file
178
attention/layers/cross_attention.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import functools
|
||||||
|
from copy import copy
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import (
|
||||||
|
AttentionBackend,
|
||||||
|
AttentionMetadata,
|
||||||
|
AttentionType,
|
||||||
|
)
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils.math_utils import cdiv
|
||||||
|
from vllm.v1.attention.backends.utils import (
|
||||||
|
CommonAttentionMetadata,
|
||||||
|
subclass_attention_backend,
|
||||||
|
)
|
||||||
|
from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_max_encoder_len(vllm_config: "VllmConfig") -> int:
|
||||||
|
"""Gets the max number of encoder input tokens from the config."""
|
||||||
|
sc = vllm_config.scheduler_config
|
||||||
|
assert sc and isinstance(sc.max_num_encoder_input_tokens, int), (
|
||||||
|
"max_num_encoder_input_tokens must be int for enc-dec models"
|
||||||
|
)
|
||||||
|
return sc.max_num_encoder_input_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cross_slot_mapping(
|
||||||
|
encoder_seq_lens: np.ndarray,
|
||||||
|
block_table_tensor: torch.Tensor,
|
||||||
|
kv_cache_spec: CrossAttentionSpec,
|
||||||
|
device: torch.device,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Get cross-attention slot mappings."""
|
||||||
|
|
||||||
|
block_size = kv_cache_spec.block_size
|
||||||
|
slot_mappings = []
|
||||||
|
|
||||||
|
# Find indices with non-zero encoder sequence lengths
|
||||||
|
# The majority of parallel requests will be running the
|
||||||
|
# decoder, so this list should be relatively small.
|
||||||
|
active_indices = np.nonzero(encoder_seq_lens)[0]
|
||||||
|
|
||||||
|
for req_index in active_indices:
|
||||||
|
encoder_seq_len = encoder_seq_lens[req_index].item()
|
||||||
|
|
||||||
|
# Calculate the number of blocks needed for this request
|
||||||
|
num_blocks_needed = cdiv(encoder_seq_len, block_size)
|
||||||
|
|
||||||
|
# Get the block IDs for this request from the tensor
|
||||||
|
req_block_ids = block_table_tensor[req_index]
|
||||||
|
|
||||||
|
# Get only the blocks we need (first num_blocks_needed blocks)
|
||||||
|
needed_block_ids = req_block_ids[:num_blocks_needed]
|
||||||
|
|
||||||
|
# All needed blocks are allocated
|
||||||
|
i_values = torch.arange(encoder_seq_len, dtype=torch.int64, device=device)
|
||||||
|
block_indices = i_values // block_size
|
||||||
|
block_offsets = i_values % block_size
|
||||||
|
block_numbers = needed_block_ids[block_indices]
|
||||||
|
slot_mapping = block_numbers * block_size + block_offsets
|
||||||
|
|
||||||
|
slot_mappings.append(slot_mapping)
|
||||||
|
|
||||||
|
if slot_mappings:
|
||||||
|
return torch.cat(slot_mappings)
|
||||||
|
else:
|
||||||
|
return torch.empty(0, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
|
def create_cross_attention_backend(
|
||||||
|
underlying_attn_backend: AttentionBackend,
|
||||||
|
) -> type[AttentionBackend]:
|
||||||
|
prefix = "CrossAttention_"
|
||||||
|
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||||
|
|
||||||
|
class CrossAttentionBuilder(underlying_builder): # type: ignore
|
||||||
|
def build(
|
||||||
|
self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
fast_build: bool = False,
|
||||||
|
) -> AttentionMetadata:
|
||||||
|
new_metadata = copy(common_attn_metadata)
|
||||||
|
new_metadata.causal = False
|
||||||
|
max_encoder_len = _get_max_encoder_len(self.vllm_config)
|
||||||
|
new_metadata.max_seq_len = max_encoder_len
|
||||||
|
|
||||||
|
new_metadata.seq_lens = torch.full(
|
||||||
|
(new_metadata.num_reqs,),
|
||||||
|
max_encoder_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
new_metadata.seq_lens_cpu = torch.full(
|
||||||
|
(new_metadata.num_reqs,),
|
||||||
|
max_encoder_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
new_metadata.slot_mapping = _get_cross_slot_mapping(
|
||||||
|
new_metadata.encoder_seq_lens,
|
||||||
|
new_metadata.block_table_tensor,
|
||||||
|
self.kv_cache_spec,
|
||||||
|
self.device,
|
||||||
|
)
|
||||||
|
return super().build(common_prefix_len, new_metadata, fast_build)
|
||||||
|
|
||||||
|
attn_backend = subclass_attention_backend(
|
||||||
|
name_prefix=prefix,
|
||||||
|
attention_backend_cls=underlying_attn_backend,
|
||||||
|
builder_cls=CrossAttentionBuilder,
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_backend
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttention(Attention):
|
||||||
|
"""
|
||||||
|
Cross-attention for encoder-decoder models.
|
||||||
|
Handles attention between decoder queries and encoder keys/values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
cache_config: CacheConfig | None = None,
|
||||||
|
attn_type: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
|
if cache_config is not None:
|
||||||
|
kv_cache_dtype = cache_config.cache_dtype
|
||||||
|
block_size = cache_config.block_size
|
||||||
|
else:
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
block_size = 16
|
||||||
|
|
||||||
|
underlying_attn_backend = get_attn_backend(
|
||||||
|
head_size, dtype, kv_cache_dtype, block_size
|
||||||
|
)
|
||||||
|
attn_backend = create_cross_attention_backend(underlying_attn_backend)
|
||||||
|
|
||||||
|
if attn_type is not None:
|
||||||
|
assert attn_type == AttentionType.ENCODER_DECODER, (
|
||||||
|
"CrossAttention only supports AttentionType.ENCODER_DECODER"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
num_heads=num_heads,
|
||||||
|
head_size=head_size,
|
||||||
|
scale=scale,
|
||||||
|
cache_config=cache_config,
|
||||||
|
attn_backend=attn_backend,
|
||||||
|
attn_type=AttentionType.ENCODER_DECODER,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||||
|
return CrossAttentionSpec(
|
||||||
|
block_size=vllm_config.cache_config.block_size,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
head_size=self.head_size,
|
||||||
|
dtype=self.kv_cache_torch_dtype,
|
||||||
|
)
|
||||||
103
attention/layers/encoder_only_attention.py
Normal file
103
attention/layers/encoder_only_attention.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import functools
|
||||||
|
from copy import copy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import (
|
||||||
|
AttentionBackend,
|
||||||
|
AttentionMetadata,
|
||||||
|
AttentionType,
|
||||||
|
)
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
from vllm.config import CacheConfig
|
||||||
|
from vllm.config.vllm import VllmConfig
|
||||||
|
from vllm.v1.attention.backends.utils import (
|
||||||
|
CommonAttentionMetadata,
|
||||||
|
subclass_attention_backend,
|
||||||
|
)
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
|
def create_encoder_only_attention_backend(
|
||||||
|
underlying_attn_backend: AttentionBackend,
|
||||||
|
) -> type[AttentionBackend]:
|
||||||
|
prefix = "EncoderOnlyAttention_"
|
||||||
|
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||||
|
|
||||||
|
class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore
|
||||||
|
def build(
|
||||||
|
self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
fast_build: bool = False,
|
||||||
|
) -> AttentionMetadata:
|
||||||
|
new_common_attn_metadata = copy(common_attn_metadata)
|
||||||
|
new_common_attn_metadata.causal = False
|
||||||
|
return super().build(
|
||||||
|
common_prefix_len, new_common_attn_metadata, fast_build
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_backend = subclass_attention_backend(
|
||||||
|
name_prefix=prefix,
|
||||||
|
attention_backend_cls=underlying_attn_backend,
|
||||||
|
builder_cls=EncoderOnlyAttentionBuilder,
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_backend
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderOnlyAttention(Attention):
|
||||||
|
"""
|
||||||
|
Encoder attention is a special case that doesn't need a KV Cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
cache_config: CacheConfig | None = None,
|
||||||
|
attn_type: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
|
if cache_config is not None:
|
||||||
|
kv_cache_dtype = cache_config.cache_dtype
|
||||||
|
block_size = cache_config.block_size
|
||||||
|
else:
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
block_size = 16
|
||||||
|
|
||||||
|
underlying_attn_backend = get_attn_backend(
|
||||||
|
head_size,
|
||||||
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
|
block_size,
|
||||||
|
attn_type=AttentionType.ENCODER_ONLY,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)
|
||||||
|
|
||||||
|
if attn_type is not None:
|
||||||
|
assert attn_type == AttentionType.ENCODER_ONLY, (
|
||||||
|
"EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
num_heads=num_heads,
|
||||||
|
head_size=head_size,
|
||||||
|
scale=scale,
|
||||||
|
cache_config=cache_config,
|
||||||
|
attn_backend=attn_backend,
|
||||||
|
attn_type=AttentionType.ENCODER_ONLY,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||||
|
# Does not need KV cache
|
||||||
|
return None
|
||||||
0
attention/ops/__init__.py
Normal file
0
attention/ops/__init__.py
Normal file
BIN
attention/ops/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
attention/ops/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
attention/ops/__pycache__/common.cpython-312.pyc
Normal file
BIN
attention/ops/__pycache__/common.cpython-312.pyc
Normal file
Binary file not shown.
BIN
attention/ops/__pycache__/flashmla.cpython-312.pyc
Normal file
BIN
attention/ops/__pycache__/flashmla.cpython-312.pyc
Normal file
Binary file not shown.
BIN
attention/ops/__pycache__/merge_attn_states.cpython-312.pyc
Normal file
BIN
attention/ops/__pycache__/merge_attn_states.cpython-312.pyc
Normal file
Binary file not shown.
BIN
attention/ops/__pycache__/paged_attn.cpython-312.pyc
Normal file
BIN
attention/ops/__pycache__/paged_attn.cpython-312.pyc
Normal file
Binary file not shown.
BIN
attention/ops/__pycache__/pallas_kv_cache_update.cpython-312.pyc
Normal file
BIN
attention/ops/__pycache__/pallas_kv_cache_update.cpython-312.pyc
Normal file
Binary file not shown.
BIN
attention/ops/__pycache__/prefix_prefill.cpython-312.pyc
Normal file
BIN
attention/ops/__pycache__/prefix_prefill.cpython-312.pyc
Normal file
Binary file not shown.
BIN
attention/ops/__pycache__/rocm_aiter_paged_attn.cpython-312.pyc
Normal file
BIN
attention/ops/__pycache__/rocm_aiter_paged_attn.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
attention/ops/__pycache__/vit_attn_wrappers.cpython-312.pyc
Normal file
BIN
attention/ops/__pycache__/vit_attn_wrappers.cpython-312.pyc
Normal file
Binary file not shown.
401
attention/ops/chunked_prefill_paged_decode.py
Normal file
401
attention/ops/chunked_prefill_paged_decode.py
Normal file
@@ -0,0 +1,401 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Authors:
|
||||||
|
# - Burkhard Ringlein <ngl@zurich.ibm.com>
|
||||||
|
# - Jan van Lunteren <jvl@zurich.ibm.com>
|
||||||
|
# - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
|
||||||
|
# - Thomas Parnell <tpa@zurich.ibm.com>
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
from .prefix_prefill import context_attention_fwd
|
||||||
|
|
||||||
|
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def cdiv_fn(x, y):
|
||||||
|
return (x + y - 1) // y
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def kernel_paged_attention_2d(
|
||||||
|
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||||
|
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||||
|
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
|
||||||
|
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
|
||||||
|
sink_ptr, # [num_query_heads]
|
||||||
|
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||||
|
seq_lens_ptr, # [num_seqs]
|
||||||
|
alibi_slopes_ptr, # [num_query_heads]
|
||||||
|
scale, # float32
|
||||||
|
k_scale, # float32
|
||||||
|
v_scale, # float32
|
||||||
|
out_scale_inv,
|
||||||
|
num_query_heads: tl.constexpr, # int
|
||||||
|
num_queries_per_kv: tl.constexpr, # int
|
||||||
|
num_queries_per_kv_padded: tl.constexpr, # int
|
||||||
|
block_table_stride: tl.int64, # int
|
||||||
|
query_stride_0: tl.int64, # int
|
||||||
|
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||||
|
output_stride_0: tl.int64, # int
|
||||||
|
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||||
|
BLOCK_SIZE: tl.constexpr, # int
|
||||||
|
HEAD_SIZE: tl.constexpr, # int
|
||||||
|
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||||
|
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||||
|
SLIDING_WINDOW: tl.constexpr, # int
|
||||||
|
x: tl.constexpr, # int
|
||||||
|
stride_k_cache_0: tl.int64, # int
|
||||||
|
stride_k_cache_1: tl.int64, # int
|
||||||
|
stride_k_cache_2: tl.int64, # int
|
||||||
|
stride_k_cache_3: tl.int64, # int
|
||||||
|
stride_k_cache_4: tl.int64, # int
|
||||||
|
stride_v_cache_0: tl.int64, # int
|
||||||
|
stride_v_cache_1: tl.int64, # int
|
||||||
|
stride_v_cache_2: tl.int64, # int
|
||||||
|
stride_v_cache_3: tl.int64, # int
|
||||||
|
filter_by_query_len: tl.constexpr, # bool
|
||||||
|
query_start_len_ptr, # [num_seqs+1]
|
||||||
|
USE_SINKS: tl.constexpr, # bool
|
||||||
|
USE_FP8: tl.constexpr,
|
||||||
|
FP8_MIN: tl.constexpr = float8_info.min,
|
||||||
|
FP8_MAX: tl.constexpr = float8_info.max,
|
||||||
|
):
|
||||||
|
seq_idx = tl.program_id(0)
|
||||||
|
kv_head_idx = tl.program_id(1)
|
||||||
|
|
||||||
|
if filter_by_query_len:
|
||||||
|
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
||||||
|
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
|
||||||
|
cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index
|
||||||
|
if cur_batch_query_len > 1:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
cur_batch_in_all_start_index = seq_idx
|
||||||
|
|
||||||
|
query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange(
|
||||||
|
0, num_queries_per_kv_padded
|
||||||
|
)
|
||||||
|
|
||||||
|
query_offset = (
|
||||||
|
cur_batch_in_all_start_index * query_stride_0
|
||||||
|
+ query_head_idx[:, None] * query_stride_1
|
||||||
|
)
|
||||||
|
|
||||||
|
head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv
|
||||||
|
head_mask = head_mask & (query_head_idx < num_query_heads)
|
||||||
|
|
||||||
|
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, 0).to(tl.int1)
|
||||||
|
|
||||||
|
# Q : (num_queries_per_kv, HEAD_SIZE,)
|
||||||
|
Q = tl.load(
|
||||||
|
query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :],
|
||||||
|
mask=dim_mask[None, :] & head_mask[:, None],
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
block_table_offset = seq_idx * block_table_stride
|
||||||
|
|
||||||
|
if not USE_SINKS:
|
||||||
|
M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
|
||||||
|
else:
|
||||||
|
M = tl.load(
|
||||||
|
sink_ptr + query_head_idx,
|
||||||
|
mask=head_mask,
|
||||||
|
other=float("-inf"),
|
||||||
|
).to(dtype=tl.float32)
|
||||||
|
|
||||||
|
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
|
||||||
|
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], dtype=tl.float32)
|
||||||
|
|
||||||
|
# sequence len for this particular sequence
|
||||||
|
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||||
|
|
||||||
|
# alibi slope for this head
|
||||||
|
if USE_ALIBI_SLOPES:
|
||||||
|
alibi_slope = tl.load(
|
||||||
|
alibi_slopes_ptr + query_head_idx, mask=head_mask, other=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
|
||||||
|
|
||||||
|
# iterate through tiles
|
||||||
|
for j in range(0, num_blocks):
|
||||||
|
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
|
||||||
|
|
||||||
|
offs_n = tl.arange(0, BLOCK_SIZE)
|
||||||
|
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||||
|
|
||||||
|
v_offset = (
|
||||||
|
physical_block_idx * stride_v_cache_0
|
||||||
|
+ kv_head_idx * stride_v_cache_1
|
||||||
|
+ offs_d[None, :] * stride_v_cache_2
|
||||||
|
+ offs_n[:, None] * stride_v_cache_3
|
||||||
|
)
|
||||||
|
|
||||||
|
k_offset = (
|
||||||
|
physical_block_idx * stride_k_cache_0
|
||||||
|
+ kv_head_idx * stride_k_cache_1
|
||||||
|
+ (offs_d[:, None] // x) * stride_k_cache_2
|
||||||
|
+ offs_n[None, :] * stride_k_cache_3
|
||||||
|
+ (offs_d[:, None] % x) * stride_k_cache_4
|
||||||
|
)
|
||||||
|
|
||||||
|
# K : (HEAD_SIZE, BLOCK_SIZE)
|
||||||
|
K_load = tl.load(key_cache_ptr + k_offset, mask=dim_mask[:, None], other=0.0)
|
||||||
|
|
||||||
|
if K_load.dtype.is_fp8():
|
||||||
|
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
||||||
|
else:
|
||||||
|
K = K_load
|
||||||
|
|
||||||
|
# V : (BLOCK_SIZE, HEAD_SIZE)
|
||||||
|
V_load = tl.load(value_cache_ptr + v_offset, mask=dim_mask[None, :], other=0.0)
|
||||||
|
|
||||||
|
if V_load.dtype.is_fp8():
|
||||||
|
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
||||||
|
else:
|
||||||
|
V = V_load
|
||||||
|
|
||||||
|
seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
|
boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32)
|
||||||
|
seq_mask = seq_offset[None, :] < boundary
|
||||||
|
|
||||||
|
# S : (num_queries_per_kv, BLOCK_SIZE,)
|
||||||
|
S = tl.where(head_mask[:, None] & seq_mask, 0.0, float("-inf")).to(tl.float32)
|
||||||
|
S += scale * tl.dot(Q, K)
|
||||||
|
|
||||||
|
context_len = seq_len - 1
|
||||||
|
|
||||||
|
if SLIDING_WINDOW > 0:
|
||||||
|
S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S, -10000)
|
||||||
|
|
||||||
|
if USE_ALIBI_SLOPES:
|
||||||
|
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||||
|
|
||||||
|
# compute running maximum
|
||||||
|
# m_j : (num_queries_per_kv,)
|
||||||
|
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||||
|
|
||||||
|
# P : (num_queries_per_kv, BLOCK_SIZE,)
|
||||||
|
P = tl.exp(S - m_j[:, None])
|
||||||
|
|
||||||
|
# l_j : (num_queries_per_kv,)
|
||||||
|
l_j = tl.sum(P, axis=1)
|
||||||
|
|
||||||
|
# alpha : (num_queries_per_kv, )
|
||||||
|
alpha = tl.exp(M - m_j)
|
||||||
|
|
||||||
|
# acc : (num_queries_per_kv, BLOCK_SIZE,)
|
||||||
|
acc = acc * alpha[:, None]
|
||||||
|
|
||||||
|
# update constants
|
||||||
|
L = L * alpha + l_j
|
||||||
|
M = m_j
|
||||||
|
|
||||||
|
# acc : (num_queries_per_kv, BLOCK_SIZE,)
|
||||||
|
acc += tl.dot(P.to(V.dtype), V)
|
||||||
|
|
||||||
|
# epilogue
|
||||||
|
acc = acc / L[:, None]
|
||||||
|
if USE_FP8:
|
||||||
|
acc = acc * tl.load(out_scale_inv)
|
||||||
|
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||||
|
|
||||||
|
output_offset = (
|
||||||
|
cur_batch_in_all_start_index * output_stride_0
|
||||||
|
+ query_head_idx * output_stride_1
|
||||||
|
)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
output_ptr + output_offset[:, None] + tl.arange(0, HEAD_SIZE_PADDED)[None, :],
|
||||||
|
acc,
|
||||||
|
mask=dim_mask[None, :] & head_mask[:, None],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def chunked_prefill_paged_decode(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
output,
|
||||||
|
kv_cache_dtype,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
block_table,
|
||||||
|
query_start_loc,
|
||||||
|
seq_lens,
|
||||||
|
max_seq_len,
|
||||||
|
max_query_len,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
alibi_slopes=None,
|
||||||
|
sliding_window=None,
|
||||||
|
sm_scale=None,
|
||||||
|
output_scale=None,
|
||||||
|
# Optional tensor for sinks
|
||||||
|
sinks=None,
|
||||||
|
):
|
||||||
|
if sm_scale is None:
|
||||||
|
sm_scale = 1.0 / (query.shape[1] ** 0.5)
|
||||||
|
|
||||||
|
use_alibi_slopes = alibi_slopes is not None
|
||||||
|
|
||||||
|
if sliding_window is None or sliding_window <= 0:
|
||||||
|
sliding_window = 0
|
||||||
|
|
||||||
|
if max_query_len > 1:
|
||||||
|
context_attention_fwd(
|
||||||
|
q=query,
|
||||||
|
k=key,
|
||||||
|
v=value,
|
||||||
|
o=output,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
k_cache=key_cache,
|
||||||
|
v_cache=value_cache,
|
||||||
|
b_loc=block_table,
|
||||||
|
b_start_loc=query_start_loc,
|
||||||
|
b_seq_len=seq_lens,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
max_input_len=max_query_len,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
|
alibi_slopes=alibi_slopes,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
sm_scale=sm_scale,
|
||||||
|
skip_decode=True,
|
||||||
|
fp8_out_scale=output_scale,
|
||||||
|
sinks=sinks,
|
||||||
|
)
|
||||||
|
|
||||||
|
block_size = value_cache.shape[3]
|
||||||
|
num_seqs = len(seq_lens)
|
||||||
|
num_query_heads = query.shape[1]
|
||||||
|
num_kv_heads = key.shape[1]
|
||||||
|
num_queries_per_kv = query.shape[1] // key.shape[1]
|
||||||
|
head_size = query.shape[2]
|
||||||
|
|
||||||
|
# Conversion of FP8 Tensor from uint8 storage to
|
||||||
|
# appropriate torch.dtype for interpretation by Triton
|
||||||
|
if "fp8" in kv_cache_dtype:
|
||||||
|
assert key_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||||
|
assert value_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||||
|
|
||||||
|
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||||
|
target_dtype = current_platform.fp8_dtype()
|
||||||
|
elif kv_cache_dtype == "fp8_e5m2":
|
||||||
|
target_dtype = torch.float8_e5m2
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
|
||||||
|
|
||||||
|
key_cache = key_cache.view(target_dtype)
|
||||||
|
value_cache = value_cache.view(target_dtype)
|
||||||
|
|
||||||
|
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), 16)
|
||||||
|
|
||||||
|
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||||
|
|
||||||
|
use_custom = use_rocm_custom_paged_attention(
|
||||||
|
query.dtype,
|
||||||
|
head_size,
|
||||||
|
block_size,
|
||||||
|
num_queries_per_kv,
|
||||||
|
max_seq_len,
|
||||||
|
sliding_window,
|
||||||
|
kv_cache_dtype,
|
||||||
|
alibi_slopes,
|
||||||
|
sinks,
|
||||||
|
)
|
||||||
|
if use_custom:
|
||||||
|
_PARTITION_SIZE_ROCM = 256
|
||||||
|
max_num_partitions = (
|
||||||
|
max_seq_len + _PARTITION_SIZE_ROCM - 1
|
||||||
|
) // _PARTITION_SIZE_ROCM
|
||||||
|
assert _PARTITION_SIZE_ROCM % block_size == 0
|
||||||
|
total_num_seq = block_table.shape[0]
|
||||||
|
tmp_output = torch.empty(
|
||||||
|
size=(total_num_seq, num_query_heads, max_num_partitions, head_size),
|
||||||
|
dtype=query.dtype,
|
||||||
|
device=output.device,
|
||||||
|
)
|
||||||
|
exp_sums = torch.empty(
|
||||||
|
size=(total_num_seq, num_query_heads, max_num_partitions),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=output.device,
|
||||||
|
)
|
||||||
|
max_logits = torch.empty_like(exp_sums)
|
||||||
|
|
||||||
|
ops.paged_attention_rocm(
|
||||||
|
output,
|
||||||
|
exp_sums,
|
||||||
|
max_logits,
|
||||||
|
tmp_output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
num_kv_heads,
|
||||||
|
scale=sm_scale,
|
||||||
|
block_tables=block_table,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
query_start_loc=query_start_loc,
|
||||||
|
block_size=block_size,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
alibi_slopes=alibi_slopes,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
|
fp8_out_scale=output_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kernel_paged_attention_2d[
|
||||||
|
(
|
||||||
|
num_seqs,
|
||||||
|
num_kv_heads,
|
||||||
|
)
|
||||||
|
](
|
||||||
|
output_ptr=output,
|
||||||
|
query_ptr=query,
|
||||||
|
key_cache_ptr=key_cache,
|
||||||
|
value_cache_ptr=value_cache,
|
||||||
|
sink_ptr=sinks,
|
||||||
|
block_tables_ptr=block_table,
|
||||||
|
seq_lens_ptr=seq_lens,
|
||||||
|
alibi_slopes_ptr=alibi_slopes,
|
||||||
|
scale=sm_scale,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
|
out_scale_inv=1.0 / output_scale if output_scale is not None else 1.0,
|
||||||
|
num_query_heads=num_query_heads,
|
||||||
|
num_queries_per_kv=num_queries_per_kv,
|
||||||
|
num_queries_per_kv_padded=num_queries_per_kv_padded,
|
||||||
|
block_table_stride=block_table.stride(0),
|
||||||
|
query_stride_0=query.stride(0),
|
||||||
|
query_stride_1=query.stride(1),
|
||||||
|
output_stride_0=output.stride(0),
|
||||||
|
output_stride_1=output.stride(1),
|
||||||
|
BLOCK_SIZE=block_size,
|
||||||
|
HEAD_SIZE=head_size,
|
||||||
|
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||||
|
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||||
|
SLIDING_WINDOW=sliding_window,
|
||||||
|
x=key_cache.shape[4],
|
||||||
|
stride_k_cache_0=key_cache.stride(0),
|
||||||
|
stride_k_cache_1=key_cache.stride(1),
|
||||||
|
stride_k_cache_2=key_cache.stride(2),
|
||||||
|
stride_k_cache_3=key_cache.stride(3),
|
||||||
|
stride_k_cache_4=key_cache.stride(4),
|
||||||
|
stride_v_cache_0=value_cache.stride(0),
|
||||||
|
stride_v_cache_1=value_cache.stride(1),
|
||||||
|
stride_v_cache_2=value_cache.stride(2),
|
||||||
|
stride_v_cache_3=value_cache.stride(3),
|
||||||
|
filter_by_query_len=True,
|
||||||
|
query_start_len_ptr=query_start_loc,
|
||||||
|
USE_SINKS=sinks is not None,
|
||||||
|
USE_FP8=output_scale is not None,
|
||||||
|
)
|
||||||
414
attention/ops/common.py
Normal file
414
attention/ops/common.py
Normal file
@@ -0,0 +1,414 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.distributed.parallel_state import GroupCoordinator
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _correct_attn_cp_out_kernel(
|
||||||
|
outputs_ptr,
|
||||||
|
new_output_ptr,
|
||||||
|
lses_ptr,
|
||||||
|
vlse_ptr,
|
||||||
|
outputs_stride_B,
|
||||||
|
outputs_stride_H,
|
||||||
|
outputs_stride_D,
|
||||||
|
lses_stride_N,
|
||||||
|
lses_stride_B,
|
||||||
|
lses_stride_H,
|
||||||
|
lse_idx,
|
||||||
|
HEAD_DIM: tl.constexpr,
|
||||||
|
N_ROUNDED: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Apply the all-gathered lses to correct each local rank's attention
|
||||||
|
output. we still need perform a cross-rank reduction to obtain the
|
||||||
|
final attention output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs_ptr (triton.PointerType):
|
||||||
|
Pointer to input tensor of shape [ B, H, D ]
|
||||||
|
lses_ptr (triton.PointerType):
|
||||||
|
Pointer to input tensor of shape [ N, B, H ]
|
||||||
|
new_output_ptr (triton.PointerType):
|
||||||
|
Pointer to output tensor of shape [ B, H, D ]
|
||||||
|
vlse_ptr (triton.PointerType):
|
||||||
|
Pointer to output tensor of shape [ B, H ]
|
||||||
|
"""
|
||||||
|
batch_idx = tl.program_id(axis=0).to(tl.int64)
|
||||||
|
head_idx = tl.program_id(axis=1).to(tl.int64)
|
||||||
|
d_offsets = tl.arange(0, HEAD_DIM)
|
||||||
|
num_n_offsets = tl.arange(0, N_ROUNDED)
|
||||||
|
|
||||||
|
# shape = [N]
|
||||||
|
lse_offsets = (
|
||||||
|
num_n_offsets * lses_stride_N
|
||||||
|
+ batch_idx * lses_stride_B
|
||||||
|
+ head_idx * lses_stride_H
|
||||||
|
)
|
||||||
|
|
||||||
|
# calc final lse
|
||||||
|
lse = tl.load(lses_ptr + lse_offsets)
|
||||||
|
lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse)
|
||||||
|
lse_max = tl.max(lse, axis=0)
|
||||||
|
lse_max = tl.where(lse_max == -float("inf"), 0, lse_max)
|
||||||
|
lse -= lse_max
|
||||||
|
lse_exp = tl.exp(lse)
|
||||||
|
lse_acc = tl.sum(lse_exp, axis=0)
|
||||||
|
lse = tl.log(lse_acc)
|
||||||
|
lse += lse_max
|
||||||
|
|
||||||
|
lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
|
||||||
|
tl.store(vlse_ptr + lse_offsets, lse)
|
||||||
|
|
||||||
|
# shape = [D]
|
||||||
|
output_offsets = (
|
||||||
|
batch_idx * outputs_stride_B
|
||||||
|
+ head_idx * outputs_stride_H
|
||||||
|
+ d_offsets * outputs_stride_D
|
||||||
|
)
|
||||||
|
|
||||||
|
# correct output
|
||||||
|
lse_offset = (
|
||||||
|
lse_idx * lses_stride_N + batch_idx * lses_stride_B + head_idx * lses_stride_H
|
||||||
|
)
|
||||||
|
lse_tmp = tl.load(lses_ptr + lse_offset)
|
||||||
|
lse_finally = lse_tmp - lse
|
||||||
|
lse_finally = tl.where(
|
||||||
|
(lse_finally != lse_finally) | (lse_finally == float("inf")),
|
||||||
|
-float("inf"),
|
||||||
|
lse_finally,
|
||||||
|
)
|
||||||
|
factor = tl.exp(lse_finally)
|
||||||
|
output = tl.load(outputs_ptr + output_offsets)
|
||||||
|
output = output * factor
|
||||||
|
|
||||||
|
tl.store(new_output_ptr + output_offsets, output)
|
||||||
|
|
||||||
|
|
||||||
|
class CPTritonContext:
|
||||||
|
"""The CPTritonContext is used to avoid recompilation of the Triton JIT."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.inner_kernel = None
|
||||||
|
|
||||||
|
def call_kernel(self, kernel, grid, *regular_args, **const_args):
|
||||||
|
if self.inner_kernel is None:
|
||||||
|
self.inner_kernel = kernel[grid](*regular_args, **const_args)
|
||||||
|
else:
|
||||||
|
self.inner_kernel[grid](*regular_args)
|
||||||
|
|
||||||
|
|
||||||
|
def correct_attn_out(
|
||||||
|
out: torch.Tensor, lses: torch.Tensor, cp_rank: int, ctx: CPTritonContext
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Correct the attention output using the all-gathered lses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out: Tensor of shape [ B, H, D ]
|
||||||
|
lses: Tensor of shape [ N, B, H ]
|
||||||
|
cp_rank: Current rank in the context-parallel group
|
||||||
|
ctx: Triton context to avoid recompilation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (out, lse) with corrected attention and final log-sum-exp.
|
||||||
|
"""
|
||||||
|
if ctx is None:
|
||||||
|
ctx = CPTritonContext()
|
||||||
|
|
||||||
|
# --- Normalize to 3D views ---
|
||||||
|
if out.ndim == 4 and out.shape[1] == 1:
|
||||||
|
out = out.squeeze(1)
|
||||||
|
assert out.ndim == 3, f"expected out [B,H,D] or [B,1,H,D], got {tuple(out.shape)}"
|
||||||
|
|
||||||
|
if lses.ndim == 4 and lses.shape[-1] == 1:
|
||||||
|
lses = lses.squeeze(-1)
|
||||||
|
if lses.ndim == 4 and lses.shape[1] == 1:
|
||||||
|
lses = lses.squeeze(1)
|
||||||
|
assert lses.ndim == 3, (
|
||||||
|
f"expected lses [N,B,H] (optionally with a 1-sized extra dim), "
|
||||||
|
f"got {tuple(lses.shape)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
B, H, D = out.shape
|
||||||
|
N = lses.shape[0]
|
||||||
|
|
||||||
|
# Strides after we normalized shapes to 3-D views. The kernel computes
|
||||||
|
# offsets for `vlse_ptr` using lses_stride_B/H, so the output buffer must
|
||||||
|
# have the same B/H stride layout as a slice of `lses`.
|
||||||
|
o_sB, o_sH, o_sD = out.stride()
|
||||||
|
l_sN, l_sB, l_sH = lses.stride()
|
||||||
|
|
||||||
|
# Allocate LSE with the same B/H strides as `lses` so writes land correctly
|
||||||
|
# even when `lses` is a non-contiguous view (e.g., 4-D to 3-D squeeze).
|
||||||
|
lse = torch.empty_strided(
|
||||||
|
(B, H), (l_sB, l_sH), device=lses.device, dtype=lses.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# Kernel launch config
|
||||||
|
grid = (B, H, 1)
|
||||||
|
|
||||||
|
regular_args = (
|
||||||
|
out,
|
||||||
|
out,
|
||||||
|
lses,
|
||||||
|
lse,
|
||||||
|
o_sB,
|
||||||
|
o_sH,
|
||||||
|
o_sD,
|
||||||
|
l_sN,
|
||||||
|
l_sB,
|
||||||
|
l_sH,
|
||||||
|
cp_rank,
|
||||||
|
)
|
||||||
|
const_args = {"HEAD_DIM": D, "N_ROUNDED": N}
|
||||||
|
|
||||||
|
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args)
|
||||||
|
return out, lse
|
||||||
|
|
||||||
|
|
||||||
|
def cp_lse_ag_out_rs(
|
||||||
|
cp_attn_out: torch.Tensor,
|
||||||
|
cp_attn_lse: torch.Tensor,
|
||||||
|
cp_group: GroupCoordinator,
|
||||||
|
ctx: CPTritonContext = None,
|
||||||
|
return_lse=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
cp_attn_out: [ B, H, D ]
|
||||||
|
cp_attn_lse: [ B, H ]
|
||||||
|
"""
|
||||||
|
if cp_group.world_size == 1:
|
||||||
|
return cp_attn_out
|
||||||
|
|
||||||
|
if ctx is None:
|
||||||
|
ctx = CPTritonContext()
|
||||||
|
|
||||||
|
lses = torch.empty(
|
||||||
|
(cp_group.world_size,) + cp_attn_lse.shape,
|
||||||
|
dtype=cp_attn_lse.dtype,
|
||||||
|
device=cp_attn_lse.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
cp_attn_lse = cp_attn_lse.contiguous()
|
||||||
|
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
||||||
|
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
||||||
|
out = cp_group.reduce_scatter(out, dim=1)
|
||||||
|
|
||||||
|
if return_lse:
|
||||||
|
cp_num_heads = lse.shape[1] // cp_group.world_size
|
||||||
|
cp_rank = cp_group.rank_in_group
|
||||||
|
lse = lse[:, cp_num_heads * cp_rank : cp_num_heads * (cp_rank + 1)]
|
||||||
|
return out, lse
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _pack_seq_kernel(
|
||||||
|
x_ptr, # [N, D]
|
||||||
|
out_ptr, # [B, Lmax, D]
|
||||||
|
lengths_ptr, # *i32, [B]
|
||||||
|
N: tl.constexpr,
|
||||||
|
D: tl.constexpr,
|
||||||
|
Lmax: tl.constexpr,
|
||||||
|
PAD_VALUE: tl.constexpr,
|
||||||
|
BLOCK_T: tl.constexpr, # timesteps per program
|
||||||
|
BLOCK_D: tl.constexpr, # features per program
|
||||||
|
):
|
||||||
|
pid_b = tl.program_id(0) # batch id
|
||||||
|
pid_t = tl.program_id(1) # block over time dimension
|
||||||
|
pid_d = tl.program_id(2) # block over feature dimension
|
||||||
|
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
|
||||||
|
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
|
||||||
|
|
||||||
|
# Compute start index and sequence length from cumulative lengths
|
||||||
|
in_start = 0
|
||||||
|
for i in range(pid_b):
|
||||||
|
in_start += tl.load(lengths_ptr + i)
|
||||||
|
seq_len = tl.load(lengths_ptr + pid_b)
|
||||||
|
|
||||||
|
# valid time positions for this block
|
||||||
|
t_mask = off_t < Lmax
|
||||||
|
|
||||||
|
# compute input row indices for valid (b, t)
|
||||||
|
in_row = in_start + off_t
|
||||||
|
valid_row = (off_t < seq_len) & t_mask
|
||||||
|
|
||||||
|
# Pointers
|
||||||
|
# x_ptr: row-major [N, D]
|
||||||
|
x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :]
|
||||||
|
|
||||||
|
# out_ptr: row-major [B, Lmax, D]
|
||||||
|
out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :]
|
||||||
|
|
||||||
|
# Initialize with PAD (cast will occur as needed based on out_ptr dtype)
|
||||||
|
d_mask = off_d[None, :] < D
|
||||||
|
pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32)
|
||||||
|
tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask)
|
||||||
|
|
||||||
|
# Load & write only where within seq_len
|
||||||
|
x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask)
|
||||||
|
tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def pack_seq_triton(
|
||||||
|
x: torch.Tensor,
|
||||||
|
lengths: torch.Tensor,
|
||||||
|
pad_value: float = -float("inf"),
|
||||||
|
block_t: int = 64,
|
||||||
|
block_d: int = 64,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Pack sequences of different lengths into a batched tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: [N, ...] - input tensor where N is total number of tokens
|
||||||
|
lengths: [B] - sequence lengths for each batch
|
||||||
|
pad_value: value to use for padding
|
||||||
|
block_t: block size for time dimension
|
||||||
|
block_d: block size for feature dimension
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
packed: [B, Lmax, ...] - packed tensor
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Handle multi-dimensional input by reshaping to (N, -1)
|
||||||
|
original_shape = x.shape
|
||||||
|
if len(original_shape) > 2:
|
||||||
|
N = original_shape[0]
|
||||||
|
x_reshaped = x.reshape(N, -1)
|
||||||
|
D = x_reshaped.shape[1]
|
||||||
|
else:
|
||||||
|
N, D = x.shape
|
||||||
|
x_reshaped = x
|
||||||
|
|
||||||
|
B = lengths.numel()
|
||||||
|
Lmax = int(lengths.max().item())
|
||||||
|
|
||||||
|
# Starts are computed inside the kernel from lengths
|
||||||
|
|
||||||
|
out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
|
||||||
|
_pack_seq_kernel[grid](
|
||||||
|
x_reshaped,
|
||||||
|
out,
|
||||||
|
lengths.int(),
|
||||||
|
N,
|
||||||
|
D,
|
||||||
|
Lmax,
|
||||||
|
PAD_VALUE=float(pad_value),
|
||||||
|
BLOCK_T=block_t,
|
||||||
|
BLOCK_D=block_d,
|
||||||
|
num_warps=4,
|
||||||
|
num_stages=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape output back to original dimensions (except first dimension)
|
||||||
|
if len(original_shape) > 2:
|
||||||
|
output_shape = (B, Lmax) + original_shape[1:]
|
||||||
|
out = out.reshape(output_shape)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _unpack_seq_triton_kernel(
|
||||||
|
packed_ptr, # [B, Lmax, D]
|
||||||
|
out_ptr, # [N, D]
|
||||||
|
lengths_ptr, # *i32, [B]
|
||||||
|
B: tl.constexpr,
|
||||||
|
Lmax: tl.constexpr,
|
||||||
|
D: tl.constexpr,
|
||||||
|
BLOCK_T: tl.constexpr, # timesteps per program
|
||||||
|
BLOCK_D: tl.constexpr, # features per program
|
||||||
|
):
|
||||||
|
pid_b = tl.program_id(0) # batch id
|
||||||
|
pid_t = tl.program_id(1) # block over time dimension
|
||||||
|
pid_d = tl.program_id(2) # block over feature dimension
|
||||||
|
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
|
||||||
|
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
|
||||||
|
|
||||||
|
# bounds: compute start from cumulative lengths
|
||||||
|
in_start = 0
|
||||||
|
for i in range(pid_b):
|
||||||
|
in_start += tl.load(lengths_ptr + i)
|
||||||
|
seq_len = tl.load(lengths_ptr + pid_b)
|
||||||
|
|
||||||
|
# valid time positions for this block
|
||||||
|
t_mask = off_t < Lmax
|
||||||
|
valid_row = (off_t < seq_len) & t_mask
|
||||||
|
|
||||||
|
# compute output row indices for valid (b, t)
|
||||||
|
out_row = in_start + off_t
|
||||||
|
|
||||||
|
# Pointers
|
||||||
|
# packed_ptr: row-major [B, Lmax, D]
|
||||||
|
packed_row_ptr = packed_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :]
|
||||||
|
|
||||||
|
# out_ptr: row-major [N, D]
|
||||||
|
out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :]
|
||||||
|
|
||||||
|
# Load from packed tensor and store to output
|
||||||
|
d_mask = off_d[None, :] < D
|
||||||
|
packed_vals = tl.load(packed_row_ptr, mask=valid_row[:, None] & d_mask)
|
||||||
|
tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def unpack_seq_triton(
|
||||||
|
packed_tensor: torch.Tensor,
|
||||||
|
lengths: torch.Tensor,
|
||||||
|
block_t: int = 64,
|
||||||
|
block_d: int = 64,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Unpack a packed decode query tensor back to the original format.
|
||||||
|
Efficient Triton implementation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
packed_tensor: [B, Lmax, ...] - packed tensor from pack_seq_triton
|
||||||
|
lengths: [B] - sequence lengths for each batch
|
||||||
|
block_t: block size for time dimension
|
||||||
|
block_d: block size for feature dimension
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
unpacked_tensor: [N, ...] where N = sum(lengths)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Handle multi-dimensional input by reshaping to (B, Lmax, -1)
|
||||||
|
original_shape = packed_tensor.shape
|
||||||
|
if len(original_shape) > 3:
|
||||||
|
B, Lmax = original_shape[:2]
|
||||||
|
packed_reshaped = packed_tensor.reshape(B, Lmax, -1)
|
||||||
|
D = packed_reshaped.shape[2]
|
||||||
|
else:
|
||||||
|
B, Lmax, D = packed_tensor.shape
|
||||||
|
packed_reshaped = packed_tensor
|
||||||
|
|
||||||
|
# Calculate total number of elements
|
||||||
|
N = int(lengths.sum().item())
|
||||||
|
|
||||||
|
out = torch.empty((N, D), device=packed_tensor.device, dtype=packed_tensor.dtype)
|
||||||
|
|
||||||
|
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
|
||||||
|
_unpack_seq_triton_kernel[grid](
|
||||||
|
packed_reshaped,
|
||||||
|
out,
|
||||||
|
lengths.int(),
|
||||||
|
B,
|
||||||
|
Lmax,
|
||||||
|
D,
|
||||||
|
BLOCK_T=block_t,
|
||||||
|
BLOCK_D=block_d,
|
||||||
|
num_warps=4,
|
||||||
|
num_stages=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape output back to original dimensions (except first dimension)
|
||||||
|
if len(original_shape) > 3:
|
||||||
|
output_shape = (N,) + original_shape[2:]
|
||||||
|
out = out.reshape(output_shape)
|
||||||
|
|
||||||
|
return out
|
||||||
252
attention/ops/flashmla.py
Normal file
252
attention/ops/flashmla.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
try:
|
||||||
|
import vllm._flashmla_C # noqa: F401
|
||||||
|
|
||||||
|
_flashmla_C_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
_flashmla_C_AVAILABLE = False
|
||||||
|
else:
|
||||||
|
_flashmla_C_AVAILABLE = False
|
||||||
|
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
try:
|
||||||
|
import vllm._flashmla_extension_C # noqa: F401
|
||||||
|
|
||||||
|
_flashmla_extension_C_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
_flashmla_extension_C_AVAILABLE = False
|
||||||
|
else:
|
||||||
|
_flashmla_extension_C_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_flashmla_available() -> tuple[bool, str | None]:
|
||||||
|
if not _flashmla_C_AVAILABLE:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"vllm._flashmla_C is not available, likely was not "
|
||||||
|
"compiled due to insufficient nvcc version or a supported arch "
|
||||||
|
"was not in the list of target arches to compile for.",
|
||||||
|
)
|
||||||
|
if not _flashmla_extension_C_AVAILABLE:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"vllm._flashmla_extension_C is not available, likely "
|
||||||
|
"was not compiled due to a build error.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
def is_flashmla_dense_supported() -> tuple[bool, str | None]:
|
||||||
|
"""
|
||||||
|
Return: is_supported_flag, unsupported_reason (optional).
|
||||||
|
"""
|
||||||
|
is_availble, maybe_reason = _is_flashmla_available()
|
||||||
|
if not is_availble:
|
||||||
|
return False, maybe_reason
|
||||||
|
if current_platform.get_device_capability()[0] != 9:
|
||||||
|
return False, "FlashMLA Dense is only supported on Hopper devices."
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
def is_flashmla_sparse_supported() -> tuple[bool, str | None]:
|
||||||
|
"""
|
||||||
|
Return: is_supported_flag, unsupported_reason (optional).
|
||||||
|
"""
|
||||||
|
is_availble, maybe_reason = _is_flashmla_available()
|
||||||
|
if not is_availble:
|
||||||
|
return False, maybe_reason
|
||||||
|
if current_platform.get_device_capability()[0] not in (9, 10):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"FlashMLA Sparse is only supported on Hopper and Blackwell devices.",
|
||||||
|
)
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
def get_mla_metadata(
|
||||||
|
cache_seqlens: torch.Tensor,
|
||||||
|
num_q_tokens_per_head_k: int,
|
||||||
|
num_heads_k: int,
|
||||||
|
num_heads_q: int | None = None,
|
||||||
|
is_fp8_kvcache: bool = False,
|
||||||
|
topk: int | None = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
- cache_seqlens: (batch_size), dtype torch.int32.
|
||||||
|
- num_q_tokens_per_head_k:
|
||||||
|
Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
|
||||||
|
- num_heads_k: The number of k heads.
|
||||||
|
- num_heads_q:
|
||||||
|
The number of q heads.
|
||||||
|
This argument is optional when sparse attention is not enabled
|
||||||
|
- is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
|
||||||
|
- topk: If not None, sparse attention will be enabled,
|
||||||
|
and only tokens in the `indices` array
|
||||||
|
passed to `flash_mla_with_kvcache_sm90` will be attended to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- tile_scheduler_metadata:
|
||||||
|
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
||||||
|
- num_splits: (batch_size + 1), dtype torch.int32.
|
||||||
|
"""
|
||||||
|
if is_fp8_kvcache and topk is None:
|
||||||
|
return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
|
||||||
|
cache_seqlens,
|
||||||
|
num_q_tokens_per_head_k,
|
||||||
|
num_heads_k,
|
||||||
|
)
|
||||||
|
return torch.ops._flashmla_C.get_mla_decoding_metadata(
|
||||||
|
cache_seqlens,
|
||||||
|
num_q_tokens_per_head_k,
|
||||||
|
num_heads_k,
|
||||||
|
num_heads_q,
|
||||||
|
is_fp8_kvcache,
|
||||||
|
topk,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def flash_mla_with_kvcache(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
block_table: torch.Tensor,
|
||||||
|
cache_seqlens: torch.Tensor,
|
||||||
|
head_dim_v: int,
|
||||||
|
tile_scheduler_metadata: torch.Tensor,
|
||||||
|
num_splits: torch.Tensor,
|
||||||
|
softmax_scale: float | None = None,
|
||||||
|
causal: bool = False,
|
||||||
|
descale_q: torch.Tensor | None = None,
|
||||||
|
descale_k: torch.Tensor | None = None,
|
||||||
|
is_fp8_kvcache: bool = False,
|
||||||
|
indices: torch.Tensor | None = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
- q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||||
|
- k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||||
|
- block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||||
|
- cache_seqlens: (batch_size), torch.int32.
|
||||||
|
- head_dim_v: Head dimension of v.
|
||||||
|
- tile_scheduler_metadata:
|
||||||
|
(num_sm_parts, TileSchedulerMetaDataSize), torch.int32,
|
||||||
|
returned by get_mla_metadata.
|
||||||
|
- num_splits:
|
||||||
|
(batch_size + 1), torch.int32, returned by get_mla_metadata.
|
||||||
|
- softmax_scale: float.
|
||||||
|
The scale of QK^T before applying softmax.
|
||||||
|
Default to 1 / sqrt(head_dim).
|
||||||
|
- causal: bool. Whether to apply causal attention mask.
|
||||||
|
- descale_q: (batch_size),
|
||||||
|
torch.float32. Descaling factors for Q, used for fp8 quantization.
|
||||||
|
- descale_k: (batch_size),
|
||||||
|
torch.float32. Descaling factors for K, used for fp8 quantization.
|
||||||
|
- is_fp8_kvcache: bool.
|
||||||
|
Whether the k_cache and v_cache are in fp8 format.
|
||||||
|
For the format of FP8 KV cache, please refer to README.md
|
||||||
|
- indices: (batch_size, seq_len_q, topk), torch.int32.
|
||||||
|
If not None, sparse attention will be enabled,
|
||||||
|
and only tokens in the `indices` array will be attended to.
|
||||||
|
Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
|
||||||
|
For details about how to set up `indices`, please refer to README.md.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||||
|
- softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||||
|
"""
|
||||||
|
if softmax_scale is None:
|
||||||
|
softmax_scale = q.shape[-1] ** (-0.5)
|
||||||
|
if indices is not None:
|
||||||
|
# NOTE (zyongye): sparse attention is also causal
|
||||||
|
# since it only attend to the tokens before
|
||||||
|
# but here `causal` should not be specified
|
||||||
|
assert not causal, "causal must be `false` if sparse attention is enabled."
|
||||||
|
assert (descale_q is None) == (descale_k is None), (
|
||||||
|
"descale_q and descale_k should be both None or both not None"
|
||||||
|
)
|
||||||
|
|
||||||
|
if indices is None and q.element_size() == 1:
|
||||||
|
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
|
||||||
|
q,
|
||||||
|
k_cache,
|
||||||
|
head_dim_v,
|
||||||
|
cache_seqlens,
|
||||||
|
block_table,
|
||||||
|
softmax_scale,
|
||||||
|
causal,
|
||||||
|
tile_scheduler_metadata,
|
||||||
|
num_splits,
|
||||||
|
descale_q,
|
||||||
|
descale_k,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
|
||||||
|
q,
|
||||||
|
k_cache,
|
||||||
|
head_dim_v,
|
||||||
|
cache_seqlens,
|
||||||
|
block_table,
|
||||||
|
softmax_scale,
|
||||||
|
causal,
|
||||||
|
tile_scheduler_metadata,
|
||||||
|
num_splits,
|
||||||
|
is_fp8_kvcache,
|
||||||
|
indices,
|
||||||
|
)
|
||||||
|
return out, softmax_lse
|
||||||
|
|
||||||
|
|
||||||
|
def flash_mla_sparse_prefill(
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv: torch.Tensor,
|
||||||
|
indices: torch.Tensor,
|
||||||
|
sm_scale: float,
|
||||||
|
d_v: int = 512,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Sparse attention prefill kernel
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- q: [s_q, h_q, d_qk], bfloat16
|
||||||
|
- kv: [s_kv, h_kv, d_qk], bfloat16
|
||||||
|
- indices: [s_q, h_kv, topk], int32.
|
||||||
|
Invalid indices should be set to -1 or numbers >= s_kv
|
||||||
|
- sm_scale: float
|
||||||
|
- d_v: The dimension of value vectors. Can only be 512
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- (output, max_logits, lse)
|
||||||
|
About the definition of output,
|
||||||
|
max_logits and lse, please refer to README.md
|
||||||
|
- output: [s_q, h_q, d_v], bfloat16
|
||||||
|
- max_logits: [s_q, h_q], float
|
||||||
|
- lse: [s_q, h_q], float, 2-based log-sum-exp
|
||||||
|
"""
|
||||||
|
results = ops.sparse_prefill_fwd(q, kv, indices,sm_scale, d_v)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# TODO: Add fake functions
|
||||||
|
#
|
||||||
|
# @register_fake("_flashmla_C::get_mla_metadata")
|
||||||
|
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# return ....
|
||||||
|
#
|
||||||
|
# @register_fake("_flashmla_C::fwd_kvcache_mla")
|
||||||
|
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# return ....
|
||||||
|
#
|
||||||
47
attention/ops/merge_attn_states.py
Normal file
47
attention/ops/merge_attn_states.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
def merge_attn_states(
|
||||||
|
output: torch.Tensor,
|
||||||
|
prefix_output: torch.Tensor,
|
||||||
|
prefix_lse: torch.Tensor,
|
||||||
|
suffix_output: torch.Tensor,
|
||||||
|
suffix_lse: torch.Tensor,
|
||||||
|
output_lse: torch.Tensor | None = None,
|
||||||
|
) -> None:
|
||||||
|
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel
|
||||||
|
# is not support for FP8 dtype, fallback to use Triton kernel.
|
||||||
|
def supported_dtypes(o: torch.Tensor) -> bool:
|
||||||
|
return o.dtype in [torch.float32, torch.half, torch.bfloat16]
|
||||||
|
|
||||||
|
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA
|
||||||
|
# kernel load/store 128b(16 bytes) per memory issue within
|
||||||
|
# thread. Namely, the headsize(headdim) must be multiple of
|
||||||
|
# pack_size (float32 -> 4, half/bfloat16 -> 8).
|
||||||
|
def supported_headdim(o: torch.Tensor) -> bool:
|
||||||
|
headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||||
|
if o.dtype == torch.float32:
|
||||||
|
return headdim % 4 == 0
|
||||||
|
return headdim % 8 == 0
|
||||||
|
|
||||||
|
if (
|
||||||
|
current_platform.is_cuda()
|
||||||
|
and supported_dtypes(output)
|
||||||
|
and supported_headdim(output)
|
||||||
|
):
|
||||||
|
from vllm._custom_ops import merge_attn_states
|
||||||
|
|
||||||
|
return merge_attn_states(
|
||||||
|
output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||||
|
|
||||||
|
return merge_attn_states(
|
||||||
|
output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse
|
||||||
|
)
|
||||||
262
attention/ops/paged_attn.py
Normal file
262
attention/ops/paged_attn.py
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import HAS_TRITON
|
||||||
|
|
||||||
|
if current_platform.is_cuda_alike():
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
elif current_platform.is_xpu():
|
||||||
|
from vllm._ipex_ops import ipex_ops as ops
|
||||||
|
|
||||||
|
if HAS_TRITON:
|
||||||
|
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||||
|
|
||||||
|
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||||
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PagedAttentionMetadata:
|
||||||
|
"""Metadata for PagedAttention."""
|
||||||
|
|
||||||
|
# (batch_size,). The length of sequences (entire tokens seen so far) per
|
||||||
|
# sequence.
|
||||||
|
seq_lens_tensor: torch.Tensor | None
|
||||||
|
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
|
||||||
|
max_decode_seq_len: int
|
||||||
|
# (batch_size, max_blocks_per_seq).
|
||||||
|
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||||
|
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||||
|
# in the kv cache. Each block can contain up to block_size tokens.
|
||||||
|
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||||
|
# captured.
|
||||||
|
block_tables: torch.Tensor | None
|
||||||
|
|
||||||
|
|
||||||
|
class PagedAttention:
|
||||||
|
@staticmethod
|
||||||
|
def get_supported_head_sizes() -> list[int]:
|
||||||
|
return [32, 64, 80, 96, 112, 120, 128, 192, 256]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_kv_cache_shape(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
cache_dtype_str: str = "auto",
|
||||||
|
) -> tuple[int, ...]:
|
||||||
|
return (2, num_blocks, block_size * num_kv_heads * head_size)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def split_kv_cache(
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
x = 16 // kv_cache.element_size()
|
||||||
|
num_blocks = kv_cache.shape[1]
|
||||||
|
|
||||||
|
key_cache = kv_cache[0]
|
||||||
|
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x)
|
||||||
|
value_cache = kv_cache[1]
|
||||||
|
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
||||||
|
return key_cache, value_cache
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def write_to_paged_cache(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: torch.Tensor,
|
||||||
|
v_scale: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
ops.reshape_and_cache(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
slot_mapping.flatten(),
|
||||||
|
kv_cache_dtype,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward_decode(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
max_seq_len: int,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
num_kv_heads: int,
|
||||||
|
scale: float,
|
||||||
|
alibi_slopes: torch.Tensor | None,
|
||||||
|
k_scale: torch.Tensor,
|
||||||
|
v_scale: torch.Tensor,
|
||||||
|
tp_rank: int = 0,
|
||||||
|
blocksparse_local_blocks: int = 0,
|
||||||
|
blocksparse_vert_stride: int = 0,
|
||||||
|
blocksparse_block_size: int = 64,
|
||||||
|
blocksparse_head_sliding_step: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||||
|
# use blocksparse paged attention
|
||||||
|
block_size = value_cache.size(-1)
|
||||||
|
assert (
|
||||||
|
blocksparse_block_size > 0 and blocksparse_block_size % block_size == 0
|
||||||
|
), (
|
||||||
|
f"{blocksparse_block_size=} needs to be a multiple of"
|
||||||
|
f"{block_size=} used in block_tables."
|
||||||
|
)
|
||||||
|
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
block_size = value_cache.shape[3]
|
||||||
|
num_seqs, num_heads, head_size = query.shape
|
||||||
|
max_num_partitions = (max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||||
|
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||||
|
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||||
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
|
# to parallelize.
|
||||||
|
# TODO(woosuk): Tune this heuristic.
|
||||||
|
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
|
||||||
|
use_v1 = max_seq_len <= 8192 and (
|
||||||
|
max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_v1:
|
||||||
|
# Run PagedAttention V1.
|
||||||
|
ops.paged_attention_v1(
|
||||||
|
output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
num_kv_heads,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
seq_lens,
|
||||||
|
block_size,
|
||||||
|
max_seq_len,
|
||||||
|
alibi_slopes,
|
||||||
|
kv_cache_dtype,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
tp_rank,
|
||||||
|
blocksparse_local_blocks,
|
||||||
|
blocksparse_vert_stride,
|
||||||
|
blocksparse_block_size,
|
||||||
|
blocksparse_head_sliding_step,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Run PagedAttention V2.
|
||||||
|
assert _PARTITION_SIZE % block_size == 0
|
||||||
|
tmp_output = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||||
|
dtype=output.dtype,
|
||||||
|
device=output.device,
|
||||||
|
)
|
||||||
|
exp_sums = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, max_num_partitions),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=output.device,
|
||||||
|
)
|
||||||
|
max_logits = torch.empty_like(exp_sums)
|
||||||
|
ops.paged_attention_v2(
|
||||||
|
output,
|
||||||
|
exp_sums,
|
||||||
|
max_logits,
|
||||||
|
tmp_output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
num_kv_heads,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
seq_lens,
|
||||||
|
block_size,
|
||||||
|
max_seq_len,
|
||||||
|
alibi_slopes,
|
||||||
|
kv_cache_dtype,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
tp_rank,
|
||||||
|
blocksparse_local_blocks,
|
||||||
|
blocksparse_vert_stride,
|
||||||
|
blocksparse_block_size,
|
||||||
|
blocksparse_head_sliding_step,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward_prefix(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
query_start_loc: torch.Tensor,
|
||||||
|
seq_lens_tensor: torch.Tensor,
|
||||||
|
max_query_len: int,
|
||||||
|
alibi_slopes: torch.Tensor | None,
|
||||||
|
sliding_window: int | None,
|
||||||
|
k_scale: torch.Tensor,
|
||||||
|
v_scale: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
max_seq_len = None
|
||||||
|
context_attention_fwd(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
output,
|
||||||
|
kv_cache_dtype,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
block_tables,
|
||||||
|
# query_start_loc is (batch_size + 1,)
|
||||||
|
query_start_loc,
|
||||||
|
seq_lens_tensor,
|
||||||
|
max_seq_len,
|
||||||
|
max_query_len,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
alibi_slopes,
|
||||||
|
sliding_window,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def swap_blocks(
|
||||||
|
src_kv_cache: torch.Tensor,
|
||||||
|
dst_kv_cache: torch.Tensor,
|
||||||
|
src_to_dst: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
src_key_cache = src_kv_cache[0]
|
||||||
|
dst_key_cache = dst_kv_cache[0]
|
||||||
|
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||||
|
|
||||||
|
src_value_cache = src_kv_cache[1]
|
||||||
|
dst_value_cache = dst_kv_cache[1]
|
||||||
|
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def copy_blocks(
|
||||||
|
kv_caches: list[torch.Tensor],
|
||||||
|
src_to_dists: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||||
|
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||||
|
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||||
130
attention/ops/pallas_kv_cache_update.py
Normal file
130
attention/ops/pallas_kv_cache_update.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import jax
|
||||||
|
from jax.experimental import pallas as pl
|
||||||
|
from jax.experimental.pallas import tpu as pltpu
|
||||||
|
|
||||||
|
from vllm.utils.math_utils import cdiv
|
||||||
|
|
||||||
|
|
||||||
|
def _kv_cache_update_kernel(
|
||||||
|
# Prefetch
|
||||||
|
slices_ref, # [3, padded_num_slices], list of (kv_cache_start,
|
||||||
|
# new_kv_start, slice_len)
|
||||||
|
num_slices_ref, # [1]
|
||||||
|
# Input
|
||||||
|
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
|
||||||
|
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
|
||||||
|
# head_dim]
|
||||||
|
# Output
|
||||||
|
_, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
|
||||||
|
# Scratch
|
||||||
|
scratch, # [num_slices_per_block, page_size, num_combined_kv_heads,
|
||||||
|
# head_dim]
|
||||||
|
sem,
|
||||||
|
):
|
||||||
|
async_copies = []
|
||||||
|
block_idx = pl.program_id(0)
|
||||||
|
num_slices_per_block = scratch.shape[0]
|
||||||
|
|
||||||
|
# Copy from new_kv_hbm_ref to scratch
|
||||||
|
for i in range(num_slices_per_block):
|
||||||
|
offset_i = i + block_idx * num_slices_per_block
|
||||||
|
new_kv_start = jax.lax.select(
|
||||||
|
offset_i < num_slices_ref[0], slices_ref[1, offset_i], 0
|
||||||
|
)
|
||||||
|
length = jax.lax.select(
|
||||||
|
offset_i < num_slices_ref[0], slices_ref[2, offset_i], 0
|
||||||
|
)
|
||||||
|
async_copy = pltpu.make_async_copy(
|
||||||
|
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
|
||||||
|
scratch.at[i, pl.ds(0, length), ...],
|
||||||
|
sem,
|
||||||
|
)
|
||||||
|
async_copy.start()
|
||||||
|
async_copies.append(async_copy)
|
||||||
|
|
||||||
|
for async_copy in async_copies:
|
||||||
|
async_copy.wait()
|
||||||
|
|
||||||
|
# Copy from scratch to kv_cache_hbm_ref
|
||||||
|
async_copies.clear()
|
||||||
|
for i in range(num_slices_per_block):
|
||||||
|
offset_i = i + block_idx * num_slices_per_block
|
||||||
|
kv_cache_start = jax.lax.select(
|
||||||
|
offset_i < num_slices_ref[0], slices_ref[0, offset_i], 0
|
||||||
|
)
|
||||||
|
length = jax.lax.select(
|
||||||
|
offset_i < num_slices_ref[0], slices_ref[2, offset_i], 0
|
||||||
|
)
|
||||||
|
async_copy = pltpu.make_async_copy(
|
||||||
|
scratch.at[i, pl.ds(0, length), ...],
|
||||||
|
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
|
||||||
|
sem,
|
||||||
|
)
|
||||||
|
async_copy.start()
|
||||||
|
async_copies.append(async_copy)
|
||||||
|
for async_copy in async_copies:
|
||||||
|
async_copy.wait()
|
||||||
|
|
||||||
|
|
||||||
|
@functools.partial(
|
||||||
|
jax.jit,
|
||||||
|
static_argnames=["page_size", "num_slices_per_block"],
|
||||||
|
)
|
||||||
|
def kv_cache_update(
|
||||||
|
# [total_num_token, num_combined_kv_heads, head_dim]
|
||||||
|
new_kv: jax.Array,
|
||||||
|
# [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
|
||||||
|
slices: jax.Array,
|
||||||
|
# [total_num_pages * page_size, num_combined_kv_heads, head_dim]
|
||||||
|
kv_cache: jax.Array,
|
||||||
|
# [1]
|
||||||
|
num_kv_update_slices: jax.Array,
|
||||||
|
*,
|
||||||
|
page_size: int = 32,
|
||||||
|
num_slices_per_block: int = 8,
|
||||||
|
):
|
||||||
|
_, num_combined_kv_heads, head_dim = new_kv.shape
|
||||||
|
assert kv_cache.shape[1] == num_combined_kv_heads
|
||||||
|
assert kv_cache.shape[2] == head_dim
|
||||||
|
assert head_dim % 128 == 0
|
||||||
|
# TODO: Add dynamic check to make sure that the all the slice lengths are
|
||||||
|
# smaller or equal to page_size
|
||||||
|
|
||||||
|
in_specs = [
|
||||||
|
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||||
|
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||||
|
]
|
||||||
|
|
||||||
|
out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
|
||||||
|
out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]
|
||||||
|
|
||||||
|
scalar_prefetches = [slices, num_kv_update_slices]
|
||||||
|
scratch = pltpu.VMEM(
|
||||||
|
(num_slices_per_block, page_size, num_combined_kv_heads, head_dim),
|
||||||
|
new_kv.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
scratch_shapes = [
|
||||||
|
scratch,
|
||||||
|
pltpu.SemaphoreType.DMA,
|
||||||
|
]
|
||||||
|
|
||||||
|
kernel = pl.pallas_call(
|
||||||
|
_kv_cache_update_kernel,
|
||||||
|
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||||
|
num_scalar_prefetch=len(scalar_prefetches),
|
||||||
|
in_specs=in_specs,
|
||||||
|
out_specs=out_specs,
|
||||||
|
grid=(cdiv(num_kv_update_slices[0], num_slices_per_block),),
|
||||||
|
scratch_shapes=scratch_shapes,
|
||||||
|
),
|
||||||
|
out_shape=out_shape,
|
||||||
|
input_output_aliases={len(scalar_prefetches) + 1: 0},
|
||||||
|
)
|
||||||
|
|
||||||
|
return kernel(*scalar_prefetches, new_kv, kv_cache)[0]
|
||||||
814
attention/ops/prefix_prefill.py
Normal file
814
attention/ops/prefix_prefill.py
Normal file
@@ -0,0 +1,814 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
|
||||||
|
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
# Static kernels parameters
|
||||||
|
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
|
||||||
|
NUM_WARPS = 4 if current_platform.is_rocm() else 8
|
||||||
|
|
||||||
|
# To check compatibility
|
||||||
|
IS_TURING = current_platform.get_device_capability() == (7, 5)
|
||||||
|
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||||
|
|
||||||
|
|
||||||
|
# Here's an example autotuner config for this kernel. This config does provide
|
||||||
|
# a performance improvement, but dramatically increases first call latency in
|
||||||
|
# triton 3.2. Because of this tradeoff, it's currently commented out.
|
||||||
|
# @triton.autotune(
|
||||||
|
# configs=[
|
||||||
|
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \
|
||||||
|
# "num_unroll_cache": 4, \
|
||||||
|
# "num_unroll_request": 1 } | \
|
||||||
|
# ({"kpack": 2, "waves_per_eu": 2} \
|
||||||
|
# if current_platform.is_rocm() else {}), \
|
||||||
|
# num_warps=4, \
|
||||||
|
# num_stages=1)
|
||||||
|
# ],
|
||||||
|
# key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"]
|
||||||
|
# )
|
||||||
|
@triton.jit
|
||||||
|
def _fwd_kernel(
|
||||||
|
Q,
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
K_cache,
|
||||||
|
V_cache,
|
||||||
|
sink_ptr,
|
||||||
|
B_Loc,
|
||||||
|
sm_scale,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
out_scale_inv,
|
||||||
|
B_Start_Loc,
|
||||||
|
B_Seqlen,
|
||||||
|
x: tl.constexpr,
|
||||||
|
Out,
|
||||||
|
stride_b_loc_b,
|
||||||
|
stride_b_loc_s,
|
||||||
|
stride_qbs,
|
||||||
|
stride_qh,
|
||||||
|
stride_qd,
|
||||||
|
stride_kbs,
|
||||||
|
stride_kh,
|
||||||
|
stride_kd,
|
||||||
|
stride_vbs,
|
||||||
|
stride_vh,
|
||||||
|
stride_vd,
|
||||||
|
stride_obs,
|
||||||
|
stride_oh,
|
||||||
|
stride_od,
|
||||||
|
stride_k_cache_bs,
|
||||||
|
stride_k_cache_h,
|
||||||
|
stride_k_cache_d,
|
||||||
|
stride_k_cache_bl: tl.constexpr,
|
||||||
|
stride_k_cache_x,
|
||||||
|
stride_v_cache_bs,
|
||||||
|
stride_v_cache_h,
|
||||||
|
stride_v_cache_d,
|
||||||
|
stride_v_cache_bl,
|
||||||
|
num_queries_per_kv: tl.constexpr,
|
||||||
|
IN_PRECISION: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
|
BLOCK_DMODEL_PADDED: tl.constexpr,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
SLIDING_WINDOW: tl.constexpr,
|
||||||
|
num_unroll_cache: tl.constexpr,
|
||||||
|
num_unroll_request: tl.constexpr,
|
||||||
|
SKIP_DECODE: tl.constexpr,
|
||||||
|
USE_SINKS: tl.constexpr,
|
||||||
|
USE_FP8: tl.constexpr,
|
||||||
|
MAX_Q_LEN: tl.constexpr = 0,
|
||||||
|
MAX_CTX_LEN: tl.constexpr = 0,
|
||||||
|
FP8_MIN: tl.constexpr = float8_info.min,
|
||||||
|
FP8_MAX: tl.constexpr = float8_info.max,
|
||||||
|
):
|
||||||
|
cur_batch = tl.program_id(0)
|
||||||
|
cur_head = tl.program_id(1)
|
||||||
|
start_m = tl.program_id(2)
|
||||||
|
|
||||||
|
cur_kv_head = cur_head // num_queries_per_kv
|
||||||
|
|
||||||
|
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||||
|
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||||
|
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
|
||||||
|
cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index
|
||||||
|
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
|
||||||
|
|
||||||
|
if SKIP_DECODE and cur_batch_query_len == 1:
|
||||||
|
return
|
||||||
|
|
||||||
|
# start position inside of the query
|
||||||
|
# generally, N goes over kv, while M goes over query_len
|
||||||
|
block_start_loc = BLOCK_M * start_m
|
||||||
|
|
||||||
|
# initialize offsets
|
||||||
|
# [BLOCK_SIZE]; starts at 0
|
||||||
|
offs_bs_n = tl.arange(0, BLOCK_SIZE)
|
||||||
|
# [N]; starts at 0
|
||||||
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
|
# [D]; starts at 0
|
||||||
|
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
|
||||||
|
# [M]; starts at current position in query
|
||||||
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
# [M,D]
|
||||||
|
off_q = (
|
||||||
|
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
|
||||||
|
+ cur_head * stride_qh
|
||||||
|
+ offs_d[None, :] * stride_qd
|
||||||
|
)
|
||||||
|
|
||||||
|
dim_mask = tl.where(tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(
|
||||||
|
tl.int1
|
||||||
|
) # [D]
|
||||||
|
|
||||||
|
q = tl.load(
|
||||||
|
Q + off_q,
|
||||||
|
mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len),
|
||||||
|
other=0.0,
|
||||||
|
) # [M,D]
|
||||||
|
|
||||||
|
# initialize pointer to m and l
|
||||||
|
if not USE_SINKS:
|
||||||
|
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||||
|
else:
|
||||||
|
m_i = tl.load(
|
||||||
|
sink_ptr + tl.full([BLOCK_M], cur_head, dtype=tl.int64),
|
||||||
|
mask=(offs_m < cur_batch_query_len),
|
||||||
|
other=float("-inf"),
|
||||||
|
).to(dtype=tl.float32)
|
||||||
|
|
||||||
|
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||||
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D]
|
||||||
|
|
||||||
|
# compute query against context (no causal mask here)
|
||||||
|
for start_n in tl.range(
|
||||||
|
0, cur_batch_ctx_len, BLOCK_SIZE, loop_unroll_factor=num_unroll_cache
|
||||||
|
):
|
||||||
|
start_n = tl.multiple_of(start_n, BLOCK_SIZE)
|
||||||
|
# -- compute qk ----
|
||||||
|
bn = tl.load(
|
||||||
|
B_Loc
|
||||||
|
+ cur_batch * stride_b_loc_b
|
||||||
|
+ (start_n // BLOCK_SIZE) * stride_b_loc_s
|
||||||
|
).to(tl.int64)
|
||||||
|
# [D,BLOCK_SIZE]
|
||||||
|
off_k = (
|
||||||
|
bn[None, :] * stride_k_cache_bs
|
||||||
|
+ cur_kv_head * stride_k_cache_h
|
||||||
|
+ (offs_d[:, None] // x) * stride_k_cache_d
|
||||||
|
+ ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl
|
||||||
|
+ (offs_d[:, None] % x) * stride_k_cache_x
|
||||||
|
)
|
||||||
|
|
||||||
|
# [BLOCK_SIZE,D]
|
||||||
|
off_v = (
|
||||||
|
bn[:, None] * stride_v_cache_bs
|
||||||
|
+ cur_kv_head * stride_v_cache_h
|
||||||
|
+ offs_d[None, :] * stride_v_cache_d
|
||||||
|
+ offs_bs_n[:, None] * stride_v_cache_bl
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
start_n + BLOCK_SIZE > cur_batch_ctx_len
|
||||||
|
or BLOCK_DMODEL != BLOCK_DMODEL_PADDED
|
||||||
|
):
|
||||||
|
k_load = tl.load(
|
||||||
|
K_cache + off_k,
|
||||||
|
mask=dim_mask[:, None]
|
||||||
|
& ((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len),
|
||||||
|
other=0.0,
|
||||||
|
) # [D,N]
|
||||||
|
else:
|
||||||
|
k_load = tl.load(K_cache + off_k)
|
||||||
|
|
||||||
|
if k_load.dtype.is_fp8():
|
||||||
|
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
|
||||||
|
else:
|
||||||
|
k = k_load
|
||||||
|
|
||||||
|
qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N]
|
||||||
|
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||||
|
qk = tl.where(
|
||||||
|
(start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")
|
||||||
|
)
|
||||||
|
qk *= sm_scale
|
||||||
|
if SLIDING_WINDOW > 0:
|
||||||
|
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
|
||||||
|
# Q entries in sequence
|
||||||
|
# (start_n + offs_bs_n[None, :]) are the positions of
|
||||||
|
# KV entries in sequence
|
||||||
|
# So the condition makes sure each entry in Q only attends
|
||||||
|
# to KV entries not more than SLIDING_WINDOW away.
|
||||||
|
#
|
||||||
|
# We can't use -inf here, because the
|
||||||
|
# sliding window may lead to the entire row being masked.
|
||||||
|
# This then makes m_ij contain -inf, which causes NaNs in
|
||||||
|
# exp().
|
||||||
|
qk = tl.where(
|
||||||
|
(cur_batch_ctx_len + offs_m[:, None]) - (start_n + offs_bs_n[None, :])
|
||||||
|
< SLIDING_WINDOW,
|
||||||
|
qk,
|
||||||
|
-10000,
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute running maximum
|
||||||
|
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
|
||||||
|
p = tl.exp(qk - m_ij[:, None])
|
||||||
|
l_ij = tl.sum(p, axis=1)
|
||||||
|
alpha = tl.exp(m_i - m_ij)
|
||||||
|
acc = acc * alpha[:, None]
|
||||||
|
|
||||||
|
# update acc
|
||||||
|
if (
|
||||||
|
start_n + BLOCK_SIZE > cur_batch_ctx_len
|
||||||
|
or BLOCK_DMODEL != BLOCK_DMODEL_PADDED
|
||||||
|
):
|
||||||
|
v_load = tl.load(
|
||||||
|
V_cache + off_v,
|
||||||
|
mask=dim_mask[None, :]
|
||||||
|
& ((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len),
|
||||||
|
other=0.0,
|
||||||
|
) # [N,D]
|
||||||
|
else:
|
||||||
|
v_load = tl.load(V_cache + off_v)
|
||||||
|
|
||||||
|
if v_load.dtype.is_fp8():
|
||||||
|
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
|
||||||
|
else:
|
||||||
|
v = v_load
|
||||||
|
p = p.to(v.dtype)
|
||||||
|
|
||||||
|
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
|
||||||
|
# # update m_i and l_i
|
||||||
|
l_i = l_i * alpha + l_ij
|
||||||
|
m_i = m_ij
|
||||||
|
|
||||||
|
off_k = (
|
||||||
|
offs_n[None, :] * stride_kbs
|
||||||
|
+ cur_kv_head * stride_kh
|
||||||
|
+ offs_d[:, None] * stride_kd
|
||||||
|
)
|
||||||
|
off_v = (
|
||||||
|
offs_n[:, None] * stride_vbs
|
||||||
|
+ cur_kv_head * stride_vh
|
||||||
|
+ offs_d[None, :] * stride_vd
|
||||||
|
)
|
||||||
|
k_ptrs = K + off_k
|
||||||
|
v_ptrs = V + off_v
|
||||||
|
|
||||||
|
# block_mask is 0 when we're already past the current query length
|
||||||
|
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
|
||||||
|
|
||||||
|
# compute query against itself (with causal mask)
|
||||||
|
for start_n in tl.range(
|
||||||
|
0,
|
||||||
|
block_mask * (start_m + 1) * BLOCK_M,
|
||||||
|
BLOCK_N,
|
||||||
|
loop_unroll_factor=num_unroll_request,
|
||||||
|
):
|
||||||
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
|
# -- compute qk ----
|
||||||
|
k = tl.load(
|
||||||
|
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||||
|
mask=dim_mask[:, None]
|
||||||
|
& ((start_n + offs_n[None, :]) < cur_batch_query_len),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
|
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||||
|
qk *= sm_scale
|
||||||
|
# apply causal mask
|
||||||
|
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||||
|
if SLIDING_WINDOW > 0:
|
||||||
|
qk = tl.where(
|
||||||
|
offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW,
|
||||||
|
qk,
|
||||||
|
-10000,
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute running maximum
|
||||||
|
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
|
||||||
|
p = tl.exp(qk - m_ij[:, None])
|
||||||
|
l_ij = tl.sum(p, axis=1)
|
||||||
|
alpha = tl.exp(m_i - m_ij)
|
||||||
|
acc = acc * alpha[:, None]
|
||||||
|
|
||||||
|
# update acc
|
||||||
|
v = tl.load(
|
||||||
|
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||||
|
mask=dim_mask[None, :]
|
||||||
|
& ((start_n + offs_n[:, None]) < cur_batch_query_len),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
p = p.to(v.dtype)
|
||||||
|
|
||||||
|
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
|
||||||
|
# update m_i and l_i
|
||||||
|
l_i = l_i * alpha + l_ij
|
||||||
|
m_i = m_ij
|
||||||
|
|
||||||
|
acc = acc / l_i[:, None]
|
||||||
|
|
||||||
|
# initialize pointers to output
|
||||||
|
off_o = (
|
||||||
|
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
|
||||||
|
+ cur_head * stride_oh
|
||||||
|
+ offs_d[None, :] * stride_od
|
||||||
|
)
|
||||||
|
out_ptrs = Out + off_o
|
||||||
|
if USE_FP8:
|
||||||
|
acc = acc * tl.load(out_scale_inv)
|
||||||
|
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||||
|
tl.store(
|
||||||
|
out_ptrs, acc, mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _fwd_kernel_alibi(
|
||||||
|
Q,
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
K_cache,
|
||||||
|
V_cache,
|
||||||
|
B_Loc,
|
||||||
|
sm_scale,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
B_Start_Loc,
|
||||||
|
B_Seqlen,
|
||||||
|
Alibi_slopes,
|
||||||
|
block_size,
|
||||||
|
x,
|
||||||
|
Out,
|
||||||
|
stride_b_loc_b,
|
||||||
|
stride_b_loc_s,
|
||||||
|
stride_qbs,
|
||||||
|
stride_qh,
|
||||||
|
stride_qd,
|
||||||
|
stride_kbs,
|
||||||
|
stride_kh,
|
||||||
|
stride_kd,
|
||||||
|
stride_vbs,
|
||||||
|
stride_vh,
|
||||||
|
stride_vd,
|
||||||
|
stride_obs,
|
||||||
|
stride_oh,
|
||||||
|
stride_od,
|
||||||
|
stride_k_cache_bs,
|
||||||
|
stride_k_cache_h,
|
||||||
|
stride_k_cache_d,
|
||||||
|
stride_k_cache_bl,
|
||||||
|
stride_k_cache_x,
|
||||||
|
stride_v_cache_bs,
|
||||||
|
stride_v_cache_h,
|
||||||
|
stride_v_cache_d,
|
||||||
|
stride_v_cache_bl,
|
||||||
|
num_queries_per_kv: int,
|
||||||
|
IN_PRECISION: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_DMODEL: tl.constexpr, # head size
|
||||||
|
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
SKIP_DECODE: tl.constexpr,
|
||||||
|
):
|
||||||
|
# attn_bias[]
|
||||||
|
cur_batch = tl.program_id(0)
|
||||||
|
cur_head = tl.program_id(1)
|
||||||
|
start_m = tl.program_id(2)
|
||||||
|
|
||||||
|
cur_kv_head = cur_head // num_queries_per_kv
|
||||||
|
|
||||||
|
# cur_batch_seq_len: the length of prompts
|
||||||
|
# cur_batch_ctx_len: the length of prefix
|
||||||
|
# cur_batch_in_all_start_index: the start id of the dim=0
|
||||||
|
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||||
|
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||||
|
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
|
||||||
|
cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index
|
||||||
|
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
|
||||||
|
|
||||||
|
if SKIP_DECODE and cur_batch_query_len == 1:
|
||||||
|
return
|
||||||
|
|
||||||
|
block_start_loc = BLOCK_M * start_m
|
||||||
|
|
||||||
|
# initialize offsets
|
||||||
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
|
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
|
||||||
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
off_q = (
|
||||||
|
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
|
||||||
|
+ cur_head * stride_qh
|
||||||
|
+ offs_d[None, :] * stride_qd
|
||||||
|
)
|
||||||
|
|
||||||
|
dim_mask = tl.where(tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(
|
||||||
|
tl.int1
|
||||||
|
)
|
||||||
|
|
||||||
|
q = tl.load(
|
||||||
|
Q + off_q,
|
||||||
|
mask=dim_mask[None, :]
|
||||||
|
& (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# # initialize pointer to m and l
|
||||||
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||||
|
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||||
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
|
||||||
|
|
||||||
|
alibi_slope = tl.load(Alibi_slopes + cur_head)
|
||||||
|
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
|
||||||
|
alibi_start_k = 0
|
||||||
|
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
||||||
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
|
# -- compute qk ----
|
||||||
|
bn = tl.load(
|
||||||
|
B_Loc
|
||||||
|
+ cur_batch * stride_b_loc_b
|
||||||
|
+ ((start_n + offs_n) // block_size) * stride_b_loc_s,
|
||||||
|
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||||
|
other=0,
|
||||||
|
).to(tl.int64)
|
||||||
|
off_k = (
|
||||||
|
bn[None, :] * stride_k_cache_bs
|
||||||
|
+ cur_kv_head * stride_k_cache_h
|
||||||
|
+ (offs_d[:, None] // x) * stride_k_cache_d
|
||||||
|
+ ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl
|
||||||
|
+ (offs_d[:, None] % x) * stride_k_cache_x
|
||||||
|
)
|
||||||
|
off_v = (
|
||||||
|
bn[:, None] * stride_v_cache_bs
|
||||||
|
+ cur_kv_head * stride_v_cache_h
|
||||||
|
+ offs_d[None, :] * stride_v_cache_d
|
||||||
|
+ (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl
|
||||||
|
)
|
||||||
|
k_load = tl.load(
|
||||||
|
K_cache + off_k,
|
||||||
|
mask=dim_mask[:, None] & ((start_n + offs_n[None, :]) < cur_batch_ctx_len),
|
||||||
|
other=0.0,
|
||||||
|
) # [D,N]
|
||||||
|
|
||||||
|
if k_load.dtype.is_fp8():
|
||||||
|
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
|
||||||
|
else:
|
||||||
|
k = k_load
|
||||||
|
|
||||||
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
|
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||||
|
qk = tl.where(
|
||||||
|
(start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")
|
||||||
|
)
|
||||||
|
qk *= sm_scale
|
||||||
|
|
||||||
|
# load alibi
|
||||||
|
alibi = (
|
||||||
|
tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - alibi_start_q[:, None]
|
||||||
|
) * alibi_slope
|
||||||
|
alibi = tl.where(
|
||||||
|
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
|
||||||
|
alibi,
|
||||||
|
float("-inf"),
|
||||||
|
)
|
||||||
|
qk += alibi
|
||||||
|
alibi_start_k += BLOCK_N
|
||||||
|
|
||||||
|
# -- compute m_ij, p, l_ij
|
||||||
|
m_ij = tl.max(qk, 1)
|
||||||
|
m_i_new = tl.maximum(m_i, m_ij)
|
||||||
|
p = tl.math.exp(qk - m_i_new[:, None])
|
||||||
|
l_ij = tl.sum(p, 1)
|
||||||
|
# -- update m_i and l_i
|
||||||
|
|
||||||
|
alpha = tl.math.exp(m_i - m_i_new)
|
||||||
|
l_i_new = alpha * l_i + l_ij
|
||||||
|
# -- update output accumulator --
|
||||||
|
# scale p
|
||||||
|
# scale acc
|
||||||
|
acc_scale = alpha
|
||||||
|
# acc_scale = l_i / l_i_new * alpha
|
||||||
|
acc = acc * acc_scale[:, None]
|
||||||
|
# update acc
|
||||||
|
v_load = tl.load(
|
||||||
|
V_cache + off_v,
|
||||||
|
mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) < cur_batch_ctx_len),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
if v_load.dtype.is_fp8():
|
||||||
|
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
|
||||||
|
else:
|
||||||
|
v = v_load
|
||||||
|
p = p.to(v.dtype)
|
||||||
|
|
||||||
|
acc = tl.dot(p, v, acc=acc, input_precision="ieee")
|
||||||
|
# update m_i and l_i
|
||||||
|
l_i = l_i_new
|
||||||
|
m_i = m_i_new
|
||||||
|
|
||||||
|
off_k = (
|
||||||
|
offs_n[None, :] * stride_kbs
|
||||||
|
+ cur_kv_head * stride_kh
|
||||||
|
+ offs_d[:, None] * stride_kd
|
||||||
|
)
|
||||||
|
off_v = (
|
||||||
|
offs_n[:, None] * stride_vbs
|
||||||
|
+ cur_kv_head * stride_vh
|
||||||
|
+ offs_d[None, :] * stride_vd
|
||||||
|
)
|
||||||
|
k_ptrs = K + off_k
|
||||||
|
v_ptrs = V + off_v
|
||||||
|
|
||||||
|
block_mask = tl.where(block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
|
||||||
|
|
||||||
|
# init alibi
|
||||||
|
alibi_slope = tl.load(Alibi_slopes + cur_head)
|
||||||
|
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
|
||||||
|
alibi_start_k = cur_batch_ctx_len
|
||||||
|
# # init debugger
|
||||||
|
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
|
||||||
|
# offset_db_k = tl.arange(0, BLOCK_N)
|
||||||
|
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
|
||||||
|
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||||
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
|
# -- compute qk ----
|
||||||
|
k = tl.load(
|
||||||
|
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||||
|
mask=dim_mask[:, None]
|
||||||
|
& ((start_n + offs_n[None, :]) < cur_batch_seq_len - cur_batch_ctx_len),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
|
qk = tl.dot(q, k, acc=qk, input_precision="ieee")
|
||||||
|
qk *= sm_scale
|
||||||
|
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||||
|
|
||||||
|
# load alibi
|
||||||
|
alibi = (
|
||||||
|
tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - alibi_start_q[:, None]
|
||||||
|
) * alibi_slope
|
||||||
|
alibi = tl.where(
|
||||||
|
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
|
||||||
|
alibi,
|
||||||
|
float("-inf"),
|
||||||
|
)
|
||||||
|
qk += alibi
|
||||||
|
alibi_start_k += BLOCK_N
|
||||||
|
|
||||||
|
# -- compute m_ij, p, l_ij
|
||||||
|
m_ij = tl.max(qk, 1)
|
||||||
|
m_i_new = tl.maximum(m_i, m_ij)
|
||||||
|
p = tl.math.exp(qk - m_i_new[:, None])
|
||||||
|
l_ij = tl.sum(p, 1)
|
||||||
|
# -- update m_i and l_i
|
||||||
|
|
||||||
|
alpha = tl.math.exp(m_i - m_i_new)
|
||||||
|
l_i_new = alpha * l_i + l_ij
|
||||||
|
# -- update output accumulator --
|
||||||
|
# scale p
|
||||||
|
# scale acc
|
||||||
|
acc_scale = alpha
|
||||||
|
# acc_scale = l_i / l_i_new * alpha
|
||||||
|
acc = acc * acc_scale[:, None]
|
||||||
|
# update acc
|
||||||
|
v = tl.load(
|
||||||
|
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||||
|
mask=dim_mask[None, :]
|
||||||
|
& ((start_n + offs_n[:, None]) < cur_batch_seq_len - cur_batch_ctx_len),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
p = p.to(v.dtype)
|
||||||
|
|
||||||
|
acc = tl.dot(p, v, acc=acc, input_precision="ieee")
|
||||||
|
# update m_i and l_i
|
||||||
|
l_i = l_i_new
|
||||||
|
m_i = m_i_new
|
||||||
|
|
||||||
|
acc = acc / l_i[:, None]
|
||||||
|
|
||||||
|
# initialize pointers to output
|
||||||
|
off_o = (
|
||||||
|
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
|
||||||
|
+ cur_head * stride_oh
|
||||||
|
+ offs_d[None, :] * stride_od
|
||||||
|
)
|
||||||
|
out_ptrs = Out + off_o
|
||||||
|
tl.store(
|
||||||
|
out_ptrs,
|
||||||
|
acc,
|
||||||
|
mask=dim_mask[None, :]
|
||||||
|
& (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def context_attention_fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
o,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
b_loc,
|
||||||
|
b_start_loc,
|
||||||
|
b_seq_len,
|
||||||
|
max_seq_len,
|
||||||
|
max_input_len,
|
||||||
|
k_scale: torch.Tensor,
|
||||||
|
v_scale: torch.Tensor,
|
||||||
|
alibi_slopes=None,
|
||||||
|
sliding_window=None,
|
||||||
|
sm_scale=None,
|
||||||
|
skip_decode=False,
|
||||||
|
fp8_out_scale=None,
|
||||||
|
sinks=None,
|
||||||
|
):
|
||||||
|
q_dtype_is_f32 = q.dtype is torch.float32
|
||||||
|
|
||||||
|
# Turing does have tensor core for float32 multiplication
|
||||||
|
# use ieee as fallback for triton kernels work. There is also
|
||||||
|
# warning on vllm/config.py to inform users this fallback
|
||||||
|
# implementation
|
||||||
|
IN_PRECISION = "ieee" if IS_TURING and q_dtype_is_f32 else None
|
||||||
|
|
||||||
|
# Conversion of FP8 Tensor from uint8 storage to
|
||||||
|
# appropriate torch.dtype for interpretation by Triton
|
||||||
|
if "fp8" in kv_cache_dtype:
|
||||||
|
assert k_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||||
|
assert v_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||||
|
|
||||||
|
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||||
|
target_dtype = current_platform.fp8_dtype()
|
||||||
|
elif kv_cache_dtype == "fp8_e5m2":
|
||||||
|
target_dtype = torch.float8_e5m2
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
|
||||||
|
|
||||||
|
k_cache = k_cache.view(target_dtype)
|
||||||
|
v_cache = v_cache.view(target_dtype)
|
||||||
|
|
||||||
|
if (
|
||||||
|
k_cache.dtype == torch.uint8
|
||||||
|
or v_cache.dtype == torch.uint8
|
||||||
|
and kv_cache_dtype == "auto"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"kv_cache_dtype='auto' unsupported for\
|
||||||
|
FP8 KV Cache prefill kernel"
|
||||||
|
)
|
||||||
|
|
||||||
|
# shape constraints
|
||||||
|
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||||
|
assert Lq == Lk and Lk == Lv
|
||||||
|
# round up Lk to a power of 2 - this is required for Triton block size
|
||||||
|
Lk_padded = triton.next_power_of_2(Lk)
|
||||||
|
|
||||||
|
if sm_scale is None:
|
||||||
|
sm_scale = 1.0 / (Lq**0.5)
|
||||||
|
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||||
|
num_queries_per_kv = q.shape[1] // k.shape[1]
|
||||||
|
|
||||||
|
assert batch + 1 == len(b_start_loc)
|
||||||
|
|
||||||
|
# 0 means "disable"
|
||||||
|
if sliding_window is None or sliding_window <= 0:
|
||||||
|
sliding_window = 0
|
||||||
|
|
||||||
|
if alibi_slopes is not None:
|
||||||
|
assert sinks is None, "Sinks arg is not supported with alibi"
|
||||||
|
assert fp8_out_scale is None, "FP8 output not supported with alibi"
|
||||||
|
# need to reduce num. blocks when using fp32
|
||||||
|
# due to increased use of GPU shared memory
|
||||||
|
# if q.dtype is torch.float32:
|
||||||
|
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK
|
||||||
|
# batch, head,
|
||||||
|
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||||
|
_fwd_kernel_alibi[grid](
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
b_loc,
|
||||||
|
sm_scale,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
b_start_loc,
|
||||||
|
b_seq_len,
|
||||||
|
alibi_slopes,
|
||||||
|
v_cache.shape[3],
|
||||||
|
k_cache.shape[4],
|
||||||
|
o,
|
||||||
|
b_loc.stride(0),
|
||||||
|
b_loc.stride(1),
|
||||||
|
q.stride(0),
|
||||||
|
q.stride(1),
|
||||||
|
q.stride(2),
|
||||||
|
k.stride(0),
|
||||||
|
k.stride(1),
|
||||||
|
k.stride(2),
|
||||||
|
v.stride(0),
|
||||||
|
v.stride(1),
|
||||||
|
v.stride(2),
|
||||||
|
o.stride(0),
|
||||||
|
o.stride(1),
|
||||||
|
o.stride(2),
|
||||||
|
k_cache.stride(0),
|
||||||
|
k_cache.stride(1),
|
||||||
|
k_cache.stride(2),
|
||||||
|
k_cache.stride(3),
|
||||||
|
k_cache.stride(4), # [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||||
|
v_cache.stride(0),
|
||||||
|
v_cache.stride(1),
|
||||||
|
v_cache.stride(2),
|
||||||
|
v_cache.stride(3), # [num_blocks, num_kv_heads, head_size, block_size]
|
||||||
|
num_queries_per_kv=num_queries_per_kv,
|
||||||
|
IN_PRECISION=IN_PRECISION,
|
||||||
|
BLOCK_M=BLOCK,
|
||||||
|
BLOCK_DMODEL=Lk,
|
||||||
|
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||||
|
BLOCK_N=BLOCK,
|
||||||
|
SKIP_DECODE=skip_decode,
|
||||||
|
num_warps=NUM_WARPS,
|
||||||
|
num_stages=1,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
max_seq_len = 0 if max_seq_len is None else max_seq_len
|
||||||
|
extra_kargs = {}
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
extra_kargs = {"kpack": 1, "waves_per_eu": 2}
|
||||||
|
|
||||||
|
grid = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"]))
|
||||||
|
_fwd_kernel[grid](
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
sinks,
|
||||||
|
b_loc,
|
||||||
|
sm_scale,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
1.0 / fp8_out_scale if fp8_out_scale is not None else 1.0,
|
||||||
|
b_start_loc,
|
||||||
|
b_seq_len,
|
||||||
|
k_cache.shape[4],
|
||||||
|
o,
|
||||||
|
b_loc.stride(0),
|
||||||
|
b_loc.stride(1),
|
||||||
|
q.stride(0),
|
||||||
|
q.stride(1),
|
||||||
|
q.stride(2),
|
||||||
|
k.stride(0),
|
||||||
|
k.stride(1),
|
||||||
|
k.stride(2),
|
||||||
|
v.stride(0),
|
||||||
|
v.stride(1),
|
||||||
|
v.stride(2),
|
||||||
|
o.stride(0),
|
||||||
|
o.stride(1),
|
||||||
|
o.stride(2),
|
||||||
|
k_cache.stride(0),
|
||||||
|
k_cache.stride(1),
|
||||||
|
k_cache.stride(2),
|
||||||
|
k_cache.stride(3),
|
||||||
|
k_cache.stride(4), # [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||||
|
v_cache.stride(0),
|
||||||
|
v_cache.stride(1),
|
||||||
|
v_cache.stride(2),
|
||||||
|
v_cache.stride(3), # [num_blocks, num_kv_heads, head_size, block_size]
|
||||||
|
BLOCK_SIZE=v_cache.shape[3],
|
||||||
|
num_queries_per_kv=num_queries_per_kv,
|
||||||
|
IN_PRECISION=IN_PRECISION,
|
||||||
|
BLOCK_DMODEL=Lk,
|
||||||
|
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||||
|
SLIDING_WINDOW=sliding_window,
|
||||||
|
SKIP_DECODE=skip_decode,
|
||||||
|
USE_FP8=fp8_out_scale is not None,
|
||||||
|
BLOCK_M=128,
|
||||||
|
BLOCK_N=64,
|
||||||
|
num_unroll_cache=4,
|
||||||
|
num_unroll_request=1,
|
||||||
|
num_warps=4,
|
||||||
|
num_stages=1,
|
||||||
|
USE_SINKS=sinks is not None,
|
||||||
|
**extra_kargs,
|
||||||
|
)
|
||||||
|
return
|
||||||
123
attention/ops/rocm_aiter_paged_attn.py
Normal file
123
attention/ops/rocm_aiter_paged_attn.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import aiter as rocm_aiter
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.ops.paged_attn import PagedAttention
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.math_utils import cdiv
|
||||||
|
|
||||||
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
|
||||||
|
|
||||||
|
class AITERPagedAttention(PagedAttention):
|
||||||
|
@staticmethod
|
||||||
|
def write_to_paged_cache(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: torch.Tensor,
|
||||||
|
v_scale: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
|
||||||
|
PagedAttention.write_to_paged_cache(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
slot_mapping,
|
||||||
|
kv_cache_dtype,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kv_cache_torch_dtype = FP8_DTYPE if "fp8" in kv_cache_dtype else torch.int8
|
||||||
|
key_cache = key_cache.view(kv_cache_torch_dtype)
|
||||||
|
value_cache = value_cache.view(kv_cache_torch_dtype)
|
||||||
|
|
||||||
|
rocm_aiter.reshape_and_cache_with_pertoken_quant(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
slot_mapping.flatten(),
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward_decode(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
max_seq_len: int,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
num_kv_heads: int,
|
||||||
|
scale: float,
|
||||||
|
alibi_slopes: torch.Tensor | None,
|
||||||
|
k_scale: torch.Tensor,
|
||||||
|
v_scale: torch.Tensor,
|
||||||
|
tp_rank: int = 0,
|
||||||
|
blocksparse_local_blocks: int = 0,
|
||||||
|
blocksparse_vert_stride: int = 0,
|
||||||
|
blocksparse_block_size: int = 64,
|
||||||
|
blocksparse_head_sliding_step: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
|
||||||
|
return PagedAttention.forward_decode(
|
||||||
|
query=query,
|
||||||
|
key_cache=key_cache,
|
||||||
|
value_cache=value_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
scale=scale,
|
||||||
|
alibi_slopes=alibi_slopes,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
blocksparse_local_blocks=blocksparse_local_blocks,
|
||||||
|
blocksparse_vert_stride=blocksparse_vert_stride,
|
||||||
|
blocksparse_block_size=blocksparse_block_size,
|
||||||
|
blocksparse_head_sliding_step=blocksparse_head_sliding_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "fp8" in kv_cache_dtype:
|
||||||
|
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||||
|
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||||
|
|
||||||
|
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||||
|
# use blocksparse paged attention
|
||||||
|
block_size = value_cache.size(-1)
|
||||||
|
assert (
|
||||||
|
blocksparse_block_size > 0 and blocksparse_block_size % block_size == 0
|
||||||
|
), (
|
||||||
|
f"{blocksparse_block_size=} needs to be a multiple of"
|
||||||
|
f"{block_size=} used in block_tables."
|
||||||
|
)
|
||||||
|
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
block_size = value_cache.shape[3]
|
||||||
|
max_num_blocks_per_seq = cdiv(max_seq_len, block_size)
|
||||||
|
|
||||||
|
rocm_aiter.pa_fwd_asm(
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
block_tables,
|
||||||
|
seq_lens,
|
||||||
|
max_num_blocks_per_seq,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
return output
|
||||||
712
attention/ops/triton_decode_attention.py
Normal file
712
attention/ops/triton_decode_attention.py
Normal file
@@ -0,0 +1,712 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py
|
||||||
|
# which was originally adapted from
|
||||||
|
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
|
||||||
|
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
|
||||||
|
|
||||||
|
# Changes:
|
||||||
|
# - Add support for page size >= 1.
|
||||||
|
|
||||||
|
# Copyright 2025 vLLM Team
|
||||||
|
# Copyright 2023-2024 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
Memory-efficient attention for decoding.
|
||||||
|
It supports page size >= 1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
is_hip_ = current_platform.is_rocm()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Only print the following warnings when triton version < 3.2.0.
|
||||||
|
# The issue won't affect performance or accuracy.
|
||||||
|
if version.parse(triton.__version__) < version.parse("3.2.0"):
|
||||||
|
logger.warning(
|
||||||
|
"The following error message 'operation scheduled before its operands' "
|
||||||
|
"can be ignored."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def tanh(x):
|
||||||
|
# Tanh is just a scaled sigmoid
|
||||||
|
return 2 * tl.sigmoid(2 * x) - 1
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _fwd_kernel_stage1(
|
||||||
|
Q,
|
||||||
|
K_Buffer,
|
||||||
|
V_Buffer,
|
||||||
|
sm_scale,
|
||||||
|
Req_to_tokens,
|
||||||
|
B_Seqlen,
|
||||||
|
Att_Out,
|
||||||
|
stride_req_to_tokens_b,
|
||||||
|
stride_qbs,
|
||||||
|
stride_qh,
|
||||||
|
stride_buf_kbs,
|
||||||
|
stride_buf_kh,
|
||||||
|
stride_buf_vbs,
|
||||||
|
stride_buf_vh,
|
||||||
|
stride_mid_ob,
|
||||||
|
stride_mid_oh,
|
||||||
|
stride_mid_os,
|
||||||
|
kv_group_num: tl.constexpr,
|
||||||
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
|
BLOCK_DV: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
NUM_KV_SPLITS: tl.constexpr,
|
||||||
|
PAGE_SIZE: tl.constexpr,
|
||||||
|
logit_cap: tl.constexpr,
|
||||||
|
Lk: tl.constexpr,
|
||||||
|
Lv: tl.constexpr,
|
||||||
|
):
|
||||||
|
cur_batch = tl.program_id(0)
|
||||||
|
cur_head = tl.program_id(1)
|
||||||
|
split_kv_id = tl.program_id(2)
|
||||||
|
|
||||||
|
cur_kv_head = cur_head // kv_group_num
|
||||||
|
|
||||||
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||||
|
offs_dv = tl.arange(0, BLOCK_DV)
|
||||||
|
mask_d = offs_d < Lk
|
||||||
|
mask_dv = offs_dv < Lv
|
||||||
|
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||||
|
cur_batch_req_idx = cur_batch
|
||||||
|
|
||||||
|
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
||||||
|
q = tl.load(Q + off_q, mask=mask_d, other=0.0)
|
||||||
|
|
||||||
|
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||||
|
split_kv_start = kv_len_per_split * split_kv_id
|
||||||
|
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
|
||||||
|
|
||||||
|
e_max = -float("inf")
|
||||||
|
e_sum = 0.0
|
||||||
|
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
||||||
|
|
||||||
|
if split_kv_end > split_kv_start:
|
||||||
|
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||||
|
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||||
|
kv_page_number = tl.load(
|
||||||
|
Req_to_tokens
|
||||||
|
+ stride_req_to_tokens_b * cur_batch_req_idx
|
||||||
|
+ offs_n // PAGE_SIZE,
|
||||||
|
mask=offs_n < split_kv_end,
|
||||||
|
other=0,
|
||||||
|
)
|
||||||
|
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
|
||||||
|
offs_buf_k = (
|
||||||
|
kv_loc[:, None] * stride_buf_kbs
|
||||||
|
+ cur_kv_head * stride_buf_kh
|
||||||
|
+ offs_d[None, :]
|
||||||
|
)
|
||||||
|
k = tl.load(
|
||||||
|
K_Buffer + offs_buf_k,
|
||||||
|
mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
qk = tl.sum(q[None, :] * k, 1)
|
||||||
|
qk *= sm_scale
|
||||||
|
|
||||||
|
if logit_cap > 0:
|
||||||
|
qk = logit_cap * tanh(qk / logit_cap)
|
||||||
|
|
||||||
|
qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))
|
||||||
|
|
||||||
|
offs_buf_v = (
|
||||||
|
kv_loc[:, None] * stride_buf_vbs
|
||||||
|
+ cur_kv_head * stride_buf_vh
|
||||||
|
+ offs_dv[None, :]
|
||||||
|
)
|
||||||
|
v = tl.load(
|
||||||
|
V_Buffer + offs_buf_v,
|
||||||
|
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
|
||||||
|
re_scale = tl.exp(e_max - n_e_max)
|
||||||
|
p = tl.exp(qk - n_e_max)
|
||||||
|
acc *= re_scale
|
||||||
|
acc += tl.sum(p[:, None] * v, 0)
|
||||||
|
|
||||||
|
e_sum = e_sum * re_scale + tl.sum(p, 0)
|
||||||
|
e_max = n_e_max
|
||||||
|
|
||||||
|
offs_mid_o = (
|
||||||
|
cur_batch * stride_mid_ob
|
||||||
|
+ cur_head * stride_mid_oh
|
||||||
|
+ split_kv_id * stride_mid_os
|
||||||
|
+ offs_dv
|
||||||
|
)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
Att_Out + offs_mid_o,
|
||||||
|
acc / e_sum,
|
||||||
|
mask=(mask_dv),
|
||||||
|
)
|
||||||
|
|
||||||
|
offs_mid_o_1 = (
|
||||||
|
cur_batch * stride_mid_ob
|
||||||
|
+ cur_head * stride_mid_oh
|
||||||
|
+ split_kv_id * stride_mid_os
|
||||||
|
+ Lv
|
||||||
|
)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
Att_Out + offs_mid_o_1,
|
||||||
|
e_max + tl.log(e_sum),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_att_m_fwd(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
att_out,
|
||||||
|
Req_to_tokens,
|
||||||
|
B_Seqlen,
|
||||||
|
num_kv_splits,
|
||||||
|
sm_scale,
|
||||||
|
page_size,
|
||||||
|
logit_cap,
|
||||||
|
):
|
||||||
|
BLOCK = 64 if not is_hip_ else 8
|
||||||
|
|
||||||
|
NUM_KV_SPLITS = num_kv_splits
|
||||||
|
Lk = k_buffer.shape[-1]
|
||||||
|
Lv = v_buffer.shape[-1]
|
||||||
|
|
||||||
|
batch, head_num = q.shape[0], q.shape[1]
|
||||||
|
|
||||||
|
grid = (batch, head_num, NUM_KV_SPLITS)
|
||||||
|
kv_group_num = q.shape[1] // k_buffer.shape[-2]
|
||||||
|
|
||||||
|
num_warps = 4
|
||||||
|
if kv_group_num != 1:
|
||||||
|
num_warps = 1 if is_hip_ else 2
|
||||||
|
|
||||||
|
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
||||||
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
|
_fwd_kernel_stage1[grid](
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
sm_scale,
|
||||||
|
Req_to_tokens,
|
||||||
|
B_Seqlen,
|
||||||
|
att_out,
|
||||||
|
Req_to_tokens.stride(0),
|
||||||
|
q.stride(0),
|
||||||
|
q.stride(1),
|
||||||
|
k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||||
|
k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||||
|
v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||||
|
v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||||
|
att_out.stride(0),
|
||||||
|
att_out.stride(1),
|
||||||
|
att_out.stride(2),
|
||||||
|
kv_group_num=kv_group_num,
|
||||||
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||||
|
BLOCK_DV=BLOCK_DV,
|
||||||
|
BLOCK_N=BLOCK,
|
||||||
|
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||||
|
PAGE_SIZE=page_size,
|
||||||
|
logit_cap=logit_cap,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=2,
|
||||||
|
Lk=Lk,
|
||||||
|
Lv=Lv,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _fwd_grouped_kernel_stage1(
|
||||||
|
Q,
|
||||||
|
K_Buffer,
|
||||||
|
V_Buffer,
|
||||||
|
sm_scale,
|
||||||
|
Req_to_tokens,
|
||||||
|
B_Seqlen,
|
||||||
|
Att_Out,
|
||||||
|
stride_req_to_tokens_b,
|
||||||
|
stride_qbs,
|
||||||
|
stride_qh,
|
||||||
|
stride_buf_kbs,
|
||||||
|
stride_buf_kh,
|
||||||
|
stride_buf_vbs,
|
||||||
|
stride_buf_vh,
|
||||||
|
stride_mid_ob,
|
||||||
|
stride_mid_oh,
|
||||||
|
stride_mid_os,
|
||||||
|
kv_group_num: tl.constexpr,
|
||||||
|
q_head_num: tl.constexpr,
|
||||||
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
|
BLOCK_DPE: tl.constexpr,
|
||||||
|
BLOCK_DV: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
BLOCK_H: tl.constexpr,
|
||||||
|
NUM_KV_SPLITS: tl.constexpr,
|
||||||
|
PAGE_SIZE: tl.constexpr,
|
||||||
|
logit_cap: tl.constexpr,
|
||||||
|
Lk: tl.constexpr,
|
||||||
|
Lv: tl.constexpr,
|
||||||
|
):
|
||||||
|
cur_batch = tl.program_id(0)
|
||||||
|
cur_head_id = tl.program_id(1)
|
||||||
|
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
|
||||||
|
split_kv_id = tl.program_id(2)
|
||||||
|
|
||||||
|
if kv_group_num > BLOCK_H:
|
||||||
|
VALID_BLOCK_H: tl.constexpr = BLOCK_H
|
||||||
|
else:
|
||||||
|
VALID_BLOCK_H: tl.constexpr = kv_group_num
|
||||||
|
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
|
||||||
|
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
|
||||||
|
mask_h = mask_h & (cur_head < q_head_num)
|
||||||
|
|
||||||
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||||
|
offs_dv = tl.arange(0, BLOCK_DV)
|
||||||
|
mask_d = offs_d < Lk
|
||||||
|
mask_dv = offs_dv < Lv
|
||||||
|
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||||
|
cur_batch_req_idx = cur_batch
|
||||||
|
|
||||||
|
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
|
||||||
|
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
|
||||||
|
|
||||||
|
if BLOCK_DPE > 0:
|
||||||
|
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
||||||
|
mask_dpe = offs_dpe < Lk
|
||||||
|
off_qpe = (
|
||||||
|
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
|
||||||
|
)
|
||||||
|
qpe = tl.load(
|
||||||
|
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||||
|
split_kv_start = kv_len_per_split * split_kv_id
|
||||||
|
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
|
||||||
|
|
||||||
|
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
|
||||||
|
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
|
||||||
|
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
|
||||||
|
|
||||||
|
if split_kv_end > split_kv_start:
|
||||||
|
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||||
|
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||||
|
kv_page_number = tl.load(
|
||||||
|
Req_to_tokens
|
||||||
|
+ stride_req_to_tokens_b * cur_batch_req_idx
|
||||||
|
+ offs_n // PAGE_SIZE,
|
||||||
|
mask=offs_n < split_kv_end,
|
||||||
|
other=0,
|
||||||
|
)
|
||||||
|
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
|
||||||
|
offs_buf_k = (
|
||||||
|
kv_loc[None, :] * stride_buf_kbs
|
||||||
|
+ cur_kv_head * stride_buf_kh
|
||||||
|
+ offs_d[:, None]
|
||||||
|
)
|
||||||
|
k = tl.load(
|
||||||
|
K_Buffer + offs_buf_k,
|
||||||
|
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
qk = tl.dot(q, k.to(q.dtype))
|
||||||
|
if BLOCK_DPE > 0:
|
||||||
|
offs_buf_kpe = (
|
||||||
|
kv_loc[None, :] * stride_buf_kbs
|
||||||
|
+ cur_kv_head * stride_buf_kh
|
||||||
|
+ offs_dpe[:, None]
|
||||||
|
)
|
||||||
|
kpe = tl.load(
|
||||||
|
K_Buffer + offs_buf_kpe,
|
||||||
|
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
qk += tl.dot(qpe, kpe.to(qpe.dtype))
|
||||||
|
qk *= sm_scale
|
||||||
|
|
||||||
|
if logit_cap > 0:
|
||||||
|
qk = logit_cap * tanh(qk / logit_cap)
|
||||||
|
|
||||||
|
qk = tl.where(
|
||||||
|
mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
|
||||||
|
)
|
||||||
|
|
||||||
|
offs_buf_v = (
|
||||||
|
kv_loc[:, None] * stride_buf_vbs
|
||||||
|
+ cur_kv_head * stride_buf_vh
|
||||||
|
+ offs_dv[None, :]
|
||||||
|
)
|
||||||
|
v = tl.load(
|
||||||
|
V_Buffer + offs_buf_v,
|
||||||
|
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||||
|
re_scale = tl.exp(e_max - n_e_max)
|
||||||
|
p = tl.exp(qk - n_e_max[:, None])
|
||||||
|
acc *= re_scale[:, None]
|
||||||
|
acc += tl.dot(p.to(v.dtype), v)
|
||||||
|
|
||||||
|
e_sum = e_sum * re_scale + tl.sum(p, 1)
|
||||||
|
e_max = n_e_max
|
||||||
|
|
||||||
|
offs_mid_o = (
|
||||||
|
cur_batch * stride_mid_ob
|
||||||
|
+ cur_head[:, None] * stride_mid_oh
|
||||||
|
+ split_kv_id * stride_mid_os
|
||||||
|
+ offs_dv[None, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
Att_Out + offs_mid_o,
|
||||||
|
acc / e_sum[:, None],
|
||||||
|
mask=(mask_h[:, None]) & (mask_dv[None, :]),
|
||||||
|
)
|
||||||
|
|
||||||
|
offs_mid_o_1 = (
|
||||||
|
cur_batch * stride_mid_ob
|
||||||
|
+ cur_head * stride_mid_oh
|
||||||
|
+ split_kv_id * stride_mid_os
|
||||||
|
+ Lv
|
||||||
|
)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
Att_Out + offs_mid_o_1,
|
||||||
|
e_max + tl.log(e_sum),
|
||||||
|
mask=mask_h,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_grouped_att_m_fwd(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
att_out,
|
||||||
|
Req_to_tokens,
|
||||||
|
B_Seqlen,
|
||||||
|
num_kv_splits,
|
||||||
|
sm_scale,
|
||||||
|
page_size,
|
||||||
|
logit_cap,
|
||||||
|
):
|
||||||
|
BLOCK = 32
|
||||||
|
Lk = k_buffer.shape[-1]
|
||||||
|
Lv = v_buffer.shape[-1]
|
||||||
|
|
||||||
|
# [TODO] work around shmem limit on MI3xx
|
||||||
|
if is_hip_ and Lk >= 576:
|
||||||
|
BLOCK = 16
|
||||||
|
|
||||||
|
if Lk == 576:
|
||||||
|
BLOCK_DMODEL = 512
|
||||||
|
BLOCK_DPE = 64
|
||||||
|
elif Lk == 288:
|
||||||
|
BLOCK_DMODEL = 256
|
||||||
|
BLOCK_DPE = 32
|
||||||
|
else:
|
||||||
|
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
||||||
|
BLOCK_DPE = 0
|
||||||
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
|
batch, head_num = q.shape[0], q.shape[1]
|
||||||
|
kv_group_num = q.shape[1] // k_buffer.shape[-2]
|
||||||
|
|
||||||
|
BLOCK_H = 16
|
||||||
|
NUM_KV_SPLITS = num_kv_splits
|
||||||
|
grid = (
|
||||||
|
batch,
|
||||||
|
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
|
||||||
|
NUM_KV_SPLITS,
|
||||||
|
)
|
||||||
|
|
||||||
|
extra_kargs = {}
|
||||||
|
num_stages = 2
|
||||||
|
if is_hip_:
|
||||||
|
# https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization
|
||||||
|
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||||
|
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||||
|
num_stages = 1
|
||||||
|
|
||||||
|
_fwd_grouped_kernel_stage1[grid](
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
sm_scale,
|
||||||
|
Req_to_tokens,
|
||||||
|
B_Seqlen,
|
||||||
|
att_out,
|
||||||
|
Req_to_tokens.stride(0),
|
||||||
|
q.stride(0),
|
||||||
|
q.stride(1),
|
||||||
|
k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||||
|
k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||||
|
v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||||
|
v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||||
|
att_out.stride(0),
|
||||||
|
att_out.stride(1),
|
||||||
|
att_out.stride(2),
|
||||||
|
kv_group_num=kv_group_num,
|
||||||
|
q_head_num=head_num,
|
||||||
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||||
|
BLOCK_DPE=BLOCK_DPE,
|
||||||
|
BLOCK_DV=BLOCK_DV,
|
||||||
|
BLOCK_N=BLOCK,
|
||||||
|
BLOCK_H=BLOCK_H,
|
||||||
|
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||||
|
PAGE_SIZE=page_size,
|
||||||
|
logit_cap=logit_cap,
|
||||||
|
num_warps=4,
|
||||||
|
num_stages=num_stages,
|
||||||
|
Lk=Lk,
|
||||||
|
Lv=Lv,
|
||||||
|
**extra_kargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _fwd_kernel_stage2(
|
||||||
|
Mid_O,
|
||||||
|
o,
|
||||||
|
lse,
|
||||||
|
B_Seqlen,
|
||||||
|
stride_mid_ob,
|
||||||
|
stride_mid_oh,
|
||||||
|
stride_mid_os,
|
||||||
|
stride_obs,
|
||||||
|
stride_oh,
|
||||||
|
stride_lse_bs,
|
||||||
|
NUM_KV_SPLITS: tl.constexpr,
|
||||||
|
BLOCK_DV: tl.constexpr,
|
||||||
|
Lv: tl.constexpr,
|
||||||
|
):
|
||||||
|
cur_batch = tl.program_id(0)
|
||||||
|
cur_head = tl.program_id(1)
|
||||||
|
|
||||||
|
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||||
|
|
||||||
|
offs_d = tl.arange(0, BLOCK_DV)
|
||||||
|
mask_d = offs_d < Lv
|
||||||
|
|
||||||
|
e_sum = 0.0
|
||||||
|
e_max = -float("inf")
|
||||||
|
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
||||||
|
|
||||||
|
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
|
||||||
|
offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv
|
||||||
|
|
||||||
|
for split_kv_id in range(0, NUM_KV_SPLITS):
|
||||||
|
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||||
|
split_kv_start = kv_len_per_split * split_kv_id
|
||||||
|
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
|
||||||
|
|
||||||
|
if split_kv_end > split_kv_start:
|
||||||
|
tv = tl.load(
|
||||||
|
Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0
|
||||||
|
)
|
||||||
|
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
|
||||||
|
n_e_max = tl.maximum(tlogic, e_max)
|
||||||
|
|
||||||
|
old_scale = tl.exp(e_max - n_e_max)
|
||||||
|
acc *= old_scale
|
||||||
|
exp_logic = tl.exp(tlogic - n_e_max)
|
||||||
|
acc += exp_logic * tv
|
||||||
|
|
||||||
|
e_sum = e_sum * old_scale + exp_logic
|
||||||
|
e_max = n_e_max
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
o + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
|
||||||
|
acc / e_sum,
|
||||||
|
mask=mask_d,
|
||||||
|
)
|
||||||
|
lse_val = e_max + tl.log(e_sum)
|
||||||
|
tl.store(
|
||||||
|
lse + cur_batch * stride_lse_bs + cur_head,
|
||||||
|
lse_val,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_softmax_reducev_fwd(
|
||||||
|
logits,
|
||||||
|
q,
|
||||||
|
o,
|
||||||
|
lse,
|
||||||
|
v_buffer,
|
||||||
|
b_seq_len,
|
||||||
|
num_kv_splits,
|
||||||
|
):
|
||||||
|
batch, head_num = q.shape[0], q.shape[1]
|
||||||
|
Lv = v_buffer.shape[-1]
|
||||||
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
|
NUM_KV_SPLITS = num_kv_splits
|
||||||
|
|
||||||
|
extra_kargs = {}
|
||||||
|
if is_hip_:
|
||||||
|
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
||||||
|
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||||
|
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||||
|
|
||||||
|
grid = (batch, head_num)
|
||||||
|
_fwd_kernel_stage2[grid](
|
||||||
|
logits,
|
||||||
|
o,
|
||||||
|
lse,
|
||||||
|
b_seq_len,
|
||||||
|
logits.stride(0),
|
||||||
|
logits.stride(1),
|
||||||
|
logits.stride(2),
|
||||||
|
o.stride(0),
|
||||||
|
o.stride(1),
|
||||||
|
lse.stride(0),
|
||||||
|
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||||
|
BLOCK_DV=BLOCK_DV,
|
||||||
|
Lv=Lv,
|
||||||
|
num_warps=4,
|
||||||
|
num_stages=2,
|
||||||
|
**extra_kargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_attention_fwd_normal(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
o,
|
||||||
|
lse,
|
||||||
|
req_to_token,
|
||||||
|
b_seq_len,
|
||||||
|
attn_logits,
|
||||||
|
num_kv_splits,
|
||||||
|
sm_scale,
|
||||||
|
page_size,
|
||||||
|
logit_cap=0.0,
|
||||||
|
):
|
||||||
|
_decode_att_m_fwd(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
attn_logits,
|
||||||
|
req_to_token,
|
||||||
|
b_seq_len,
|
||||||
|
num_kv_splits,
|
||||||
|
sm_scale,
|
||||||
|
page_size,
|
||||||
|
logit_cap,
|
||||||
|
)
|
||||||
|
_decode_softmax_reducev_fwd(
|
||||||
|
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_attention_fwd_grouped(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
o,
|
||||||
|
lse,
|
||||||
|
req_to_token,
|
||||||
|
b_seq_len,
|
||||||
|
attn_logits,
|
||||||
|
num_kv_splits,
|
||||||
|
sm_scale,
|
||||||
|
page_size,
|
||||||
|
logit_cap=0.0,
|
||||||
|
):
|
||||||
|
_decode_grouped_att_m_fwd(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
attn_logits,
|
||||||
|
req_to_token,
|
||||||
|
b_seq_len,
|
||||||
|
num_kv_splits,
|
||||||
|
sm_scale,
|
||||||
|
page_size,
|
||||||
|
logit_cap,
|
||||||
|
)
|
||||||
|
_decode_softmax_reducev_fwd(
|
||||||
|
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_attention_fwd(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
o,
|
||||||
|
lse,
|
||||||
|
req_to_token,
|
||||||
|
b_seq_len,
|
||||||
|
attn_logits,
|
||||||
|
num_kv_splits,
|
||||||
|
sm_scale,
|
||||||
|
page_size=1,
|
||||||
|
logit_cap=0.0,
|
||||||
|
):
|
||||||
|
assert num_kv_splits == attn_logits.shape[2]
|
||||||
|
kv_group_num = q.shape[1] // v_buffer.shape[-2]
|
||||||
|
|
||||||
|
if kv_group_num == 1:
|
||||||
|
# MHA
|
||||||
|
decode_attention_fwd_normal(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
o,
|
||||||
|
lse,
|
||||||
|
req_to_token,
|
||||||
|
b_seq_len,
|
||||||
|
attn_logits,
|
||||||
|
num_kv_splits,
|
||||||
|
sm_scale,
|
||||||
|
page_size,
|
||||||
|
logit_cap,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# GQA/MQA/MLA
|
||||||
|
decode_attention_fwd_grouped(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
o,
|
||||||
|
lse,
|
||||||
|
req_to_token,
|
||||||
|
b_seq_len,
|
||||||
|
attn_logits,
|
||||||
|
num_kv_splits,
|
||||||
|
sm_scale,
|
||||||
|
page_size,
|
||||||
|
logit_cap,
|
||||||
|
)
|
||||||
105
attention/ops/triton_merge_attn_states.py
Normal file
105
attention/ops/triton_merge_attn_states.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
|
||||||
|
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||||
|
# can be used to combine partial attention results (in the split-KV case)
|
||||||
|
def merge_attn_states(
|
||||||
|
output: torch.Tensor,
|
||||||
|
prefix_output: torch.Tensor,
|
||||||
|
prefix_lse: torch.Tensor,
|
||||||
|
suffix_output: torch.Tensor,
|
||||||
|
suffix_lse: torch.Tensor,
|
||||||
|
output_lse: torch.Tensor | None = None,
|
||||||
|
) -> None:
|
||||||
|
num_tokens = output.shape[0]
|
||||||
|
num_query_heads = output.shape[1]
|
||||||
|
head_size = output.shape[2]
|
||||||
|
padded_head_size = triton.next_power_of_2(head_size)
|
||||||
|
|
||||||
|
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
|
||||||
|
merge_attn_states_kernel[(num_tokens, num_query_heads)](
|
||||||
|
output,
|
||||||
|
output_lse,
|
||||||
|
prefix_output,
|
||||||
|
prefix_lse,
|
||||||
|
suffix_output,
|
||||||
|
suffix_lse,
|
||||||
|
head_size,
|
||||||
|
padded_head_size,
|
||||||
|
output_lse is not None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def merge_attn_states_kernel(
|
||||||
|
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||||
|
output_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||||
|
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||||
|
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||||
|
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||||
|
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||||
|
HEAD_SIZE: tl.constexpr,
|
||||||
|
PADDED_HEAD_SIZE: tl.constexpr,
|
||||||
|
OUTPUT_LSE: tl.constexpr,
|
||||||
|
):
|
||||||
|
token_idx = tl.program_id(0)
|
||||||
|
num_tokens = tl.num_programs(0)
|
||||||
|
head_idx = tl.program_id(1)
|
||||||
|
num_heads = tl.num_programs(1)
|
||||||
|
|
||||||
|
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
|
||||||
|
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
|
||||||
|
|
||||||
|
# FA2 and FA3 have different behavior for when the sum-exp is 0, this namely
|
||||||
|
# arises with 0 len seqlens. FA3 returns -inf here while FA2 returns inf.
|
||||||
|
# If we see an inf assume FA2 and convert inf to -inf for consistency
|
||||||
|
# and correctness. Inf generally doesn't make sense in this context outside
|
||||||
|
# of undefined-behavior/FA2-case, so I think this a safe assumption.
|
||||||
|
p_lse = float("-inf") if p_lse == float("inf") else p_lse
|
||||||
|
s_lse = float("-inf") if s_lse == float("inf") else s_lse
|
||||||
|
|
||||||
|
max_lse = tl.maximum(p_lse, s_lse)
|
||||||
|
p_lse = p_lse - max_lse
|
||||||
|
s_lse = s_lse - max_lse
|
||||||
|
# Will reuse precomputed Exp values for scale factor computation.
|
||||||
|
p_se = tl.exp(p_lse)
|
||||||
|
s_se = tl.exp(s_lse)
|
||||||
|
out_se = p_se + s_se
|
||||||
|
|
||||||
|
if OUTPUT_LSE:
|
||||||
|
out_lse = tl.log(out_se) + max_lse
|
||||||
|
tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse)
|
||||||
|
|
||||||
|
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
|
||||||
|
head_mask = head_arange < HEAD_SIZE
|
||||||
|
p_out = tl.load(
|
||||||
|
prefix_output
|
||||||
|
+ token_idx * num_heads * HEAD_SIZE
|
||||||
|
+ head_idx * HEAD_SIZE
|
||||||
|
+ head_arange,
|
||||||
|
mask=head_mask,
|
||||||
|
)
|
||||||
|
s_out = tl.load(
|
||||||
|
suffix_output
|
||||||
|
+ token_idx * num_heads * HEAD_SIZE
|
||||||
|
+ head_idx * HEAD_SIZE
|
||||||
|
+ head_arange,
|
||||||
|
mask=head_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE(woosuk): Be careful with the numerical stability.
|
||||||
|
# We should compute the scale first, and then multiply it with the output.
|
||||||
|
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
|
||||||
|
p_scale = p_se / out_se
|
||||||
|
s_scale = s_se / out_se
|
||||||
|
out = p_out * p_scale + s_out * s_scale
|
||||||
|
tl.store(
|
||||||
|
output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
|
||||||
|
out,
|
||||||
|
mask=head_mask,
|
||||||
|
)
|
||||||
184
attention/ops/triton_reshape_and_cache_flash.py
Normal file
184
attention/ops/triton_reshape_and_cache_flash.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def reshape_and_cache_kernel_flash(
|
||||||
|
key_ptr, # [num_tokens, num_heads, head_size]
|
||||||
|
value_ptr, # [num_tokens, num_heads, head_size]
|
||||||
|
key_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
|
||||||
|
value_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
|
||||||
|
slot_mapping_ptr, # [num_tokens]
|
||||||
|
k_scale, # float32
|
||||||
|
v_scale, # float32
|
||||||
|
# strides
|
||||||
|
key_stride: tl.int64,
|
||||||
|
value_stride: tl.int64,
|
||||||
|
block_stride: tl.int64,
|
||||||
|
page_stride: tl.int64,
|
||||||
|
num_heads: tl.constexpr,
|
||||||
|
head_size: tl.constexpr,
|
||||||
|
block_size: tl.constexpr,
|
||||||
|
# FP8 flags
|
||||||
|
FP8_KV_CACHE: tl.constexpr,
|
||||||
|
# tune parameters
|
||||||
|
TILE_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
token_idx = tl.program_id(axis=0)
|
||||||
|
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64)
|
||||||
|
if slot_idx < 0:
|
||||||
|
# Padding token that should be ignored.
|
||||||
|
return
|
||||||
|
|
||||||
|
tile_i = tl.program_id(axis=1)
|
||||||
|
tile_offs = tl.arange(0, TILE_SIZE)
|
||||||
|
tile_pos = tile_i * TILE_SIZE + tile_offs
|
||||||
|
|
||||||
|
block_idx = slot_idx // block_size
|
||||||
|
block_offset = slot_idx % block_size
|
||||||
|
|
||||||
|
src_key_idx = token_idx * key_stride
|
||||||
|
src_value_idx = token_idx * value_stride
|
||||||
|
|
||||||
|
tgt_idx = block_idx * block_stride + block_offset * page_stride
|
||||||
|
|
||||||
|
# [TILE_SIZE]
|
||||||
|
key_load = tl.load(
|
||||||
|
key_ptr + src_key_idx + tile_pos, mask=tile_pos < (num_heads * head_size)
|
||||||
|
)
|
||||||
|
if FP8_KV_CACHE:
|
||||||
|
# tl.store will do the correct implicit cast to fp8,
|
||||||
|
# based on the key_cache_ptr.dtype.element_ty
|
||||||
|
key_tile = key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale)
|
||||||
|
else:
|
||||||
|
key_tile = key_load
|
||||||
|
|
||||||
|
# [TILE_SIZE]
|
||||||
|
value_load = tl.load(
|
||||||
|
value_ptr + src_value_idx + tile_pos, mask=tile_pos < (num_heads * head_size)
|
||||||
|
)
|
||||||
|
if FP8_KV_CACHE:
|
||||||
|
if value_load.dtype.is_fp8():
|
||||||
|
value_tile = value_load
|
||||||
|
else:
|
||||||
|
# tl.store will do the correct implicit cast to fp8,
|
||||||
|
# based on the value_cache_ptr.dtype.element_ty
|
||||||
|
value_tile = value_load / tl.load(v_scale)
|
||||||
|
else:
|
||||||
|
value_tile = value_load
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
key_cache_ptr + tgt_idx + tile_pos,
|
||||||
|
key_tile,
|
||||||
|
mask=tile_pos < (num_heads * head_size),
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
value_cache_ptr + tgt_idx + tile_pos,
|
||||||
|
value_tile,
|
||||||
|
mask=tile_pos < (num_heads * head_size),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def triton_reshape_and_cache_flash(
|
||||||
|
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||||
|
value: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||||
|
# [num_blocks, block_size, num_heads, head_size]
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
# [num_blocks, block_size, num_heads, head_size]
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor, # [num_tokens]
|
||||||
|
kv_cache_dtype: str, # "auto", "fp8"
|
||||||
|
k_scale: torch.Tensor, # float32
|
||||||
|
v_scale: torch.Tensor, # float32
|
||||||
|
):
|
||||||
|
num_heads = key.shape[1]
|
||||||
|
head_size = key.shape[2]
|
||||||
|
block_size = key_cache.shape[1]
|
||||||
|
n = num_heads * head_size
|
||||||
|
|
||||||
|
key_stride = key.stride()[0]
|
||||||
|
value_stride = value.stride()[0]
|
||||||
|
block_stride = key_cache.stride()[0]
|
||||||
|
page_stride = key_cache.stride()[1]
|
||||||
|
|
||||||
|
head_stride = key_cache.stride()[2]
|
||||||
|
assert head_stride == head_size, "only continous heads are supported"
|
||||||
|
|
||||||
|
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), (
|
||||||
|
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
|
||||||
|
)
|
||||||
|
kv_cache_torch_dtype = (
|
||||||
|
current_platform.fp8_dtype()
|
||||||
|
if kv_cache_dtype.startswith("fp8")
|
||||||
|
else key_cache.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"):
|
||||||
|
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
|
||||||
|
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
|
||||||
|
key_cache = key_cache.view(kv_cache_torch_dtype)
|
||||||
|
value_cache = value_cache.view(kv_cache_torch_dtype)
|
||||||
|
assert kv_cache_dtype != torch.uint8, (
|
||||||
|
"explicit fp8 cast and store to "
|
||||||
|
"uint8 is not supported by triton reshape_and_cache_flash"
|
||||||
|
)
|
||||||
|
|
||||||
|
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
|
||||||
|
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
torch.float8_e5m2,
|
||||||
|
torch.uint8,
|
||||||
|
torch.float8_e4m3fnuz,
|
||||||
|
], (
|
||||||
|
"unsupported dtype of KV cache tensor, got "
|
||||||
|
"{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, "
|
||||||
|
"fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz."
|
||||||
|
)
|
||||||
|
|
||||||
|
# heuristics instead of autotuning
|
||||||
|
TILE_SIZE = min(2048, triton.next_power_of_2(n))
|
||||||
|
if current_platform.is_rocm() or current_platform.is_xpu():
|
||||||
|
num_stages = 4
|
||||||
|
num_warps = 8
|
||||||
|
else: # cuda
|
||||||
|
num_stages = 10
|
||||||
|
num_warps = 16
|
||||||
|
if torch.cuda.get_device_capability(key.device)[0] < 9:
|
||||||
|
TILE_SIZE = min(512, TILE_SIZE)
|
||||||
|
|
||||||
|
# TODO(ngl): maybe replace with static launch grid to avoid overhead if
|
||||||
|
# using cudagraphs
|
||||||
|
grid = lambda meta: (
|
||||||
|
slot_mapping.shape[0],
|
||||||
|
triton.cdiv(n, meta["TILE_SIZE"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
reshape_and_cache_kernel_flash[grid](
|
||||||
|
key_ptr=key,
|
||||||
|
value_ptr=value,
|
||||||
|
key_cache_ptr=key_cache,
|
||||||
|
value_cache_ptr=value_cache,
|
||||||
|
slot_mapping_ptr=slot_mapping,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
|
# strides
|
||||||
|
key_stride=key_stride,
|
||||||
|
value_stride=value_stride,
|
||||||
|
block_stride=block_stride,
|
||||||
|
page_stride=page_stride,
|
||||||
|
num_heads=num_heads,
|
||||||
|
head_size=head_size,
|
||||||
|
block_size=block_size,
|
||||||
|
# FP8 flags
|
||||||
|
FP8_KV_CACHE=FP8_KV_CACHE,
|
||||||
|
# autotune parameters
|
||||||
|
TILE_SIZE=TILE_SIZE,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=num_stages,
|
||||||
|
)
|
||||||
941
attention/ops/triton_unified_attention.py
Normal file
941
attention/ops/triton_unified_attention.py
Normal file
@@ -0,0 +1,941 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Authors:
|
||||||
|
# - Burkhard Ringlein <ngl@zurich.ibm.com>
|
||||||
|
# - Jan van Lunteren <jvl@zurich.ibm.com>
|
||||||
|
# - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
|
||||||
|
# - Thomas Parnell <tpa@zurich.ibm.com>
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def cdiv_fn(x, y):
|
||||||
|
return (x + y - 1) // y
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def apply_softcap(S, x):
|
||||||
|
Sdiv = S / x
|
||||||
|
p1 = tl.exp(Sdiv)
|
||||||
|
p2 = tl.exp(-Sdiv)
|
||||||
|
return x * (p1 - p2) / (p1 + p2)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def find_seq_idx(
|
||||||
|
query_start_len_ptr,
|
||||||
|
target_idx,
|
||||||
|
num_seqs,
|
||||||
|
BLOCK_Q: tl.constexpr,
|
||||||
|
use_q_block_mode: tl.constexpr,
|
||||||
|
):
|
||||||
|
left: tl.int32 = 0
|
||||||
|
right = num_seqs
|
||||||
|
while left < right:
|
||||||
|
mid = (left + right) // 2
|
||||||
|
val = tl.load(query_start_len_ptr + mid)
|
||||||
|
mid_val = val // BLOCK_Q + mid if use_q_block_mode else val
|
||||||
|
|
||||||
|
if mid_val <= target_idx:
|
||||||
|
left = mid + 1
|
||||||
|
else:
|
||||||
|
right = mid
|
||||||
|
|
||||||
|
return left - 1
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def kernel_unified_attention_2d(
|
||||||
|
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||||
|
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||||
|
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||||
|
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||||
|
sink_ptr, # [num_query_heads]
|
||||||
|
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||||
|
seq_lens_ptr, # [num_seqs]
|
||||||
|
alibi_slopes_ptr, # [num_query_heads]
|
||||||
|
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
|
||||||
|
scale, # float32
|
||||||
|
k_scale, # float32
|
||||||
|
v_scale, # float32
|
||||||
|
out_scale, # float32
|
||||||
|
softcap, # float32
|
||||||
|
num_query_heads: tl.constexpr, # int
|
||||||
|
num_queries_per_kv: tl.constexpr, # int
|
||||||
|
block_table_stride: tl.int64, # int
|
||||||
|
query_stride_0: tl.int64, # int
|
||||||
|
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||||
|
output_stride_0: tl.int64, # int
|
||||||
|
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||||
|
qq_bias_stride_0: tl.int64, # int
|
||||||
|
BLOCK_SIZE: tl.constexpr, # int
|
||||||
|
TILE_SIZE: tl.constexpr, # int must be power of 2
|
||||||
|
HEAD_SIZE: tl.constexpr, # int
|
||||||
|
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||||
|
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||||
|
USE_QQ_BIAS: tl.constexpr, # bool
|
||||||
|
USE_SOFTCAP: tl.constexpr, # bool
|
||||||
|
USE_SINKS: tl.constexpr, # bool
|
||||||
|
SLIDING_WINDOW: tl.constexpr, # int
|
||||||
|
stride_k_cache_0: tl.int64, # int
|
||||||
|
stride_k_cache_1: tl.int64, # int
|
||||||
|
stride_k_cache_2: tl.int64, # int
|
||||||
|
stride_k_cache_3: tl.constexpr, # int
|
||||||
|
stride_v_cache_0: tl.int64, # int
|
||||||
|
stride_v_cache_1: tl.int64, # int
|
||||||
|
stride_v_cache_2: tl.int64, # int
|
||||||
|
stride_v_cache_3: tl.constexpr, # int
|
||||||
|
query_start_len_ptr, # [num_seqs+1]
|
||||||
|
BLOCK_Q: tl.constexpr, # int
|
||||||
|
num_seqs: tl.int32,
|
||||||
|
BLOCK_M: tl.constexpr, # int
|
||||||
|
USE_FP8: tl.constexpr, # bool
|
||||||
|
FP8_MIN: tl.constexpr = float8_info.min,
|
||||||
|
FP8_MAX: tl.constexpr = float8_info.max,
|
||||||
|
):
|
||||||
|
q_block_global_idx = tl.program_id(0)
|
||||||
|
kv_head_idx = tl.program_id(1)
|
||||||
|
|
||||||
|
seq_idx = find_seq_idx(
|
||||||
|
query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True
|
||||||
|
)
|
||||||
|
|
||||||
|
q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx
|
||||||
|
|
||||||
|
q_block_local_idx = q_block_global_idx - q_block_start_idx
|
||||||
|
|
||||||
|
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
||||||
|
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
|
||||||
|
|
||||||
|
cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index
|
||||||
|
|
||||||
|
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
|
||||||
|
return
|
||||||
|
|
||||||
|
offs_m = tl.arange(0, BLOCK_M)
|
||||||
|
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||||
|
offs_t = tl.arange(0, TILE_SIZE)
|
||||||
|
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
|
||||||
|
|
||||||
|
query_offset_0 = cur_batch_in_all_start_index + query_pos
|
||||||
|
query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv
|
||||||
|
query_offset = (
|
||||||
|
query_offset_0[:, None] * query_stride_0
|
||||||
|
+ query_offset_1[:, None] * query_stride_1
|
||||||
|
+ offs_d[None, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
|
||||||
|
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
|
||||||
|
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
|
||||||
|
|
||||||
|
# Q : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||||
|
Q = tl.load(
|
||||||
|
query_ptr + query_offset,
|
||||||
|
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
block_table_offset = seq_idx * block_table_stride
|
||||||
|
|
||||||
|
if not USE_SINKS:
|
||||||
|
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||||
|
else:
|
||||||
|
M = tl.load(
|
||||||
|
sink_ptr + query_offset_1,
|
||||||
|
mask=query_mask_1,
|
||||||
|
other=float("-inf"),
|
||||||
|
).to(dtype=tl.float32)
|
||||||
|
|
||||||
|
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||||
|
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
|
||||||
|
|
||||||
|
# sequence len for this particular sequence
|
||||||
|
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||||
|
|
||||||
|
# context length for this particular sequences
|
||||||
|
context_len = seq_len - cur_batch_query_len
|
||||||
|
|
||||||
|
# alibi slope for this head
|
||||||
|
if USE_ALIBI_SLOPES:
|
||||||
|
alibi_slope = tl.load(
|
||||||
|
alibi_slopes_ptr + query_offset_1, mask=query_mask_1, other=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# query-query attention bias
|
||||||
|
if USE_QQ_BIAS:
|
||||||
|
qq_bias_row_ptrs = (
|
||||||
|
qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
|
||||||
|
) # shape: [BLOCK_M]
|
||||||
|
|
||||||
|
# compute the length of the longest sequence prefix spanned by any
|
||||||
|
# query token in the current q_block (q_block_local_idx)
|
||||||
|
max_seq_prefix_len = (
|
||||||
|
context_len
|
||||||
|
+ q_block_local_idx * BLOCK_Q
|
||||||
|
+ (BLOCK_M - 1) // num_queries_per_kv
|
||||||
|
+ 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# adjust for potential padding in the last q_block by considering the
|
||||||
|
# actual sequence length
|
||||||
|
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
|
||||||
|
|
||||||
|
# calculate the number of tiles that need to be processed to
|
||||||
|
# cover the longest sequence prefix (due to causal masking, tiles beyond
|
||||||
|
# this prefix can be skipped)
|
||||||
|
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
|
||||||
|
|
||||||
|
# ---- Sliding-window tile pruning --------------------
|
||||||
|
# Default: keep previous global behavior
|
||||||
|
tile_start = 0
|
||||||
|
tile_end = num_tiles
|
||||||
|
if SLIDING_WINDOW > 0:
|
||||||
|
# Query rows covered by this Q-block
|
||||||
|
qpos_lo = q_block_local_idx * BLOCK_Q
|
||||||
|
qpos_hi = tl.minimum(
|
||||||
|
qpos_lo + (BLOCK_M - 1) // num_queries_per_kv,
|
||||||
|
cur_batch_query_len - 1,
|
||||||
|
)
|
||||||
|
# For sliding window, each query position q can only attend to
|
||||||
|
# keys in the range [q_abs - SLIDING_WINDOW + 1, q_abs]
|
||||||
|
# where q_abs = context_len + q
|
||||||
|
# The union of allowed key positions for this Q-block is:
|
||||||
|
# [context_len + qpos_lo - SLIDING_WINDOW + 1, context_len + qpos_hi]
|
||||||
|
first_allowed_key = context_len + qpos_lo - SLIDING_WINDOW + 1
|
||||||
|
last_allowed_key = context_len + qpos_hi
|
||||||
|
# Convert to tile indices and clamp
|
||||||
|
tile_start = tl.maximum(0, first_allowed_key // TILE_SIZE)
|
||||||
|
tile_end = tl.minimum((last_allowed_key // TILE_SIZE) + 1, num_tiles)
|
||||||
|
|
||||||
|
# iterate through tiles (now limited to the sliding window range)
|
||||||
|
for j in range(tile_start, tile_end):
|
||||||
|
seq_offset = j * TILE_SIZE + offs_t
|
||||||
|
tile_mask = seq_offset < max_seq_prefix_len
|
||||||
|
|
||||||
|
physical_block_idx = tl.load(
|
||||||
|
block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE
|
||||||
|
).to(tl.int64)
|
||||||
|
|
||||||
|
v_offset = (
|
||||||
|
physical_block_idx[:, None] * stride_v_cache_0
|
||||||
|
+ kv_head_idx * stride_v_cache_2
|
||||||
|
+ offs_d[None, :] * stride_v_cache_3
|
||||||
|
+ (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1
|
||||||
|
)
|
||||||
|
|
||||||
|
k_offset = (
|
||||||
|
physical_block_idx[None, :] * stride_k_cache_0
|
||||||
|
+ kv_head_idx * stride_k_cache_2
|
||||||
|
+ offs_d[:, None] * stride_k_cache_3
|
||||||
|
+ (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1
|
||||||
|
)
|
||||||
|
|
||||||
|
# K : (HEAD_SIZE, TILE_SIZE)
|
||||||
|
K_load = tl.load(
|
||||||
|
key_cache_ptr + k_offset,
|
||||||
|
mask=dim_mask[:, None] & tile_mask[None, :],
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if K_load.dtype.is_fp8():
|
||||||
|
if Q.dtype.is_fp8():
|
||||||
|
K = K_load
|
||||||
|
else:
|
||||||
|
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
||||||
|
else:
|
||||||
|
K = K_load
|
||||||
|
|
||||||
|
# V : (TILE_SIZE, HEAD_SIZE)
|
||||||
|
V_load = tl.load(
|
||||||
|
value_cache_ptr + v_offset,
|
||||||
|
mask=dim_mask[None, :] & tile_mask[:, None],
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if V_load.dtype.is_fp8():
|
||||||
|
if Q.dtype.is_fp8():
|
||||||
|
V = V_load
|
||||||
|
else:
|
||||||
|
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
||||||
|
else:
|
||||||
|
V = V_load
|
||||||
|
|
||||||
|
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
||||||
|
|
||||||
|
# S : (BLOCK_M, TILE_SIZE)
|
||||||
|
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||||
|
|
||||||
|
S += scale * tl.dot(Q, K)
|
||||||
|
|
||||||
|
if USE_SOFTCAP:
|
||||||
|
S = apply_softcap(S, softcap)
|
||||||
|
|
||||||
|
S = tl.where(
|
||||||
|
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
|
||||||
|
)
|
||||||
|
|
||||||
|
if SLIDING_WINDOW > 0:
|
||||||
|
S = tl.where(
|
||||||
|
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
|
||||||
|
S,
|
||||||
|
float("-inf"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if USE_ALIBI_SLOPES:
|
||||||
|
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||||
|
|
||||||
|
if USE_QQ_BIAS:
|
||||||
|
# compute key positions relative to query section
|
||||||
|
key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE]
|
||||||
|
# load bias only for keys that correspond to queries
|
||||||
|
is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0
|
||||||
|
qq_bias = tl.load(
|
||||||
|
qq_bias_row_ptrs + key_rel_pos[None, :],
|
||||||
|
mask=is_query_key[None, :], # avoid OOB for context keys
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
S += qq_bias
|
||||||
|
|
||||||
|
# compute running maximum
|
||||||
|
# m_j : (BLOCK_M,)
|
||||||
|
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||||
|
|
||||||
|
# For sliding window there's a chance the max is -inf due to masking of
|
||||||
|
# the entire row. In this case we need to set m_j 0 to avoid NaN
|
||||||
|
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
|
||||||
|
|
||||||
|
# P : (BLOCK_M, TILE_SIZE)
|
||||||
|
P = tl.exp(S - m_j[:, None])
|
||||||
|
|
||||||
|
# l_j : (BLOCK_M,)
|
||||||
|
l_j = tl.sum(P, axis=1)
|
||||||
|
|
||||||
|
# alpha : (BLOCK_M, )
|
||||||
|
alpha = tl.exp(M - m_j)
|
||||||
|
|
||||||
|
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||||
|
acc = acc * alpha[:, None]
|
||||||
|
|
||||||
|
# update constants
|
||||||
|
L = L * alpha + l_j
|
||||||
|
M = m_j
|
||||||
|
|
||||||
|
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||||
|
acc += tl.dot(P.to(V.dtype), V)
|
||||||
|
|
||||||
|
# epilogue
|
||||||
|
acc = acc / L[:, None]
|
||||||
|
if USE_FP8:
|
||||||
|
acc = acc * tl.load(out_scale)
|
||||||
|
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||||
|
|
||||||
|
output_offset = (
|
||||||
|
query_offset_0[:, None] * output_stride_0
|
||||||
|
+ query_offset_1[:, None] * output_stride_1
|
||||||
|
+ offs_d[None, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
output_ptr + output_offset,
|
||||||
|
acc,
|
||||||
|
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def kernel_unified_attention_3d(
|
||||||
|
segm_output_ptr,
|
||||||
|
# [num_tokens, num_query_heads, num_segments, head_size]
|
||||||
|
segm_max_ptr, # [num_tokens, num_query_heads, num_segments]
|
||||||
|
segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments]
|
||||||
|
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||||
|
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
|
||||||
|
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
|
||||||
|
sink_ptr, # [num_query_heads]
|
||||||
|
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||||
|
seq_lens_ptr, # [num_seqs]
|
||||||
|
alibi_slopes_ptr, # [num_query_heads]
|
||||||
|
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
|
||||||
|
scale, # float32
|
||||||
|
k_scale, # float32
|
||||||
|
v_scale, # float32
|
||||||
|
softcap, # float32
|
||||||
|
num_query_heads: tl.constexpr, # int
|
||||||
|
num_queries_per_kv: tl.constexpr, # int
|
||||||
|
block_table_stride: tl.int64, # int
|
||||||
|
query_stride_0: tl.int64, # int
|
||||||
|
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||||
|
qq_bias_stride_0: tl.int64, # int
|
||||||
|
BLOCK_SIZE: tl.constexpr, # int
|
||||||
|
TILE_SIZE: tl.constexpr, # int, must be power of 2
|
||||||
|
HEAD_SIZE: tl.constexpr, # int
|
||||||
|
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||||
|
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||||
|
USE_QQ_BIAS: tl.constexpr, # bool
|
||||||
|
USE_SOFTCAP: tl.constexpr, # bool
|
||||||
|
USE_SINKS: tl.constexpr, # bool
|
||||||
|
SLIDING_WINDOW: tl.constexpr, # int
|
||||||
|
stride_k_cache_0: tl.int64, # int
|
||||||
|
stride_k_cache_1: tl.int64, # int
|
||||||
|
stride_k_cache_2: tl.int64, # int
|
||||||
|
stride_k_cache_3: tl.constexpr, # int
|
||||||
|
stride_v_cache_0: tl.int64, # int
|
||||||
|
stride_v_cache_1: tl.int64, # int
|
||||||
|
stride_v_cache_2: tl.int64, # int
|
||||||
|
stride_v_cache_3: tl.constexpr, # int
|
||||||
|
query_start_len_ptr, # [num_seqs+1]
|
||||||
|
BLOCK_Q: tl.constexpr, # int
|
||||||
|
num_seqs: tl.int32,
|
||||||
|
BLOCK_M: tl.constexpr, # int
|
||||||
|
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
||||||
|
):
|
||||||
|
q_block_global_idx = tl.program_id(0)
|
||||||
|
kv_head_idx = tl.program_id(1)
|
||||||
|
segm_idx = tl.program_id(2)
|
||||||
|
|
||||||
|
seq_idx = find_seq_idx(
|
||||||
|
query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True
|
||||||
|
)
|
||||||
|
|
||||||
|
q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx
|
||||||
|
|
||||||
|
q_block_local_idx = q_block_global_idx - q_block_start_idx
|
||||||
|
|
||||||
|
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
||||||
|
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
|
||||||
|
|
||||||
|
cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index
|
||||||
|
|
||||||
|
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
|
||||||
|
return
|
||||||
|
|
||||||
|
# sequence len for this particular sequence
|
||||||
|
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||||
|
|
||||||
|
# number of segments for this particular sequence
|
||||||
|
num_segments = NUM_SEGMENTS_PER_SEQ
|
||||||
|
tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE)
|
||||||
|
|
||||||
|
if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len:
|
||||||
|
return
|
||||||
|
|
||||||
|
offs_m = tl.arange(0, BLOCK_M)
|
||||||
|
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||||
|
offs_t = tl.arange(0, TILE_SIZE)
|
||||||
|
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
|
||||||
|
|
||||||
|
query_offset_0 = cur_batch_in_all_start_index + query_pos
|
||||||
|
query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv
|
||||||
|
query_offset = (
|
||||||
|
query_offset_0[:, None] * query_stride_0
|
||||||
|
+ query_offset_1[:, None] * query_stride_1
|
||||||
|
+ offs_d[None, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
|
||||||
|
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
|
||||||
|
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
|
||||||
|
|
||||||
|
# Q : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||||
|
Q = tl.load(
|
||||||
|
query_ptr + query_offset,
|
||||||
|
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
block_table_offset = seq_idx * block_table_stride
|
||||||
|
|
||||||
|
if USE_SINKS:
|
||||||
|
if segm_idx == 0:
|
||||||
|
M = tl.load(
|
||||||
|
sink_ptr + query_offset_1,
|
||||||
|
mask=query_mask_1,
|
||||||
|
other=float("-inf"),
|
||||||
|
).to(dtype=tl.float32)
|
||||||
|
else:
|
||||||
|
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||||
|
else:
|
||||||
|
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||||
|
|
||||||
|
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||||
|
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
|
||||||
|
|
||||||
|
# context length for this particular sequences
|
||||||
|
context_len = seq_len - cur_batch_query_len
|
||||||
|
|
||||||
|
# alibi slope for this head
|
||||||
|
if USE_ALIBI_SLOPES:
|
||||||
|
alibi_slope = tl.load(
|
||||||
|
alibi_slopes_ptr + query_offset_1, mask=query_mask_1, other=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# query-query attention bias
|
||||||
|
if USE_QQ_BIAS:
|
||||||
|
qq_bias_row_ptrs = (
|
||||||
|
qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
|
||||||
|
) # shape: [BLOCK_M]
|
||||||
|
|
||||||
|
# compute the length of the longest sequence prefix spanned by any
|
||||||
|
# query token in the current q_block (q_block_local_idx)
|
||||||
|
max_seq_prefix_len = (
|
||||||
|
context_len
|
||||||
|
+ q_block_local_idx * BLOCK_Q
|
||||||
|
+ (BLOCK_M - 1) // num_queries_per_kv
|
||||||
|
+ 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# adjust for potential padding in the last q_block by considering the
|
||||||
|
# actual sequence length
|
||||||
|
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
|
||||||
|
|
||||||
|
# calculate the number of tiles that need to be processed to
|
||||||
|
# cover the longest sequence prefix (due to causal masking, tiles beyond
|
||||||
|
# this prefix can be skipped)
|
||||||
|
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
|
||||||
|
|
||||||
|
# iterate through tiles within current segment
|
||||||
|
for j in range(
|
||||||
|
segm_idx * tiles_per_segment,
|
||||||
|
min((segm_idx + 1) * tiles_per_segment, num_tiles),
|
||||||
|
):
|
||||||
|
seq_offset = j * TILE_SIZE + offs_t
|
||||||
|
tile_mask = seq_offset < max_seq_prefix_len
|
||||||
|
|
||||||
|
physical_block_idx = tl.load(
|
||||||
|
block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE
|
||||||
|
).to(tl.int64)
|
||||||
|
|
||||||
|
v_offset = (
|
||||||
|
physical_block_idx[:, None] * stride_v_cache_0
|
||||||
|
+ kv_head_idx * stride_v_cache_2
|
||||||
|
+ offs_d[None, :] * stride_v_cache_3
|
||||||
|
+ (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1
|
||||||
|
)
|
||||||
|
|
||||||
|
k_offset = (
|
||||||
|
physical_block_idx[None, :] * stride_k_cache_0
|
||||||
|
+ kv_head_idx * stride_k_cache_2
|
||||||
|
+ offs_d[:, None] * stride_k_cache_3
|
||||||
|
+ (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1
|
||||||
|
)
|
||||||
|
|
||||||
|
# K : (HEAD_SIZE, TILE_SIZE)
|
||||||
|
K_load = tl.load(
|
||||||
|
key_cache_ptr + k_offset,
|
||||||
|
mask=dim_mask[:, None] & tile_mask[None, :],
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if K_load.dtype.is_fp8():
|
||||||
|
if Q.dtype.is_fp8():
|
||||||
|
K = K_load
|
||||||
|
else:
|
||||||
|
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
||||||
|
else:
|
||||||
|
K = K_load
|
||||||
|
|
||||||
|
# V : (TILE_SIZE, HEAD_SIZE)
|
||||||
|
V_load = tl.load(
|
||||||
|
value_cache_ptr + v_offset,
|
||||||
|
mask=dim_mask[None, :] & tile_mask[:, None],
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if V_load.dtype.is_fp8():
|
||||||
|
if Q.dtype.is_fp8():
|
||||||
|
V = V_load
|
||||||
|
else:
|
||||||
|
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
||||||
|
else:
|
||||||
|
V = V_load
|
||||||
|
|
||||||
|
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
||||||
|
|
||||||
|
# S : (BLOCK_M, TILE_SIZE)
|
||||||
|
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||||
|
S += scale * tl.dot(Q, K)
|
||||||
|
|
||||||
|
if USE_SOFTCAP:
|
||||||
|
S = apply_softcap(S, softcap)
|
||||||
|
|
||||||
|
S = tl.where(
|
||||||
|
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
|
||||||
|
)
|
||||||
|
|
||||||
|
if SLIDING_WINDOW > 0:
|
||||||
|
S = tl.where(
|
||||||
|
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
|
||||||
|
S,
|
||||||
|
float("-inf"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if USE_ALIBI_SLOPES:
|
||||||
|
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||||
|
|
||||||
|
if USE_QQ_BIAS:
|
||||||
|
# compute key positions relative to query section
|
||||||
|
key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE]
|
||||||
|
# load bias only for keys that correspond to queries
|
||||||
|
is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0
|
||||||
|
qq_bias = tl.load(
|
||||||
|
qq_bias_row_ptrs + key_rel_pos[None, :],
|
||||||
|
mask=is_query_key[None, :], # avoid OOB for context keys
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
S += qq_bias
|
||||||
|
|
||||||
|
# compute running maximum
|
||||||
|
# m_j : (BLOCK_M,)
|
||||||
|
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||||
|
|
||||||
|
# For sliding window there's a chance the max is -inf due to masking of
|
||||||
|
# the entire row. In this case we need to set m_j 0 to avoid NaN
|
||||||
|
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
|
||||||
|
|
||||||
|
# P : (BLOCK_M, TILE_SIZE,)
|
||||||
|
P = tl.exp(S - m_j[:, None])
|
||||||
|
|
||||||
|
# l_j : (BLOCK_M,)
|
||||||
|
l_j = tl.sum(P, axis=1)
|
||||||
|
|
||||||
|
# alpha : (BLOCK_M, )
|
||||||
|
alpha = tl.exp(M - m_j)
|
||||||
|
|
||||||
|
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||||
|
acc = acc * alpha[:, None]
|
||||||
|
|
||||||
|
# update constants
|
||||||
|
L = L * alpha + l_j
|
||||||
|
M = m_j
|
||||||
|
|
||||||
|
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||||
|
acc += tl.dot(P.to(V.dtype), V)
|
||||||
|
|
||||||
|
segm_output_offset = (
|
||||||
|
query_offset_0[:, None].to(tl.int64)
|
||||||
|
* (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED)
|
||||||
|
+ query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED)
|
||||||
|
+ segm_idx * HEAD_SIZE_PADDED
|
||||||
|
+ tl.arange(0, HEAD_SIZE_PADDED)[None, :]
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
segm_output_ptr + segm_output_offset,
|
||||||
|
acc,
|
||||||
|
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||||
|
)
|
||||||
|
segm_offset = (
|
||||||
|
query_offset_0.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ)
|
||||||
|
+ query_offset_1 * NUM_SEGMENTS_PER_SEQ
|
||||||
|
+ segm_idx
|
||||||
|
)
|
||||||
|
tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1)
|
||||||
|
tl.store(segm_expsum_ptr + segm_offset, L, mask=query_mask_0 & query_mask_1)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def reduce_segments(
|
||||||
|
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||||
|
segm_output_ptr,
|
||||||
|
# [num_tokens, num_query_heads, max_num_segments, head_size]
|
||||||
|
segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments]
|
||||||
|
segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments]
|
||||||
|
seq_lens_ptr, # [num_seqs]
|
||||||
|
num_seqs, # int
|
||||||
|
num_query_heads: tl.constexpr, # int
|
||||||
|
out_scale_inv, # float32
|
||||||
|
output_stride_0: tl.int64, # int
|
||||||
|
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||||
|
block_table_stride: tl.int64, # int
|
||||||
|
TILE_SIZE: tl.constexpr, # int
|
||||||
|
HEAD_SIZE: tl.constexpr, # int, must be power of 2
|
||||||
|
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||||
|
query_start_len_ptr, # [num_seqs+1]
|
||||||
|
BLOCK_Q: tl.constexpr, # int
|
||||||
|
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
||||||
|
USE_FP8: tl.constexpr, # bool
|
||||||
|
FP8_MIN: tl.constexpr = float8_info.min,
|
||||||
|
FP8_MAX: tl.constexpr = float8_info.max,
|
||||||
|
):
|
||||||
|
query_token_idx = tl.program_id(0)
|
||||||
|
query_head_idx = tl.program_id(1)
|
||||||
|
|
||||||
|
seq_idx = find_seq_idx(
|
||||||
|
query_start_len_ptr, query_token_idx, num_seqs, BLOCK_Q, False
|
||||||
|
)
|
||||||
|
|
||||||
|
# sequence len for this particular sequence
|
||||||
|
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||||
|
|
||||||
|
# number of segments for this particular sequence
|
||||||
|
num_segments = NUM_SEGMENTS_PER_SEQ
|
||||||
|
tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE)
|
||||||
|
|
||||||
|
# create masks for subsequent loads
|
||||||
|
act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE)
|
||||||
|
segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full(
|
||||||
|
[NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32
|
||||||
|
)
|
||||||
|
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, 0).to(tl.int1)
|
||||||
|
|
||||||
|
# load segment maxima
|
||||||
|
segm_offset = (
|
||||||
|
query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ)
|
||||||
|
+ query_head_idx * NUM_SEGMENTS_PER_SEQ
|
||||||
|
+ tl.arange(0, NUM_SEGMENTS_PER_SEQ)
|
||||||
|
)
|
||||||
|
segm_max = tl.load(segm_max_ptr + segm_offset, mask=segm_mask, other=float("-inf"))
|
||||||
|
overall_max = tl.max(segm_max)
|
||||||
|
|
||||||
|
# load and rescale segment exp sums
|
||||||
|
segm_expsum = tl.load(segm_expsum_ptr + segm_offset, mask=segm_mask, other=0.0)
|
||||||
|
segm_expsum = segm_expsum * tl.exp(segm_max - overall_max)
|
||||||
|
overall_expsum = tl.sum(segm_expsum)
|
||||||
|
|
||||||
|
# load, rescale, and add segment attention outputs
|
||||||
|
segm_output_offset = (
|
||||||
|
query_token_idx.to(tl.int64)
|
||||||
|
* (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED)
|
||||||
|
+ query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED)
|
||||||
|
+ tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED
|
||||||
|
+ tl.arange(0, HEAD_SIZE_PADDED)[None, :]
|
||||||
|
)
|
||||||
|
segm_output = tl.load(
|
||||||
|
segm_output_ptr + segm_output_offset,
|
||||||
|
mask=segm_mask[:, None] & dim_mask[None, :],
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
segm_output *= tl.exp(segm_max - overall_max)[:, None]
|
||||||
|
acc_sum = tl.sum(segm_output, axis=0)
|
||||||
|
# safely divide by overall_expsum, returning 0.0 if overall_expsum is 0
|
||||||
|
acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum)
|
||||||
|
|
||||||
|
if USE_FP8:
|
||||||
|
acc = acc * tl.load(out_scale_inv)
|
||||||
|
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||||
|
|
||||||
|
# write result
|
||||||
|
output_offset = (
|
||||||
|
query_token_idx * output_stride_0
|
||||||
|
+ query_head_idx * output_stride_1
|
||||||
|
+ tl.arange(0, HEAD_SIZE_PADDED)
|
||||||
|
)
|
||||||
|
tl.store(output_ptr + output_offset, acc, mask=dim_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def unified_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens_q,
|
||||||
|
max_seqlen_q,
|
||||||
|
seqused_k,
|
||||||
|
max_seqlen_k,
|
||||||
|
softmax_scale,
|
||||||
|
causal,
|
||||||
|
window_size,
|
||||||
|
block_table,
|
||||||
|
softcap,
|
||||||
|
q_descale,
|
||||||
|
k_descale,
|
||||||
|
v_descale,
|
||||||
|
alibi_slopes=None,
|
||||||
|
output_scale=None,
|
||||||
|
qq_bias=None,
|
||||||
|
# Optional tensor for sinks
|
||||||
|
sinks=None,
|
||||||
|
):
|
||||||
|
assert causal, "Only causal attention is supported"
|
||||||
|
assert q_descale is None, "Q scales not supported"
|
||||||
|
|
||||||
|
if sinks is not None:
|
||||||
|
assert sinks.shape[0] == q.shape[1], "Sinks must be num_query_heads size"
|
||||||
|
|
||||||
|
use_alibi_slopes = alibi_slopes is not None
|
||||||
|
use_qq_bias = qq_bias is not None
|
||||||
|
|
||||||
|
block_size = v.shape[1]
|
||||||
|
num_seqs = len(seqused_k)
|
||||||
|
num_query_heads = q.shape[1]
|
||||||
|
num_kv_heads = k.shape[2]
|
||||||
|
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||||
|
head_size = q.shape[2]
|
||||||
|
|
||||||
|
BLOCK_M = (
|
||||||
|
16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv)
|
||||||
|
)
|
||||||
|
BLOCK_Q = BLOCK_M // num_queries_per_kv
|
||||||
|
|
||||||
|
# Ideally we would launch with kernel with:
|
||||||
|
# \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks.
|
||||||
|
# However, it is slow to realize the query_lens on cpu.
|
||||||
|
# Instead we use upper-bound:
|
||||||
|
# \sum_i[ceil(query_len[i] / BLOCK_Q)]
|
||||||
|
# <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1]
|
||||||
|
# = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs
|
||||||
|
# <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs
|
||||||
|
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
|
||||||
|
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
|
||||||
|
|
||||||
|
# Assigning default tile sizes for prefill and decode.
|
||||||
|
# Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1)
|
||||||
|
# and at least 16 for all other data types.
|
||||||
|
TILE_SIZE_PREFILL = 32
|
||||||
|
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
|
||||||
|
|
||||||
|
# if batch contains a prefill
|
||||||
|
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
|
||||||
|
kernel_unified_attention_2d[
|
||||||
|
(
|
||||||
|
total_num_q_blocks,
|
||||||
|
num_kv_heads,
|
||||||
|
)
|
||||||
|
](
|
||||||
|
output_ptr=out,
|
||||||
|
query_ptr=q,
|
||||||
|
key_cache_ptr=k,
|
||||||
|
value_cache_ptr=v,
|
||||||
|
sink_ptr=sinks,
|
||||||
|
block_tables_ptr=block_table,
|
||||||
|
seq_lens_ptr=seqused_k,
|
||||||
|
alibi_slopes_ptr=alibi_slopes,
|
||||||
|
qq_bias_ptr=qq_bias,
|
||||||
|
scale=softmax_scale,
|
||||||
|
k_scale=k_descale,
|
||||||
|
v_scale=v_descale,
|
||||||
|
out_scale=1 / output_scale if output_scale is not None else 1.0,
|
||||||
|
softcap=softcap,
|
||||||
|
num_query_heads=num_query_heads,
|
||||||
|
num_queries_per_kv=num_queries_per_kv,
|
||||||
|
block_table_stride=block_table.stride(0),
|
||||||
|
query_stride_0=q.stride(0),
|
||||||
|
query_stride_1=q.stride(1),
|
||||||
|
output_stride_0=out.stride(0),
|
||||||
|
output_stride_1=out.stride(1),
|
||||||
|
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
|
||||||
|
BLOCK_SIZE=block_size,
|
||||||
|
TILE_SIZE=TILE_SIZE_PREFILL,
|
||||||
|
HEAD_SIZE=head_size,
|
||||||
|
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||||
|
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||||
|
USE_QQ_BIAS=use_qq_bias,
|
||||||
|
USE_SOFTCAP=(softcap > 0),
|
||||||
|
USE_SINKS=(sinks is not None),
|
||||||
|
SLIDING_WINDOW=(1 + window_size[0]),
|
||||||
|
stride_k_cache_0=k.stride(0),
|
||||||
|
stride_k_cache_1=k.stride(1),
|
||||||
|
stride_k_cache_2=k.stride(2),
|
||||||
|
stride_k_cache_3=k.stride(3),
|
||||||
|
stride_v_cache_0=v.stride(0),
|
||||||
|
stride_v_cache_1=v.stride(1),
|
||||||
|
stride_v_cache_2=v.stride(2),
|
||||||
|
stride_v_cache_3=v.stride(3),
|
||||||
|
query_start_len_ptr=cu_seqlens_q,
|
||||||
|
BLOCK_Q=BLOCK_Q,
|
||||||
|
num_seqs=num_seqs,
|
||||||
|
BLOCK_M=BLOCK_M,
|
||||||
|
USE_FP8=output_scale is not None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
|
||||||
|
# value that showed good performance in tests
|
||||||
|
NUM_SEGMENTS = 16
|
||||||
|
|
||||||
|
segm_output = torch.empty(
|
||||||
|
q.shape[0],
|
||||||
|
num_query_heads,
|
||||||
|
NUM_SEGMENTS,
|
||||||
|
triton.next_power_of_2(head_size),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=q.device,
|
||||||
|
)
|
||||||
|
segm_max = torch.empty(
|
||||||
|
q.shape[0],
|
||||||
|
num_query_heads,
|
||||||
|
NUM_SEGMENTS,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=q.device,
|
||||||
|
)
|
||||||
|
segm_expsum = torch.empty(
|
||||||
|
q.shape[0],
|
||||||
|
num_query_heads,
|
||||||
|
NUM_SEGMENTS,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=q.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)](
|
||||||
|
segm_output_ptr=segm_output,
|
||||||
|
segm_max_ptr=segm_max,
|
||||||
|
segm_expsum_ptr=segm_expsum,
|
||||||
|
query_ptr=q,
|
||||||
|
key_cache_ptr=k,
|
||||||
|
value_cache_ptr=v,
|
||||||
|
sink_ptr=sinks,
|
||||||
|
block_tables_ptr=block_table,
|
||||||
|
seq_lens_ptr=seqused_k,
|
||||||
|
alibi_slopes_ptr=alibi_slopes,
|
||||||
|
qq_bias_ptr=qq_bias,
|
||||||
|
scale=softmax_scale,
|
||||||
|
k_scale=k_descale,
|
||||||
|
v_scale=v_descale,
|
||||||
|
softcap=softcap,
|
||||||
|
num_query_heads=num_query_heads,
|
||||||
|
num_queries_per_kv=num_queries_per_kv,
|
||||||
|
block_table_stride=block_table.stride(0),
|
||||||
|
query_stride_0=q.stride(0),
|
||||||
|
query_stride_1=q.stride(1),
|
||||||
|
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
|
||||||
|
BLOCK_SIZE=block_size,
|
||||||
|
TILE_SIZE=TILE_SIZE_DECODE,
|
||||||
|
HEAD_SIZE=head_size,
|
||||||
|
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||||
|
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||||
|
USE_QQ_BIAS=use_qq_bias,
|
||||||
|
USE_SOFTCAP=(softcap > 0),
|
||||||
|
USE_SINKS=(sinks is not None),
|
||||||
|
SLIDING_WINDOW=(1 + window_size[0]),
|
||||||
|
stride_k_cache_0=k.stride(0),
|
||||||
|
stride_k_cache_1=k.stride(1),
|
||||||
|
stride_k_cache_2=k.stride(2),
|
||||||
|
stride_k_cache_3=k.stride(3),
|
||||||
|
stride_v_cache_0=v.stride(0),
|
||||||
|
stride_v_cache_1=v.stride(1),
|
||||||
|
stride_v_cache_2=v.stride(2),
|
||||||
|
stride_v_cache_3=v.stride(3),
|
||||||
|
query_start_len_ptr=cu_seqlens_q,
|
||||||
|
BLOCK_Q=BLOCK_Q,
|
||||||
|
num_seqs=num_seqs,
|
||||||
|
BLOCK_M=BLOCK_M,
|
||||||
|
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
|
||||||
|
)
|
||||||
|
reduce_segments[(q.shape[0], num_query_heads)](
|
||||||
|
output_ptr=out,
|
||||||
|
segm_output_ptr=segm_output,
|
||||||
|
segm_max_ptr=segm_max,
|
||||||
|
segm_expsum_ptr=segm_expsum,
|
||||||
|
seq_lens_ptr=seqused_k,
|
||||||
|
num_seqs=num_seqs,
|
||||||
|
num_query_heads=num_query_heads,
|
||||||
|
out_scale_inv=1 / output_scale if output_scale is not None else 1.0,
|
||||||
|
output_stride_0=out.stride(0),
|
||||||
|
output_stride_1=out.stride(1),
|
||||||
|
block_table_stride=block_table.stride(0),
|
||||||
|
TILE_SIZE=TILE_SIZE_DECODE,
|
||||||
|
HEAD_SIZE=head_size,
|
||||||
|
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||||
|
query_start_len_ptr=cu_seqlens_q,
|
||||||
|
BLOCK_Q=BLOCK_Q,
|
||||||
|
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
|
||||||
|
USE_FP8=output_scale is not None,
|
||||||
|
)
|
||||||
178
attention/ops/vit_attn_wrappers.py
Normal file
178
attention/ops/vit_attn_wrappers.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
This file contains ops for ViT attention to be compatible with torch.compile
|
||||||
|
as there are operations here not supported by torch.compile (for instance,
|
||||||
|
`to_list` in xformers attn, or `.item()` in flash attention)
|
||||||
|
|
||||||
|
Using these ops and wrapping vision blocks with `torch.compile` can speed up
|
||||||
|
throughput in vision models by ~5% relative on H100, and improve token
|
||||||
|
latencies by ~7% (see qwen2_5_vl for example usage)
|
||||||
|
|
||||||
|
To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
|
def xformers_attn_seqlens_wrapper(
|
||||||
|
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from xformers import ops as xops
|
||||||
|
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||||
|
|
||||||
|
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||||
|
q_seqlen=seqlens.tolist(), kv_seqlen=None, device=q.device
|
||||||
|
)
|
||||||
|
context_layer = xops.memory_efficient_attention_forward(
|
||||||
|
q, k, v, attn_bias=attn_bias, p=0, scale=None
|
||||||
|
)
|
||||||
|
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
|
||||||
|
return context_layer
|
||||||
|
|
||||||
|
|
||||||
|
def xformers_attn_seqlens_wrapper_fake(
|
||||||
|
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
b, s, h, d = q.shape
|
||||||
|
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="xformers_attn_seqlens_wrapper",
|
||||||
|
op_func=xformers_attn_seqlens_wrapper,
|
||||||
|
fake_impl=xformers_attn_seqlens_wrapper_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def vit_xformers_attn_wrapper(
|
||||||
|
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.vllm.xformers_attn_seqlens_wrapper(q, k, v, seqlens)
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attn_maxseqlen_wrapper(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
max_seqlen: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
is_rocm_aiter: bool,
|
||||||
|
use_upstream_fa: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if is_rocm_aiter:
|
||||||
|
from aiter import flash_attn_varlen_func
|
||||||
|
else:
|
||||||
|
if use_upstream_fa:
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
else:
|
||||||
|
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||||
|
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||||
|
output = flash_attn_varlen_func(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q=cu_seqlens,
|
||||||
|
cu_seqlens_k=cu_seqlens,
|
||||||
|
max_seqlen_q=max_seqlen.item(),
|
||||||
|
max_seqlen_k=max_seqlen.item(),
|
||||||
|
dropout_p=0.0,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
context_layer = einops.rearrange(
|
||||||
|
output, "(b s) h d -> s b (h d)", b=batch_size
|
||||||
|
).contiguous()
|
||||||
|
return context_layer
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attn_maxseqlen_wrapper_fake(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
max_seqlen: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
is_rocm_aiter: bool,
|
||||||
|
use_upstream_fa: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
b, s, h, d = q.shape
|
||||||
|
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="flash_attn_maxseqlen_wrapper",
|
||||||
|
op_func=flash_attn_maxseqlen_wrapper,
|
||||||
|
fake_impl=flash_attn_maxseqlen_wrapper_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def vit_flash_attn_wrapper(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
max_seqlen: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
is_rocm_aiter: bool,
|
||||||
|
use_upstream_fa: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
|
||||||
|
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Once we have a torch 2.10, we can use tensor slices
|
||||||
|
# so we won't need to wrap this in custom ops
|
||||||
|
def torch_sdpa_wrapper(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
outputs = []
|
||||||
|
for i in range(1, len(cu_seqlens)):
|
||||||
|
start_idx = cu_seqlens[i - 1]
|
||||||
|
end_idx = cu_seqlens[i]
|
||||||
|
q_i = q[:, start_idx:end_idx]
|
||||||
|
k_i = k[:, start_idx:end_idx]
|
||||||
|
v_i = v[:, start_idx:end_idx]
|
||||||
|
q_i, k_i, v_i = (
|
||||||
|
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
||||||
|
)
|
||||||
|
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
|
||||||
|
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
|
||||||
|
outputs.append(output_i)
|
||||||
|
context_layer = torch.cat(outputs, dim=1)
|
||||||
|
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
|
||||||
|
return context_layer
|
||||||
|
|
||||||
|
|
||||||
|
def torch_sdpa_wrapper_fake(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
b, s, h, d = q.shape
|
||||||
|
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="torch_sdpa_wrapper",
|
||||||
|
op_func=torch_sdpa_wrapper,
|
||||||
|
fake_impl=torch_sdpa_wrapper_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def vit_torch_sdpa_wrapper(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens)
|
||||||
231
attention/selector.py
Normal file
231
attention/selector.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
from collections.abc import Generator
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from functools import cache
|
||||||
|
from typing import cast, get_args
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
|
from vllm.config.cache import CacheDType
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import STR_BACKEND_ENV_VAR
|
||||||
|
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
|
||||||
|
"""
|
||||||
|
Get the backend override specified by the vLLM attention
|
||||||
|
backend environment variable, if one is specified.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* AttentionBackendEnum value if an override is specified
|
||||||
|
* None otherwise
|
||||||
|
"""
|
||||||
|
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
|
||||||
|
return None if backend_name is None else AttentionBackendEnum[backend_name]
|
||||||
|
|
||||||
|
|
||||||
|
# Global state allows a particular choice of backend
|
||||||
|
# to be forced, overriding the logic which auto-selects
|
||||||
|
# a backend based on system & workload configuration
|
||||||
|
# (default behavior if this variable is None)
|
||||||
|
#
|
||||||
|
# THIS SELECTION TAKES PRECEDENCE OVER THE
|
||||||
|
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
|
||||||
|
forced_attn_backend: AttentionBackendEnum | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None:
|
||||||
|
"""
|
||||||
|
Force all attention operations to use a specified backend.
|
||||||
|
|
||||||
|
Passing `None` for the argument re-enables automatic
|
||||||
|
backend selection.,
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* attn_backend: backend selection (None to revert to auto)
|
||||||
|
"""
|
||||||
|
global forced_attn_backend
|
||||||
|
forced_attn_backend = attn_backend
|
||||||
|
|
||||||
|
|
||||||
|
def get_global_forced_attn_backend() -> AttentionBackendEnum | None:
|
||||||
|
"""
|
||||||
|
Get the currently-forced choice of attention backend,
|
||||||
|
or None if auto-selection is currently enabled.
|
||||||
|
"""
|
||||||
|
return forced_attn_backend
|
||||||
|
|
||||||
|
|
||||||
|
def get_attn_backend(
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: str | None,
|
||||||
|
block_size: int | None,
|
||||||
|
use_mla: bool = False,
|
||||||
|
has_sink: bool = False,
|
||||||
|
use_sparse: bool = False,
|
||||||
|
attn_type: str | None = None,
|
||||||
|
) -> type[AttentionBackend]:
|
||||||
|
"""Selects which attention backend to use and lazily imports it."""
|
||||||
|
|
||||||
|
if kv_cache_dtype is not None:
|
||||||
|
valid_cache_dtypes = get_args(CacheDType)
|
||||||
|
assert kv_cache_dtype in valid_cache_dtypes, (
|
||||||
|
f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
|
||||||
|
f"Valid values are: {valid_cache_dtypes}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return _cached_get_attn_backend(
|
||||||
|
head_size=head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
|
||||||
|
block_size=block_size,
|
||||||
|
use_mla=use_mla,
|
||||||
|
has_sink=has_sink,
|
||||||
|
use_sparse=use_sparse,
|
||||||
|
attn_type=attn_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def _cached_get_attn_backend(
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: CacheDType | None,
|
||||||
|
block_size: int | None,
|
||||||
|
use_mla: bool = False,
|
||||||
|
has_sink: bool = False,
|
||||||
|
use_sparse: bool = False,
|
||||||
|
attn_type: str | None = None,
|
||||||
|
) -> type[AttentionBackend]:
|
||||||
|
# Check whether a particular choice of backend was
|
||||||
|
# previously forced.
|
||||||
|
#
|
||||||
|
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||||
|
# ENVIRONMENT VARIABLE.
|
||||||
|
selected_backend = None
|
||||||
|
backend_by_global_setting: AttentionBackendEnum | None = (
|
||||||
|
get_global_forced_attn_backend()
|
||||||
|
)
|
||||||
|
if backend_by_global_setting is not None:
|
||||||
|
selected_backend = backend_by_global_setting
|
||||||
|
else:
|
||||||
|
# Check the environment variable and override if specified
|
||||||
|
backend_by_env_var: str | None = envs.VLLM_ATTENTION_BACKEND
|
||||||
|
if backend_by_env_var is not None:
|
||||||
|
if backend_by_env_var.endswith("_VLLM_V1"):
|
||||||
|
logger.warning(
|
||||||
|
"The suffix '_VLLM_V1' in the environment variable "
|
||||||
|
"%s is no longer necessary as V0 backends have been "
|
||||||
|
"deprecated. Please remove this suffix from your "
|
||||||
|
"environment variable setting.",
|
||||||
|
STR_BACKEND_ENV_VAR,
|
||||||
|
)
|
||||||
|
backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
|
||||||
|
try:
|
||||||
|
selected_backend = AttentionBackendEnum[backend_by_env_var]
|
||||||
|
except KeyError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid attention backend: '{backend_by_env_var}'. Valid "
|
||||||
|
f"backends are: {list(AttentionBackendEnum.__members__.keys())}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# get device-specific attn_backend
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
sig = inspect.signature(current_platform.get_attn_backend_cls)
|
||||||
|
if "use_v1" in sig.parameters:
|
||||||
|
logger.warning_once(
|
||||||
|
"use_v1 parameter for get_attn_backend_cls is deprecated and will "
|
||||||
|
"be removed in v0.13.0 or v1.0.0, whichever is soonest. Please "
|
||||||
|
"remove it from your plugin code."
|
||||||
|
)
|
||||||
|
attention_cls = current_platform.get_attn_backend_cls(
|
||||||
|
selected_backend,
|
||||||
|
head_size,
|
||||||
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
|
block_size,
|
||||||
|
True, # use_v1
|
||||||
|
use_mla,
|
||||||
|
has_sink,
|
||||||
|
use_sparse,
|
||||||
|
attn_type,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attention_cls = current_platform.get_attn_backend_cls(
|
||||||
|
selected_backend,
|
||||||
|
head_size,
|
||||||
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
|
block_size,
|
||||||
|
use_mla,
|
||||||
|
has_sink,
|
||||||
|
use_sparse,
|
||||||
|
attn_type,
|
||||||
|
)
|
||||||
|
if not attention_cls:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid attention backend for {current_platform.device_name}"
|
||||||
|
)
|
||||||
|
backend = resolve_obj_by_qualname(attention_cls)
|
||||||
|
|
||||||
|
# Adjust kv cache layout if the selected backend requires a specific one
|
||||||
|
required_layout = backend.get_required_kv_cache_layout()
|
||||||
|
if required_layout is not None:
|
||||||
|
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
||||||
|
|
||||||
|
set_kv_cache_layout(required_layout)
|
||||||
|
logger.info(
|
||||||
|
"Using %s KV cache layout for %s backend.",
|
||||||
|
required_layout,
|
||||||
|
backend.get_name(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return backend
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def global_force_attn_backend_context_manager(
|
||||||
|
attn_backend: AttentionBackendEnum,
|
||||||
|
) -> Generator[None, None, None]:
|
||||||
|
"""
|
||||||
|
Globally force a vLLM attention backend override within a
|
||||||
|
context manager, reverting the global attention backend
|
||||||
|
override to its prior state upon exiting the context
|
||||||
|
manager.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* attn_backend: attention backend to force
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* Generator
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Save the current state of the global backend override (if any)
|
||||||
|
original_value = get_global_forced_attn_backend()
|
||||||
|
|
||||||
|
# Globally force the new backend override
|
||||||
|
global_force_attn_backend(attn_backend)
|
||||||
|
|
||||||
|
# Yield control back to the enclosed code block
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# Revert the original global backend override, if any
|
||||||
|
global_force_attn_backend(original_value)
|
||||||
|
_cached_get_attn_backend.cache_clear()
|
||||||
0
attention/utils/__init__.py
Normal file
0
attention/utils/__init__.py
Normal file
BIN
attention/utils/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
attention/utils/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
attention/utils/__pycache__/fa_utils.cpython-312.pyc
Normal file
BIN
attention/utils/__pycache__/fa_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
attention/utils/__pycache__/kv_sharing_utils.cpython-312.pyc
Normal file
BIN
attention/utils/__pycache__/kv_sharing_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
attention/utils/__pycache__/kv_transfer_utils.cpython-312.pyc
Normal file
BIN
attention/utils/__pycache__/kv_transfer_utils.cpython-312.pyc
Normal file
Binary file not shown.
108
attention/utils/fa_utils.py
Normal file
108
attention/utils/fa_utils.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
|
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||||
|
from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache, flash_attn_varlen_int8_func
|
||||||
|
|
||||||
|
elif current_platform.is_xpu():
|
||||||
|
from vllm._ipex_ops import ipex_ops as ops
|
||||||
|
|
||||||
|
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||||
|
flash_attn_varlen_func = ops.flash_attn_varlen_func
|
||||||
|
flash_attn_with_kvcache = ops.flash_attn_with_kvcache
|
||||||
|
get_scheduler_metadata = ops.get_scheduler_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||||
|
# import here to avoid circular dependencies
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
if current_platform.is_xpu():
|
||||||
|
return 2
|
||||||
|
try:
|
||||||
|
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||||
|
fa_version_unsupported_reason,
|
||||||
|
is_fa_version_supported,
|
||||||
|
)
|
||||||
|
|
||||||
|
device_capability = current_platform.get_device_capability()
|
||||||
|
|
||||||
|
assert device_capability is not None
|
||||||
|
|
||||||
|
# 1. default version depending on platform
|
||||||
|
fa_version = (
|
||||||
|
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. override if passed by environment
|
||||||
|
if envs.VLLM_FLASH_ATTN_VERSION is not None:
|
||||||
|
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
|
||||||
|
fa_version = envs.VLLM_FLASH_ATTN_VERSION
|
||||||
|
|
||||||
|
# 3. fallback for unsupported combinations
|
||||||
|
if device_capability.major == 10 and fa_version == 3:
|
||||||
|
logger.warning_once(
|
||||||
|
"Cannot use FA version 3 on Blackwell platform "
|
||||||
|
"defaulting to FA version 2."
|
||||||
|
)
|
||||||
|
fa_version = 2
|
||||||
|
|
||||||
|
if requires_alibi and fa_version == 3:
|
||||||
|
logger.warning_once(
|
||||||
|
"Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
|
||||||
|
)
|
||||||
|
fa_version = 2
|
||||||
|
|
||||||
|
if not is_fa_version_supported(fa_version):
|
||||||
|
logger.error(
|
||||||
|
"Cannot use FA version %d is not supported due to %s",
|
||||||
|
fa_version,
|
||||||
|
fa_version_unsupported_reason(fa_version),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert is_fa_version_supported(fa_version)
|
||||||
|
return fa_version
|
||||||
|
except (ImportError, AssertionError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attn_supports_fp8() -> bool:
|
||||||
|
return (
|
||||||
|
get_flash_attn_version() == 3
|
||||||
|
and current_platform.get_device_capability().major == 9
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attn_supports_sinks() -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attn_supports_mla():
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
try:
|
||||||
|
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||||
|
is_fa_version_supported,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
is_fa_version_supported(3)
|
||||||
|
and current_platform.get_device_capability()[0] == 9
|
||||||
|
)
|
||||||
|
except (ImportError, AssertionError):
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_flash_attn_varlen_func_available() -> bool:
|
||||||
|
return current_platform.is_cuda() or current_platform.is_xpu()
|
||||||
33
attention/utils/kv_sharing_utils.py
Normal file
33
attention/utils/kv_sharing_utils.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
def validate_kv_sharing_target(
|
||||||
|
current_layer_name, target_layer_name, static_forward_context
|
||||||
|
):
|
||||||
|
error_msg = (
|
||||||
|
f"Specified KV sharing target layer for {current_layer_name} "
|
||||||
|
f"is not valid: target layer {target_layer_name} "
|
||||||
|
)
|
||||||
|
|
||||||
|
if current_layer_name == target_layer_name:
|
||||||
|
raise ValueError(error_msg + "cannot be the same as the current layer.")
|
||||||
|
|
||||||
|
if target_layer_name not in static_forward_context:
|
||||||
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
|
|
||||||
|
# If target layer name is not in the static fwd context, it means either
|
||||||
|
# a) the target layer does not come BEFORE the current layer, or
|
||||||
|
# b) the target layer is not an Attention layer that exists in the model
|
||||||
|
current_layer_idx = extract_layer_index(current_layer_name)
|
||||||
|
target_layer_idx = extract_layer_index(target_layer_name)
|
||||||
|
if current_layer_idx <= target_layer_idx:
|
||||||
|
raise ValueError(error_msg + "must come before the current layer.")
|
||||||
|
else:
|
||||||
|
raise ValueError(error_msg + "is not a valid Attention layer in the model.")
|
||||||
|
|
||||||
|
# Currently KV sharing is only supported between layers of the same type
|
||||||
|
target_layer_attn_type = static_forward_context[target_layer_name].attn_type
|
||||||
|
expected = static_forward_context[current_layer_name].attn_type
|
||||||
|
if target_layer_attn_type != expected:
|
||||||
|
raise ValueError(
|
||||||
|
error_msg + f"must be the same type as the current layer ({expected})."
|
||||||
|
)
|
||||||
60
attention/utils/kv_transfer_utils.py
Normal file
60
attention/utils/kv_transfer_utils.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import inspect
|
||||||
|
from collections.abc import Callable
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from vllm.distributed.kv_transfer import (
|
||||||
|
get_kv_transfer_group,
|
||||||
|
has_kv_transfer_group,
|
||||||
|
is_v1_kv_transfer_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_transfer_kv_layer(func: Callable) -> Callable:
|
||||||
|
"""Decorator that handles KV layer transfer prior and after execution of
|
||||||
|
an attention layer, if enabled. Otherwise, the wrapper is a no-op.
|
||||||
|
|
||||||
|
On entry: waits for the KV layer from the connector.
|
||||||
|
On exit: saves the KV layer to the connector.
|
||||||
|
"""
|
||||||
|
# Import at runtime to avoid circular dependency
|
||||||
|
from vllm.attention.layer import get_attention_context
|
||||||
|
|
||||||
|
# Inspect the signature ONCE when the decorator is applied.
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
param_names = list(sig.parameters.keys())
|
||||||
|
|
||||||
|
# Find the index of 'layer_name' parameter.
|
||||||
|
try:
|
||||||
|
layer_name_index = param_names.index("layer_name")
|
||||||
|
except ValueError as e:
|
||||||
|
raise TypeError(
|
||||||
|
f"Function {func.__name__} must have a 'layer_name' parameter"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
layer_name: str = args[layer_name_index]
|
||||||
|
|
||||||
|
# Extract attention context (layer-specific metadata, layer, and kv_cache)
|
||||||
|
attn_metadata, attn_layer, kv_cache = get_attention_context(layer_name)
|
||||||
|
connector = get_kv_transfer_group()
|
||||||
|
if attn_metadata is None or not connector.has_connector_metadata():
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Wait for KV layer on entry
|
||||||
|
connector.wait_for_layer_load(layer_name)
|
||||||
|
|
||||||
|
# Execute the function
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Save KV cache layer on exit
|
||||||
|
connector.save_kv_layer(layer_name, kv_cache, attn_metadata)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
||||||
88
beam_search.py
Normal file
88
beam_search.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
|
from vllm.logprobs import Logprob
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.multimodal import MultiModalDataDict
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BeamSearchSequence:
|
||||||
|
"""A sequence for beam search.
|
||||||
|
It keeps track of the tokens and the log probability of the sequence.
|
||||||
|
The text field is optional and will only be filled when the sequence is
|
||||||
|
about to be returned to the user.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The tokens include the prompt.
|
||||||
|
tokens: list[int]
|
||||||
|
logprobs: list[dict[int, Logprob]]
|
||||||
|
lora_request: LoRARequest | None = None
|
||||||
|
cum_logprob: float = 0.0
|
||||||
|
text: str | None = None
|
||||||
|
finish_reason: str | None = None
|
||||||
|
stop_reason: int | str | None = None
|
||||||
|
multi_modal_data: Optional["MultiModalDataDict"] = None
|
||||||
|
mm_processor_kwargs: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BeamSearchOutput:
|
||||||
|
"""The output of beam search.
|
||||||
|
It contains the list of the best beam search sequences.
|
||||||
|
The length of the list is equal to the beam width.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sequences: list[BeamSearchSequence]
|
||||||
|
|
||||||
|
|
||||||
|
class BeamSearchInstance:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prompt_tokens: list[int],
|
||||||
|
lora_request: LoRARequest | None = None,
|
||||||
|
logprobs: list[dict[int, Logprob]] | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.beams: list[BeamSearchSequence] = [
|
||||||
|
BeamSearchSequence(
|
||||||
|
tokens=prompt_tokens,
|
||||||
|
logprobs=[] if logprobs is None else list(logprobs),
|
||||||
|
lora_request=lora_request,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
self.completed: list[BeamSearchSequence] = []
|
||||||
|
|
||||||
|
|
||||||
|
def get_beam_search_score(
|
||||||
|
tokens: list[int],
|
||||||
|
cumulative_logprob: float,
|
||||||
|
eos_token_id: int,
|
||||||
|
length_penalty: float = 1.0,
|
||||||
|
) -> float:
|
||||||
|
"""Calculate the beam search score with length penalty.
|
||||||
|
|
||||||
|
Adapted from
|
||||||
|
|
||||||
|
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
|
||||||
|
"""
|
||||||
|
seq_len = len(tokens)
|
||||||
|
if tokens[-1] == eos_token_id:
|
||||||
|
seq_len -= 1
|
||||||
|
|
||||||
|
return cumulative_logprob / (seq_len**length_penalty)
|
||||||
|
|
||||||
|
|
||||||
|
def create_sort_beams_key_function(eos_token_id: int, length_penalty: float):
|
||||||
|
def sort_beams_key(x: BeamSearchSequence) -> float:
|
||||||
|
return get_beam_search_score(
|
||||||
|
x.tokens, x.cum_logprob, eos_token_id, length_penalty
|
||||||
|
)
|
||||||
|
|
||||||
|
return sort_beams_key
|
||||||
0
benchmarks/__init__.py
Normal file
0
benchmarks/__init__.py
Normal file
BIN
benchmarks/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
benchmarks/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
benchmarks/__pycache__/datasets.cpython-312.pyc
Normal file
BIN
benchmarks/__pycache__/datasets.cpython-312.pyc
Normal file
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user