first commit
This commit is contained in:
BIN
vllm/_C.abi3.so
Normal file
BIN
vllm/_C.abi3.so
Normal file
Binary file not shown.
102
vllm/__init__.py
Normal file
102
vllm/__init__.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
|
||||||
|
|
||||||
|
# The version.py should be independent library, and we always import the
|
||||||
|
# version library first. Such assumption is critical for some customization.
|
||||||
|
from .version import __version__, __version_tuple__ # isort:skip
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
# The environment variables override should be imported before any other
|
||||||
|
# modules to ensure that the environment variables are set before any
|
||||||
|
# other modules are imported.
|
||||||
|
import vllm.env_override # noqa: F401
|
||||||
|
|
||||||
|
MODULE_ATTRS = {
|
||||||
|
"bc_linter_skip": "._bc_linter:bc_linter_skip",
|
||||||
|
"bc_linter_include": "._bc_linter:bc_linter_include",
|
||||||
|
"AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs",
|
||||||
|
"EngineArgs": ".engine.arg_utils:EngineArgs",
|
||||||
|
"AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine",
|
||||||
|
"LLMEngine": ".engine.llm_engine:LLMEngine",
|
||||||
|
"LLM": ".entrypoints.llm:LLM",
|
||||||
|
"initialize_ray_cluster": ".executor.ray_utils:initialize_ray_cluster",
|
||||||
|
"PromptType": ".inputs:PromptType",
|
||||||
|
"TextPrompt": ".inputs:TextPrompt",
|
||||||
|
"TokensPrompt": ".inputs:TokensPrompt",
|
||||||
|
"ModelRegistry": ".model_executor.models:ModelRegistry",
|
||||||
|
"SamplingParams": ".sampling_params:SamplingParams",
|
||||||
|
"PoolingParams": ".pooling_params:PoolingParams",
|
||||||
|
"ClassificationOutput": ".outputs:ClassificationOutput",
|
||||||
|
"ClassificationRequestOutput": ".outputs:ClassificationRequestOutput",
|
||||||
|
"CompletionOutput": ".outputs:CompletionOutput",
|
||||||
|
"EmbeddingOutput": ".outputs:EmbeddingOutput",
|
||||||
|
"EmbeddingRequestOutput": ".outputs:EmbeddingRequestOutput",
|
||||||
|
"PoolingOutput": ".outputs:PoolingOutput",
|
||||||
|
"PoolingRequestOutput": ".outputs:PoolingRequestOutput",
|
||||||
|
"RequestOutput": ".outputs:RequestOutput",
|
||||||
|
"ScoringOutput": ".outputs:ScoringOutput",
|
||||||
|
"ScoringRequestOutput": ".outputs:ScoringRequestOutput",
|
||||||
|
}
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
|
from vllm.entrypoints.llm import LLM
|
||||||
|
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||||
|
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
||||||
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
from vllm.outputs import (ClassificationOutput,
|
||||||
|
ClassificationRequestOutput, CompletionOutput,
|
||||||
|
EmbeddingOutput, EmbeddingRequestOutput,
|
||||||
|
PoolingOutput, PoolingRequestOutput,
|
||||||
|
RequestOutput, ScoringOutput,
|
||||||
|
ScoringRequestOutput)
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
from ._bc_linter import bc_linter_include, bc_linter_skip
|
||||||
|
else:
|
||||||
|
|
||||||
|
def __getattr__(name: str) -> typing.Any:
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
|
if name in MODULE_ATTRS:
|
||||||
|
module_name, attr_name = MODULE_ATTRS[name].split(":")
|
||||||
|
module = import_module(module_name, __package__)
|
||||||
|
return getattr(module, attr_name)
|
||||||
|
else:
|
||||||
|
raise AttributeError(
|
||||||
|
f'module {__package__} has no attribute {name}')
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"__version__",
|
||||||
|
"bc_linter_skip",
|
||||||
|
"bc_linter_include",
|
||||||
|
"__version_tuple__",
|
||||||
|
"LLM",
|
||||||
|
"ModelRegistry",
|
||||||
|
"PromptType",
|
||||||
|
"TextPrompt",
|
||||||
|
"TokensPrompt",
|
||||||
|
"SamplingParams",
|
||||||
|
"RequestOutput",
|
||||||
|
"CompletionOutput",
|
||||||
|
"PoolingOutput",
|
||||||
|
"PoolingRequestOutput",
|
||||||
|
"EmbeddingOutput",
|
||||||
|
"EmbeddingRequestOutput",
|
||||||
|
"ClassificationOutput",
|
||||||
|
"ClassificationRequestOutput",
|
||||||
|
"ScoringOutput",
|
||||||
|
"ScoringRequestOutput",
|
||||||
|
"LLMEngine",
|
||||||
|
"EngineArgs",
|
||||||
|
"AsyncLLMEngine",
|
||||||
|
"AsyncEngineArgs",
|
||||||
|
"initialize_ray_cluster",
|
||||||
|
"PoolingParams",
|
||||||
|
]
|
||||||
BIN
vllm/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/_bc_linter.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/_bc_linter.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/_custom_ops.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/_custom_ops.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/_ipex_ops.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/_ipex_ops.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/_version.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/_version.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/beam_search.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/beam_search.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/collect_env.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/collect_env.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/connections.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/connections.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/env_override.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/env_override.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/envs.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/envs.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/forward_context.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/forward_context.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/logger.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/logger.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/logits_process.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/logits_process.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/logprobs.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/logprobs.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/outputs.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/outputs.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/pooling_params.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/pooling_params.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/sampling_params.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/sampling_params.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/scalar_type.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/scalar_type.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/scripts.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/scripts.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/sequence.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/sequence.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/tasks.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/tasks.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/test_utils.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/test_utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/tracing.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/tracing.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/version.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/version.cpython-310.pyc
Normal file
Binary file not shown.
59
vllm/_bc_linter.py
Normal file
59
vllm/_bc_linter.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# vllm/_bc_linter.py
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Callable, TypeVar, overload
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def bc_linter_skip(obj: T) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def bc_linter_skip(*, reason: str | None = ...) -> Callable[[T], T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def bc_linter_skip(obj: Any = None, *, reason: str | None = None):
|
||||||
|
"""
|
||||||
|
No-op decorator to mark symbols/files for BC-linter suppression.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@bc_linter_skip
|
||||||
|
def legacy_api(...): ...
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _wrap(x: T) -> T:
|
||||||
|
return x
|
||||||
|
|
||||||
|
return _wrap if obj is None else obj
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def bc_linter_include(obj: T) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def bc_linter_include(*, reason: str | None = ...) -> Callable[[T], T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def bc_linter_include(obj: Any = None, *, reason: str | None = None):
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
@bc_linter_include
|
||||||
|
def public_api(...): ...
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _wrap(x: T) -> T:
|
||||||
|
return x
|
||||||
|
|
||||||
|
return _wrap if obj is None else obj
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["bc_linter_skip", "bc_linter_include"]
|
||||||
2044
vllm/_custom_ops.py
Normal file
2044
vllm/_custom_ops.py
Normal file
File diff suppressed because it is too large
Load Diff
BIN
vllm/_flashmla_C.abi3.so
Normal file
BIN
vllm/_flashmla_C.abi3.so
Normal file
Binary file not shown.
BIN
vllm/_flashmla_extension_C.abi3.so
Normal file
BIN
vllm/_flashmla_extension_C.abi3.so
Normal file
Binary file not shown.
393
vllm/_ipex_ops.py
Normal file
393
vllm/_ipex_ops.py
Normal file
@@ -0,0 +1,393 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
except ImportError as e:
|
||||||
|
logger.debug("Import error msg: %s", e.msg)
|
||||||
|
|
||||||
|
|
||||||
|
class ipex_ops:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reshape_activation_tensor(
|
||||||
|
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
num = x.size(0)
|
||||||
|
d = x.size(1) // 2
|
||||||
|
x = x.reshape(num, 2, d)
|
||||||
|
x1, x2 = torch.chunk(x, chunks=2, dim=1)
|
||||||
|
x1 = x1.reshape(num, d)
|
||||||
|
x2 = x2.reshape(num, d)
|
||||||
|
return x1, x2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||||
|
ipex.llm.functional.silu_and_mul(x, out)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||||
|
ipex.llm.functional.gelu_and_mul(x, out)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||||
|
ipex.llm.functional.gelu_and_mul(x, out)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gelu_fast(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.nn.functional.gelu(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gelu_new(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.nn.functional.gelu(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||||
|
ipex.llm.functional.gelu_quick(x, out)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def paged_attention_v1(
|
||||||
|
out: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
num_kv_heads: int,
|
||||||
|
scale: float,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
context_lens: torch.Tensor,
|
||||||
|
block_size: int,
|
||||||
|
max_context_len: int,
|
||||||
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
|
tp_rank: int = 0,
|
||||||
|
blocksparse_local_blocks: int = 0,
|
||||||
|
blocksparse_vert_stride: int = 0,
|
||||||
|
blocksparse_block_size: int = 64,
|
||||||
|
blocksparse_head_sliding_step: int = 0,
|
||||||
|
) -> None:
|
||||||
|
assert kv_cache_dtype == "auto"
|
||||||
|
num_heads = out.size(1)
|
||||||
|
num_queries_per_tokens = num_heads // num_kv_heads
|
||||||
|
ipex.llm.modules.PagedAttention.single_query_kv_attention(
|
||||||
|
out,
|
||||||
|
query.contiguous(),
|
||||||
|
key_cache.view_as(value_cache),
|
||||||
|
value_cache,
|
||||||
|
num_queries_per_tokens,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
block_size,
|
||||||
|
max_context_len,
|
||||||
|
alibi_slopes,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def paged_attention_v2(
|
||||||
|
out: torch.Tensor,
|
||||||
|
exp_sum: torch.Tensor,
|
||||||
|
max_logits: torch.Tensor,
|
||||||
|
tmp_out: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
num_kv_heads: int,
|
||||||
|
scale: float,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
context_lens: torch.Tensor,
|
||||||
|
block_size: int,
|
||||||
|
max_context_len: int,
|
||||||
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
|
tp_rank: int = 0,
|
||||||
|
blocksparse_local_blocks: int = 0,
|
||||||
|
blocksparse_vert_stride: int = 0,
|
||||||
|
blocksparse_block_size: int = 64,
|
||||||
|
blocksparse_head_sliding_step: int = 0,
|
||||||
|
) -> None:
|
||||||
|
assert kv_cache_dtype == "auto"
|
||||||
|
num_heads = out.size(1)
|
||||||
|
num_queries_per_tokens = num_heads // num_kv_heads
|
||||||
|
ipex.llm.modules.PagedAttention.single_query_kv_attention(
|
||||||
|
out,
|
||||||
|
query.contiguous(),
|
||||||
|
key_cache.view_as(value_cache),
|
||||||
|
value_cache,
|
||||||
|
num_queries_per_tokens,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
block_size,
|
||||||
|
max_context_len,
|
||||||
|
alibi_slopes,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def rotary_embedding(
|
||||||
|
positions: torch.Tensor, # [batch_size, seq_len]
|
||||||
|
query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size]
|
||||||
|
key: torch.Tensor, # [batch_size, seq_len, num_kv_heads*head_size]
|
||||||
|
head_size: int,
|
||||||
|
cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim]
|
||||||
|
is_neox: bool,
|
||||||
|
) -> None:
|
||||||
|
rot_dim = cos_sin_cache.size(1)
|
||||||
|
ipex.llm.functional.rotary_embedding_batched(positions, query, key,
|
||||||
|
head_size, cos_sin_cache,
|
||||||
|
is_neox, rot_dim)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def rms_norm(input: torch.Tensor, weight: torch.Tensor,
|
||||||
|
epsilon: float) -> torch.Tensor:
|
||||||
|
return ipex.llm.functional.rms_norm(input, weight, epsilon)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
|
||||||
|
weight: torch.Tensor, epsilon: float) -> None:
|
||||||
|
tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
|
||||||
|
epsilon, True)
|
||||||
|
input.copy_(tmp)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def varlen_attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
out: torch.Tensor,
|
||||||
|
seqlen_q: torch.Tensor,
|
||||||
|
seqlen_k: torch.Tensor,
|
||||||
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
|
max_seqlen_q: int,
|
||||||
|
max_seqlen_k: int,
|
||||||
|
pdropout: float,
|
||||||
|
softmax_scale: float,
|
||||||
|
zero_tensors: bool,
|
||||||
|
is_causal: bool,
|
||||||
|
return_softmax: bool,
|
||||||
|
gen_: torch.Generator,
|
||||||
|
window_size_left: float,
|
||||||
|
window_size_right: float,
|
||||||
|
logits_soft_cap: float,
|
||||||
|
) -> None:
|
||||||
|
if ipex.__version__.endswith("cpu"):
|
||||||
|
if logits_soft_cap != 0.0:
|
||||||
|
raise ValueError("IPEX CPU does not support logits_soft_cap")
|
||||||
|
assert alibi_slopes is None
|
||||||
|
assert window_size_left < 0 and window_size_right < 0
|
||||||
|
ipex.llm.functional.varlen_attention(query.contiguous(),
|
||||||
|
key.contiguous(),
|
||||||
|
value.contiguous(), out,
|
||||||
|
seqlen_q.int(),
|
||||||
|
seqlen_k.int(), max_seqlen_q,
|
||||||
|
max_seqlen_k, pdropout,
|
||||||
|
softmax_scale, zero_tensors,
|
||||||
|
is_causal, return_softmax,
|
||||||
|
gen_)
|
||||||
|
else: # XPU build
|
||||||
|
ipex.llm.functional.varlen_attention(
|
||||||
|
query.contiguous(), key.contiguous(), value.contiguous(), out,
|
||||||
|
seqlen_q.int(), seqlen_k.int(), alibi_slopes, max_seqlen_q,
|
||||||
|
max_seqlen_k, pdropout, softmax_scale, zero_tensors, is_causal,
|
||||||
|
return_softmax, gen_, window_size_left, window_size_right,
|
||||||
|
logits_soft_cap)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reshape_and_cache(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
|
) -> None:
|
||||||
|
assert kv_cache_dtype == "auto"
|
||||||
|
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||||
|
key, value, key_cache, value_cache, slot_mapping)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reshape_and_cache_flash(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: Optional[torch.Tensor] = None,
|
||||||
|
v_scale: Optional[torch.Tensor] = None,
|
||||||
|
k_scale_float: float = 1.0,
|
||||||
|
v_scale_float: float = 1.0,
|
||||||
|
) -> None:
|
||||||
|
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
|
||||||
|
key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
|
||||||
|
k_scale_float, v_scale_float)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def flash_attn_varlen_func(
|
||||||
|
out: torch.Tensor,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens_q: torch.Tensor,
|
||||||
|
seqused_k: torch.Tensor, # we don't support this in ipex kernel
|
||||||
|
max_seqlen_q: int,
|
||||||
|
max_seqlen_k: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
causal: bool,
|
||||||
|
block_table: torch.Tensor,
|
||||||
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
|
window_size: Optional[list[int]] = None,
|
||||||
|
softcap: Optional[float] = 0.0,
|
||||||
|
cu_seqlens_k: Optional[torch.Tensor] = None,
|
||||||
|
# The following parameters are not used in ipex kernel currently,
|
||||||
|
# we keep API compatible to CUDA's.
|
||||||
|
scheduler_metadata=None,
|
||||||
|
fa_version: int = 2,
|
||||||
|
q_descale=None,
|
||||||
|
k_descale=None,
|
||||||
|
v_descale=None,
|
||||||
|
num_splits=0,
|
||||||
|
s_aux: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
if cu_seqlens_k is None:
|
||||||
|
# cu_seqlens_k is not used in ipex kernel.
|
||||||
|
cu_seqlens_k = torch.cumsum(seqused_k, dim=0)
|
||||||
|
cu_seqlens_k = torch.cat([
|
||||||
|
torch.tensor([0], device=seqused_k.device, dtype=torch.int32),
|
||||||
|
cu_seqlens_k
|
||||||
|
]).to(torch.int32)
|
||||||
|
|
||||||
|
real_window_size: tuple[int, int]
|
||||||
|
if window_size is None:
|
||||||
|
real_window_size = (-1, -1)
|
||||||
|
else:
|
||||||
|
assert len(window_size) == 2
|
||||||
|
real_window_size = (window_size[0], window_size[1])
|
||||||
|
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||||
|
out,
|
||||||
|
q.contiguous(),
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
softmax_scale,
|
||||||
|
causal,
|
||||||
|
block_table,
|
||||||
|
alibi_slopes,
|
||||||
|
softcap=softcap,
|
||||||
|
window_size_left=real_window_size[0],
|
||||||
|
window_size_right=real_window_size[1],
|
||||||
|
k_scale=1.0,
|
||||||
|
v_scale=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_scheduler_metadata(
|
||||||
|
batch_size,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
num_heads_q,
|
||||||
|
num_heads_kv,
|
||||||
|
headdim,
|
||||||
|
cache_seqlens: torch.Tensor,
|
||||||
|
qkv_dtype=torch.bfloat16,
|
||||||
|
headdim_v=None,
|
||||||
|
cu_seqlens_q: Optional[torch.Tensor] = None,
|
||||||
|
cu_seqlens_k_new: Optional[torch.Tensor] = None,
|
||||||
|
cache_leftpad: Optional[torch.Tensor] = None,
|
||||||
|
page_size: Optional[int] = None,
|
||||||
|
max_seqlen_k_new=0,
|
||||||
|
causal=False,
|
||||||
|
window_size=(-1, -1), # -1 means infinite context window
|
||||||
|
has_softcap=False,
|
||||||
|
num_splits=0, # Can be tuned for speed
|
||||||
|
pack_gqa=None, # Can be tuned for speed
|
||||||
|
sm_margin=0, # Can be tuned if some SMs are used for communication
|
||||||
|
) -> None:
|
||||||
|
logger.warning_once(
|
||||||
|
"get_scheduler_metadata is not implemented for ipex_ops, "
|
||||||
|
"returning None.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def copy_blocks(key_caches: list[torch.Tensor],
|
||||||
|
value_caches: list[torch.Tensor],
|
||||||
|
block_mapping: torch.Tensor) -> None:
|
||||||
|
torch.xpu.copy_blocks( # type: ignore
|
||||||
|
key_caches,
|
||||||
|
value_caches,
|
||||||
|
block_mapping,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
|
||||||
|
block_mapping: torch.Tensor) -> None:
|
||||||
|
torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def scaled_fp8_quant(
|
||||||
|
input: torch.Tensor,
|
||||||
|
scale: Optional[torch.Tensor] = None,
|
||||||
|
num_token_padding: Optional[int] = None,
|
||||||
|
scale_ub: Optional[torch.Tensor] = None,
|
||||||
|
use_per_token_if_dynamic: bool = False,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Quantize input tensor to FP8 and return quantized tensor and scale.
|
||||||
|
|
||||||
|
This function is designed for both static and dynamic quantization:
|
||||||
|
If you provide the scale, it will use static scaling and if you omit
|
||||||
|
it, the scale will be determined dynamically. Currently, XPU platform
|
||||||
|
only supports dynamic quantization. The function also allows optional
|
||||||
|
padding of the output tensors for downstream kernels that will benefit
|
||||||
|
from padding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: The input tensor to be quantized to FP8
|
||||||
|
scale: Optional scaling factor for the FP8 quantization
|
||||||
|
scale_ub: Optional upper bound for scaling factor in dynamic
|
||||||
|
per token case
|
||||||
|
num_token_padding: If specified, pad the first dimension
|
||||||
|
of the output to at least this value.
|
||||||
|
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
||||||
|
in the dynamic quantization case.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
||||||
|
scaling factor.
|
||||||
|
"""
|
||||||
|
# This code assumes batch_dim and num_tokens are flattened
|
||||||
|
assert (input.ndim == 2)
|
||||||
|
shape: Union[tuple[int, int], torch.Size] = input.shape
|
||||||
|
out_dtype: torch.dtype = current_platform.fp8_dtype()
|
||||||
|
if num_token_padding:
|
||||||
|
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||||
|
if output is None:
|
||||||
|
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
||||||
|
else:
|
||||||
|
assert num_token_padding is None, \
|
||||||
|
"padding not supported if output passed in"
|
||||||
|
assert output.dtype == out_dtype
|
||||||
|
assert scale is None, "only dynamic fp8 quantization supported on XPU"
|
||||||
|
assert not use_per_token_if_dynamic, (
|
||||||
|
"per token dynamic fp8 quantization not supported on XPU")
|
||||||
|
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||||
|
torch.ops.torch_ipex.dynamic_scaled_fp8_quant(output, input, scale)
|
||||||
|
|
||||||
|
return output, scale
|
||||||
BIN
vllm/_moe_C.abi3.so
Normal file
BIN
vllm/_moe_C.abi3.so
Normal file
Binary file not shown.
34
vllm/_version.py
Normal file
34
vllm/_version.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# file generated by setuptools-scm
|
||||||
|
# don't change, don't track in version control
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"__version__",
|
||||||
|
"__version_tuple__",
|
||||||
|
"version",
|
||||||
|
"version_tuple",
|
||||||
|
"__commit_id__",
|
||||||
|
"commit_id",
|
||||||
|
]
|
||||||
|
|
||||||
|
TYPE_CHECKING = False
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
||||||
|
COMMIT_ID = Union[str, None]
|
||||||
|
else:
|
||||||
|
VERSION_TUPLE = object
|
||||||
|
COMMIT_ID = object
|
||||||
|
|
||||||
|
version: str
|
||||||
|
__version__: str
|
||||||
|
__version_tuple__: VERSION_TUPLE
|
||||||
|
version_tuple: VERSION_TUPLE
|
||||||
|
commit_id: COMMIT_ID
|
||||||
|
__commit_id__: COMMIT_ID
|
||||||
|
|
||||||
|
__version__ = version = '0.11.0'
|
||||||
|
__version_tuple__ = version_tuple = (0, 11, 0)
|
||||||
|
|
||||||
|
__commit_id__ = commit_id = 'gf71952c1c'
|
||||||
0
vllm/assets/__init__.py
Normal file
0
vllm/assets/__init__.py
Normal file
BIN
vllm/assets/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/assets/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/assets/__pycache__/audio.cpython-310.pyc
Normal file
BIN
vllm/assets/__pycache__/audio.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/assets/__pycache__/base.cpython-310.pyc
Normal file
BIN
vllm/assets/__pycache__/base.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/assets/__pycache__/image.cpython-310.pyc
Normal file
BIN
vllm/assets/__pycache__/image.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/assets/__pycache__/video.cpython-310.pyc
Normal file
BIN
vllm/assets/__pycache__/video.cpython-310.pyc
Normal file
Binary file not shown.
45
vllm/assets/audio.py
Normal file
45
vllm/assets/audio.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
|
from vllm.utils import PlaceholderModule
|
||||||
|
|
||||||
|
from .base import VLLM_S3_BUCKET_URL, get_vllm_public_assets
|
||||||
|
|
||||||
|
try:
|
||||||
|
import librosa
|
||||||
|
except ImportError:
|
||||||
|
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||||
|
|
||||||
|
ASSET_DIR = "multimodal_asset"
|
||||||
|
|
||||||
|
AudioAssetName = Literal["winning_call", "mary_had_lamb"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AudioAsset:
|
||||||
|
name: AudioAssetName
|
||||||
|
|
||||||
|
@property
|
||||||
|
def filename(self) -> str:
|
||||||
|
return f"{self.name}.ogg"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
|
||||||
|
audio_path = get_vllm_public_assets(filename=self.filename,
|
||||||
|
s3_prefix=ASSET_DIR)
|
||||||
|
return librosa.load(audio_path, sr=None)
|
||||||
|
|
||||||
|
def get_local_path(self) -> Path:
|
||||||
|
return get_vllm_public_assets(filename=self.filename,
|
||||||
|
s3_prefix=ASSET_DIR)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def url(self) -> str:
|
||||||
|
return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg")
|
||||||
41
vllm/assets/base.py
Normal file
41
vllm/assets/base.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.connections import global_http_connection
|
||||||
|
|
||||||
|
VLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com"
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_dir() -> Path:
|
||||||
|
"""Get the path to the cache for storing downloaded assets."""
|
||||||
|
path = Path(envs.VLLM_ASSETS_CACHE)
|
||||||
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_vllm_public_assets(filename: str,
|
||||||
|
s3_prefix: Optional[str] = None) -> Path:
|
||||||
|
"""
|
||||||
|
Download an asset file from ``s3://vllm-public-assets``
|
||||||
|
and return the path to the downloaded file.
|
||||||
|
"""
|
||||||
|
asset_directory = get_cache_dir() / "vllm_public_assets"
|
||||||
|
asset_directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
asset_path = asset_directory / filename
|
||||||
|
if not asset_path.exists():
|
||||||
|
if s3_prefix is not None:
|
||||||
|
filename = s3_prefix + "/" + filename
|
||||||
|
global_http_connection.download_file(
|
||||||
|
f"{VLLM_S3_BUCKET_URL}/{filename}",
|
||||||
|
asset_path,
|
||||||
|
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT)
|
||||||
|
|
||||||
|
return asset_path
|
||||||
50
vllm/assets/image.py
Normal file
50
vllm/assets/image.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from .base import get_vllm_public_assets
|
||||||
|
|
||||||
|
VLM_IMAGES_DIR = "vision_model_images"
|
||||||
|
|
||||||
|
ImageAssetName = Literal["stop_sign", "cherry_blossom", "hato",
|
||||||
|
"2560px-Gfp-wisconsin-madison-the-nature-boardwalk",
|
||||||
|
"Grayscale_8bits_palette_sample_image",
|
||||||
|
"1280px-Venn_diagram_rgb", "RGBA_comp", "237-400x300",
|
||||||
|
"231-200x300", "27-500x500", "17-150x600",
|
||||||
|
"handelsblatt-preview", "paper-11"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ImageAsset:
|
||||||
|
name: ImageAssetName
|
||||||
|
|
||||||
|
def get_path(self, ext: str) -> Path:
|
||||||
|
"""
|
||||||
|
Return s3 path for given image.
|
||||||
|
"""
|
||||||
|
return get_vllm_public_assets(filename=f"{self.name}.{ext}",
|
||||||
|
s3_prefix=VLM_IMAGES_DIR)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pil_image(self, ext="jpg") -> Image.Image:
|
||||||
|
|
||||||
|
image_path = self.get_path(ext)
|
||||||
|
return Image.open(image_path)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_embeds(self) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Image embeddings, only used for testing purposes with llava 1.5.
|
||||||
|
"""
|
||||||
|
image_path = self.get_path('pt')
|
||||||
|
return torch.load(image_path, map_location="cpu", weights_only=True)
|
||||||
|
|
||||||
|
def read_bytes(self, ext: str) -> bytes:
|
||||||
|
p = Path(self.get_path(ext))
|
||||||
|
return p.read_bytes()
|
||||||
145
vllm/assets/video.py
Normal file
145
vllm/assets/video.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Any, ClassVar, Literal, Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from vllm.utils import PlaceholderModule
|
||||||
|
|
||||||
|
from .base import get_cache_dir
|
||||||
|
|
||||||
|
try:
|
||||||
|
import librosa
|
||||||
|
except ImportError:
|
||||||
|
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def download_video_asset(filename: str) -> str:
|
||||||
|
"""
|
||||||
|
Download and open an image from huggingface
|
||||||
|
repo: raushan-testing-hf/videos-test
|
||||||
|
"""
|
||||||
|
video_directory = get_cache_dir() / "video-example-data"
|
||||||
|
video_directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
video_path = video_directory / filename
|
||||||
|
video_path_str = str(video_path)
|
||||||
|
if not video_path.exists():
|
||||||
|
video_path_str = hf_hub_download(
|
||||||
|
repo_id="raushan-testing-hf/videos-test",
|
||||||
|
filename=filename,
|
||||||
|
repo_type="dataset",
|
||||||
|
cache_dir=video_directory,
|
||||||
|
)
|
||||||
|
return video_path_str
|
||||||
|
|
||||||
|
|
||||||
|
def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
|
||||||
|
cap = cv2.VideoCapture(path)
|
||||||
|
if not cap.isOpened():
|
||||||
|
raise ValueError(f"Could not open video file {path}")
|
||||||
|
|
||||||
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
frames = []
|
||||||
|
|
||||||
|
num_frames = num_frames if num_frames > 0 else total_frames
|
||||||
|
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
||||||
|
for idx in range(total_frames):
|
||||||
|
ok = cap.grab() # next img
|
||||||
|
if not ok:
|
||||||
|
break
|
||||||
|
if idx in frame_indices: # only decompress needed
|
||||||
|
ret, frame = cap.retrieve()
|
||||||
|
if ret:
|
||||||
|
# OpenCV uses BGR format, we need to convert it to RGB
|
||||||
|
# for PIL and transformers compatibility
|
||||||
|
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||||
|
|
||||||
|
frames = np.stack(frames)
|
||||||
|
if len(frames) < num_frames:
|
||||||
|
raise ValueError(f"Could not read enough frames from video file {path}"
|
||||||
|
f" (expected {num_frames} frames, got {len(frames)})")
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def video_to_pil_images_list(path: str,
|
||||||
|
num_frames: int = -1) -> list[Image.Image]:
|
||||||
|
frames = video_to_ndarrays(path, num_frames)
|
||||||
|
return [Image.fromarray(frame) for frame in frames]
|
||||||
|
|
||||||
|
|
||||||
|
def video_get_metadata(path: str, num_frames: int = -1) -> dict[str, Any]:
|
||||||
|
cap = cv2.VideoCapture(path)
|
||||||
|
if not cap.isOpened():
|
||||||
|
raise ValueError(f"Could not open video file {path}")
|
||||||
|
|
||||||
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||||
|
duration = total_frames / fps if fps > 0 else 0
|
||||||
|
|
||||||
|
if num_frames == -1 or num_frames > total_frames:
|
||||||
|
num_frames = total_frames
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"total_num_frames": num_frames,
|
||||||
|
"fps": fps,
|
||||||
|
"duration": duration,
|
||||||
|
"video_backend": "opencv",
|
||||||
|
"frames_indices": list(range(num_frames)),
|
||||||
|
# extra field used to control hf processor's video
|
||||||
|
# sampling behavior
|
||||||
|
"do_sample_frames": num_frames == total_frames,
|
||||||
|
}
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
VideoAssetName = Literal["baby_reading"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class VideoAsset:
|
||||||
|
name: VideoAssetName
|
||||||
|
num_frames: int = -1
|
||||||
|
|
||||||
|
_NAME_TO_FILE: ClassVar[dict[VideoAssetName, str]] = {
|
||||||
|
"baby_reading": "sample_demo_1.mp4",
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def filename(self) -> str:
|
||||||
|
return self._NAME_TO_FILE[self.name]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def video_path(self) -> str:
|
||||||
|
return download_video_asset(self.filename)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pil_images(self) -> list[Image.Image]:
|
||||||
|
ret = video_to_pil_images_list(self.video_path, self.num_frames)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@property
|
||||||
|
def np_ndarrays(self) -> npt.NDArray:
|
||||||
|
ret = video_to_ndarrays(self.video_path, self.num_frames)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metadata(self) -> dict[str, Any]:
|
||||||
|
ret = video_get_metadata(self.video_path, self.num_frames)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray:
|
||||||
|
"""
|
||||||
|
Read audio data from the video asset, used in Qwen2.5-Omni examples.
|
||||||
|
|
||||||
|
See also: examples/offline_inference/qwen2_5_omni/only_thinker.py
|
||||||
|
"""
|
||||||
|
return librosa.load(self.video_path, sr=sampling_rate)[0]
|
||||||
15
vllm/attention/__init__.py
Normal file
15
vllm/attention/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||||
|
AttentionMetadata, AttentionType)
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Attention",
|
||||||
|
"AttentionBackend",
|
||||||
|
"AttentionMetadata",
|
||||||
|
"AttentionType",
|
||||||
|
"get_attn_backend",
|
||||||
|
]
|
||||||
BIN
vllm/attention/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/attention/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/attention/__pycache__/layer.cpython-310.pyc
Normal file
BIN
vllm/attention/__pycache__/layer.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/attention/__pycache__/selector.cpython-310.pyc
Normal file
BIN
vllm/attention/__pycache__/selector.cpython-310.pyc
Normal file
Binary file not shown.
0
vllm/attention/backends/__init__.py
Normal file
0
vllm/attention/backends/__init__.py
Normal file
BIN
vllm/attention/backends/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/attention/backends/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/attention/backends/__pycache__/abstract.cpython-310.pyc
Normal file
BIN
vllm/attention/backends/__pycache__/abstract.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/attention/backends/__pycache__/utils.cpython-310.pyc
Normal file
BIN
vllm/attention/backends/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
204
vllm/attention/backends/abstract.py
Normal file
204
vllm/attention/backends/abstract.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Generic, List, Optional, Protocol, Tuple, Type, TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionType:
|
||||||
|
"""
|
||||||
|
Attention type.
|
||||||
|
Use string to be compatible with `torch.compile`.
|
||||||
|
"""
|
||||||
|
DECODER = "decoder"
|
||||||
|
"""Decoder attention between previous layer Q/K/V."""
|
||||||
|
ENCODER = "encoder"
|
||||||
|
"""Encoder attention between previous layer Q/K/V for encoder-decoder."""
|
||||||
|
ENCODER_ONLY = "encoder_only"
|
||||||
|
"""Encoder attention between previous layer Q/K/V."""
|
||||||
|
ENCODER_DECODER = "encoder_decoder"
|
||||||
|
"""Attention between dec. Q and enc. K/V for encoder-decoder."""
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBackend(ABC):
|
||||||
|
"""Abstract class for attention backends."""
|
||||||
|
# For some attention backends, we allocate an output tensor before
|
||||||
|
# calling the custom op. When piecewise cudagraph is enabled, this
|
||||||
|
# makes sure the output tensor is allocated inside the cudagraph.
|
||||||
|
accept_output_buffer: bool = False
|
||||||
|
|
||||||
|
# Whether this backend supports receiving pre-quantized query input.
|
||||||
|
# If True, the attention layer will handle query quantization instead
|
||||||
|
# of the backend, allowing torch.compile to fuse quantization with
|
||||||
|
# previous operations.
|
||||||
|
# Needs to be worked through for all backends
|
||||||
|
# https://github.com/vllm-project/vllm/issues/25584
|
||||||
|
supports_quant_query_input: bool = False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_impl_cls() -> Type["AttentionImpl"]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
|
||||||
|
return cls.get_metadata_cls()(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_kv_cache_shape(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
cache_dtype_str: str = "auto",
|
||||||
|
) -> Tuple[int, ...]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_kv_cache_stride_order() -> Tuple[int, ...]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def full_cls_name(cls) -> tuple[str, str]:
|
||||||
|
return (cls.__module__, cls.__qualname__)
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionMetadata:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=AttentionMetadata)
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionLayer(Protocol):
|
||||||
|
|
||||||
|
_q_scale: torch.Tensor
|
||||||
|
_k_scale: torch.Tensor
|
||||||
|
_v_scale: torch.Tensor
|
||||||
|
_q_scale_float: float
|
||||||
|
_k_scale_float: float
|
||||||
|
_v_scale_float: float
|
||||||
|
_prob_scale: torch.Tensor
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionImpl(ABC, Generic[T]):
|
||||||
|
|
||||||
|
# Whether the attention impl can return the softmax lse for decode.
|
||||||
|
# Some features like decode context parallelism require the softmax lse.
|
||||||
|
can_return_lse_for_decode: bool = False
|
||||||
|
|
||||||
|
# some attention backends might not always want to return lse
|
||||||
|
# even if they can return lse (for efficiency reasons)
|
||||||
|
need_to_return_lse_for_decode: bool = False
|
||||||
|
|
||||||
|
dcp_world_size: int
|
||||||
|
dcp_rank: int
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
# use __new__ so that all subclasses will call this
|
||||||
|
self = super().__new__(cls)
|
||||||
|
try:
|
||||||
|
from vllm.distributed.parallel_state import get_dcp_group
|
||||||
|
self.dcp_world_size = get_dcp_group().world_size
|
||||||
|
self.dcp_rank = get_dcp_group().rank_in_group
|
||||||
|
except AssertionError:
|
||||||
|
# DCP might not be initialized in testing
|
||||||
|
self.dcp_world_size = 1
|
||||||
|
self.dcp_rank = 0
|
||||||
|
self.need_to_return_lse_for_decode = self.dcp_world_size > 1 \
|
||||||
|
and self.can_return_lse_for_decode
|
||||||
|
return self
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: Optional[int] = None,
|
||||||
|
alibi_slopes: Optional[List[float]] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
kv_cache_dtype: str = "auto",
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: T,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||||
|
"""
|
||||||
|
Does this attention implementation support fused output quantization.
|
||||||
|
This is used by the AttnFusionPass to only fuse output quantization
|
||||||
|
onto implementations that support it.
|
||||||
|
|
||||||
|
:param quant_key: QuantKey object that describes the quantization op
|
||||||
|
:return: is fusion supported for this type of quantization
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
|
hidden_states_or_cq: torch.Tensor,
|
||||||
|
kv_c_normed: torch.Tensor,
|
||||||
|
k_pe: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: T,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
||||||
|
return kv_cache_dtype != "auto"
|
||||||
33
vllm/attention/backends/utils.py
Normal file
33
vllm/attention/backends/utils.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Attention backend utils"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
PAD_SLOT_ID = -1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MLADims:
|
||||||
|
q_lora_rank: Optional[int]
|
||||||
|
kv_lora_rank: int
|
||||||
|
qk_nope_head_dim: int
|
||||||
|
qk_rope_head_dim: int
|
||||||
|
v_head_dim: int
|
||||||
|
|
||||||
|
|
||||||
|
def get_mla_dims(model_config: ModelConfig) -> MLADims:
|
||||||
|
hf_text_config = model_config.hf_text_config
|
||||||
|
|
||||||
|
return MLADims(
|
||||||
|
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
|
||||||
|
kv_lora_rank=hf_text_config.kv_lora_rank,
|
||||||
|
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
|
||||||
|
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
|
||||||
|
v_head_dim=hf_text_config.v_head_dim,
|
||||||
|
)
|
||||||
645
vllm/attention/layer.py
Normal file
645
vllm/attention/layer.py
Normal file
@@ -0,0 +1,645 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Attention layer."""
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.attention import AttentionType
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||||
|
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||||
|
from vllm.config import CacheConfig, get_current_vllm_config
|
||||||
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||||
|
has_kv_transfer_group,
|
||||||
|
is_v1_kv_transfer_group)
|
||||||
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
GroupShape)
|
||||||
|
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||||
|
from vllm.platforms import _Backend, current_platform
|
||||||
|
from vllm.utils import GiB_bytes, direct_register_custom_op
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
USE_XFORMERS_OPS = None
|
||||||
|
try:
|
||||||
|
tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, )
|
||||||
|
except AttributeError:
|
||||||
|
tag_cudagraph_unsafe = () # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
|
def check_xformers_availability():
|
||||||
|
global USE_XFORMERS_OPS
|
||||||
|
if USE_XFORMERS_OPS is not None:
|
||||||
|
return USE_XFORMERS_OPS
|
||||||
|
|
||||||
|
if current_platform.is_cuda() and current_platform.has_device_capability(
|
||||||
|
100):
|
||||||
|
# Xformers FA is not compatible with B200
|
||||||
|
USE_XFORMERS_OPS = False
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from importlib.util import find_spec
|
||||||
|
|
||||||
|
find_spec("xformers.ops")
|
||||||
|
USE_XFORMERS_OPS = True
|
||||||
|
except ImportError:
|
||||||
|
USE_XFORMERS_OPS = False
|
||||||
|
|
||||||
|
# the warning only needs to be shown once
|
||||||
|
if not USE_XFORMERS_OPS:
|
||||||
|
logger.warning("Xformers is not available, falling back.")
|
||||||
|
|
||||||
|
return USE_XFORMERS_OPS
|
||||||
|
|
||||||
|
|
||||||
|
def check_upstream_fa_availability(dtype: torch.dtype):
|
||||||
|
if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda(
|
||||||
|
) and current_platform.has_device_capability(80):
|
||||||
|
from transformers.utils import is_flash_attn_2_available
|
||||||
|
return is_flash_attn_2_available()
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module, AttentionLayerBase):
|
||||||
|
"""Attention layer.
|
||||||
|
|
||||||
|
This class takes query, key, and value tensors as input. The input tensors
|
||||||
|
can either contain prompt tokens or generation tokens.
|
||||||
|
The class does the following:
|
||||||
|
|
||||||
|
1. Store the input key and value tensors in the KV cache.
|
||||||
|
2. Perform (multi-head/multi-query/grouped-query) attention.
|
||||||
|
3. Return the output tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: Optional[int] = None,
|
||||||
|
alibi_slopes: Optional[List[float]] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
per_layer_sliding_window: Optional[int] = None,
|
||||||
|
use_mla: bool = False,
|
||||||
|
use_sparse: bool = False,
|
||||||
|
prefix: str = "",
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
|
attn_backend: Optional[type[AttentionBackend]] = None,
|
||||||
|
**extra_impl_args,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
The KV cache is stored inside this class and is accessed via
|
||||||
|
`self.kv_cache`.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if per_layer_sliding_window is not None:
|
||||||
|
# per-layer sliding window
|
||||||
|
sliding_window = per_layer_sliding_window
|
||||||
|
elif cache_config is not None:
|
||||||
|
# model-level sliding window
|
||||||
|
sliding_window = cache_config.sliding_window
|
||||||
|
else:
|
||||||
|
sliding_window = None
|
||||||
|
|
||||||
|
if cache_config is not None:
|
||||||
|
kv_cache_dtype = cache_config.cache_dtype
|
||||||
|
block_size = cache_config.block_size
|
||||||
|
calculate_kv_scales = cache_config.calculate_kv_scales
|
||||||
|
else:
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
block_size = 16
|
||||||
|
calculate_kv_scales = False
|
||||||
|
if num_kv_heads is None:
|
||||||
|
num_kv_heads = num_heads
|
||||||
|
assert num_heads % num_kv_heads == 0, \
|
||||||
|
f"num_heads ({num_heads}) is not " \
|
||||||
|
f"divisible by num_kv_heads ({num_kv_heads})"
|
||||||
|
|
||||||
|
# The default k/v_scale is set to 1.0. This is ignored
|
||||||
|
# when kv-cache is not fp8, and should be used with
|
||||||
|
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
|
||||||
|
# expect the pre-quantized k/v_scale to be loaded along
|
||||||
|
# with the model weights.
|
||||||
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
self.calculate_kv_scales = calculate_kv_scales
|
||||||
|
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
# FlashAttn doesn't support quantizing the kv-cache only
|
||||||
|
# but requires q to be quantized as well.
|
||||||
|
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
|
||||||
|
# We also keep q/k/v_scale on host (cpu) memory for attention
|
||||||
|
# backends that require the scales to be on host instead of on device.
|
||||||
|
# e.g. Flashinfer
|
||||||
|
self._q_scale_float = 1.0
|
||||||
|
self._k_scale_float = 1.0
|
||||||
|
self._v_scale_float = 1.0
|
||||||
|
|
||||||
|
# The output scale on host memory. This should be the input scale of
|
||||||
|
# the quant op after this attention layer.
|
||||||
|
self._o_scale_float: Optional[float] = None
|
||||||
|
|
||||||
|
self.use_mla = use_mla
|
||||||
|
self.use_sparse = use_sparse
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_size = head_size
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
self.has_sink = extra_impl_args.get("sinks") is not None
|
||||||
|
|
||||||
|
quant_method = quant_config.get_quant_method(
|
||||||
|
self, prefix=prefix) if quant_config else None
|
||||||
|
if quant_method is not None and not isinstance(
|
||||||
|
quant_method, UnquantizedLinearMethod):
|
||||||
|
assert isinstance(quant_method, BaseKVCacheMethod)
|
||||||
|
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
||||||
|
# checkpoint config and become the "auto" behavior
|
||||||
|
if self.kv_cache_dtype == "fp8_e5m2":
|
||||||
|
raise ValueError("fp8_e5m2 kv-cache is not supported with "
|
||||||
|
"fp8 checkpoints.")
|
||||||
|
# If quantization is enabled, we make "k_scale" and "v_scale"
|
||||||
|
# parameters so that it can be loaded from the model checkpoint.
|
||||||
|
# The k/v_scale will then be converted back to native float32
|
||||||
|
# values after weight loading.
|
||||||
|
self.quant_method = quant_method
|
||||||
|
self.quant_method.create_weights(self)
|
||||||
|
|
||||||
|
# During model initialization, the default dtype is set as the model
|
||||||
|
# weight and activation dtype.
|
||||||
|
dtype = torch.get_default_dtype()
|
||||||
|
if attn_backend is None:
|
||||||
|
self.attn_backend = get_attn_backend(head_size,
|
||||||
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
|
block_size,
|
||||||
|
use_mla=use_mla,
|
||||||
|
has_sink=self.has_sink,
|
||||||
|
use_sparse=use_sparse)
|
||||||
|
else:
|
||||||
|
self.attn_backend = attn_backend
|
||||||
|
|
||||||
|
impl_cls = self.attn_backend.get_impl_cls()
|
||||||
|
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||||
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||||
|
logits_soft_cap, attn_type,
|
||||||
|
kv_sharing_target_layer_name, **extra_impl_args)
|
||||||
|
self.backend = backend_name_to_enum(self.attn_backend.get_name())
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||||
|
# torch.compile works by registering the attention as one giant
|
||||||
|
# opaque custom op. For other platforms, we directly call them
|
||||||
|
# and let torch.compile handle them.
|
||||||
|
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||||
|
|
||||||
|
self.use_output = self.attn_backend.accept_output_buffer
|
||||||
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
|
if prefix in compilation_config.static_forward_context:
|
||||||
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
|
compilation_config.static_forward_context[prefix] = self
|
||||||
|
self.layer_name = prefix
|
||||||
|
self.attn_type = attn_type
|
||||||
|
|
||||||
|
if kv_sharing_target_layer_name is not None:
|
||||||
|
validate_kv_sharing_target(
|
||||||
|
prefix,
|
||||||
|
kv_sharing_target_layer_name,
|
||||||
|
compilation_config.static_forward_context,
|
||||||
|
)
|
||||||
|
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||||
|
|
||||||
|
# use a placeholder kv cache tensor during init, which will be replaced
|
||||||
|
# by bind_kv_cache
|
||||||
|
# this variable will not be accessed if use_direct_call is True
|
||||||
|
self.kv_cache = [
|
||||||
|
torch.tensor([]) for _ in range(get_current_vllm_config(
|
||||||
|
).parallel_config.pipeline_parallel_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT,
|
||||||
|
dtype=torch.float32)
|
||||||
|
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT,
|
||||||
|
dtype=torch.float32)
|
||||||
|
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT,
|
||||||
|
dtype=torch.float32)
|
||||||
|
except torch.cuda.OutOfMemoryError as e:
|
||||||
|
logger.error(
|
||||||
|
"Failed to initialize attention q/k/v range constants: %s", e)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
logger.debug("CUDA device: %s", torch.cuda.current_device())
|
||||||
|
logger.debug("Allocated: %.2f GiB",
|
||||||
|
torch.cuda.memory_allocated() / GiB_bytes)
|
||||||
|
logger.debug("Reserved: %.2f GiB",
|
||||||
|
torch.cuda.memory_reserved() / GiB_bytes)
|
||||||
|
raise RuntimeError(
|
||||||
|
"Failed to initialize q/k/v range constants. "
|
||||||
|
"This may be caused by insufficient memory to allocate "
|
||||||
|
"kv cache.") from e
|
||||||
|
|
||||||
|
# for attn backends supporting query quantization
|
||||||
|
self.query_quant = None
|
||||||
|
if self.kv_cache_dtype.startswith(
|
||||||
|
"fp8") and self.attn_backend.supports_quant_query_input:
|
||||||
|
self.query_quant = QuantFP8(static=True,
|
||||||
|
group_shape=GroupShape.PER_TENSOR)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
# For some alternate attention backends like MLA the attention output
|
||||||
|
# shape does not match the query shape, so we optionally let the model
|
||||||
|
# definition specify the output tensor shape.
|
||||||
|
output_shape: Optional[torch.Size] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
The KV cache is stored inside this class and is accessed via
|
||||||
|
`self.kv_cache`.
|
||||||
|
|
||||||
|
Attention metadata (`attn_metadata`) is set using a context manager in
|
||||||
|
the model runner's `execute_model` method. It is accessed via forward
|
||||||
|
context using
|
||||||
|
`vllm.forward_context.get_forward_context().attn_metadata`.
|
||||||
|
"""
|
||||||
|
if self.calculate_kv_scales:
|
||||||
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
|
if attn_metadata.enable_kv_scales_calculation:
|
||||||
|
self.calc_kv_scales(query, key, value)
|
||||||
|
|
||||||
|
output_dtype = query.dtype
|
||||||
|
if self.query_quant is not None:
|
||||||
|
# quantizing with a simple torch operation enables
|
||||||
|
# torch.compile to fuse this into previous ops
|
||||||
|
# which reduces overheads during decoding.
|
||||||
|
# Otherwise queries are quantized using custom ops
|
||||||
|
# which causes decoding overheads
|
||||||
|
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
|
||||||
|
query, _ = self.query_quant(query, self._q_scale)
|
||||||
|
|
||||||
|
if self.use_output:
|
||||||
|
output_shape = (output_shape
|
||||||
|
if output_shape is not None else query.shape)
|
||||||
|
output = torch.zeros(output_shape,
|
||||||
|
dtype=output_dtype,
|
||||||
|
device=query.device)
|
||||||
|
hidden_size = output_shape[-1]
|
||||||
|
# We skip reshaping query, key and value tensors for the MLA
|
||||||
|
# backend since these tensors have different semantics and are
|
||||||
|
# processed differently.
|
||||||
|
if not self.use_mla:
|
||||||
|
# Reshape the query, key, and value tensors.
|
||||||
|
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||||
|
# CPU overheads from the non-CUDA-graph regions.
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
output = output.view(-1, self.num_heads, self.head_size)
|
||||||
|
if key is not None:
|
||||||
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
if value is not None:
|
||||||
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
if self.use_direct_call:
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if isinstance(attn_metadata, dict):
|
||||||
|
attn_metadata = attn_metadata[self.layer_name]
|
||||||
|
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
|
self.impl.forward(self,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
self_kv_cache,
|
||||||
|
attn_metadata,
|
||||||
|
output=output)
|
||||||
|
else:
|
||||||
|
torch.ops.vllm.unified_attention_with_output(
|
||||||
|
query, key, value, output, self.layer_name)
|
||||||
|
return output.view(-1, hidden_size)
|
||||||
|
else:
|
||||||
|
if self.use_direct_call:
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if isinstance(attn_metadata, dict):
|
||||||
|
attn_metadata = attn_metadata[self.layer_name]
|
||||||
|
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
|
return self.impl.forward(self, query, key, value,
|
||||||
|
self_kv_cache, attn_metadata)
|
||||||
|
else:
|
||||||
|
return torch.ops.vllm.unified_attention(
|
||||||
|
query, key, value, self.layer_name)
|
||||||
|
|
||||||
|
def calc_kv_scales(self, query, key, value):
|
||||||
|
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
|
||||||
|
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
|
||||||
|
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
|
||||||
|
self._q_scale_float = self._q_scale.item()
|
||||||
|
self._k_scale_float = self._k_scale.item()
|
||||||
|
self._v_scale_float = self._v_scale.item()
|
||||||
|
# We only calculate the scales once
|
||||||
|
self.calculate_kv_scales = False
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
s = f"head_size={self.impl.head_size}" # type: ignore
|
||||||
|
s += f", num_heads={self.impl.num_heads}" # type: ignore
|
||||||
|
s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
|
||||||
|
s += f", scale={self.impl.scale}" # type: ignore
|
||||||
|
s += f", backend={self.impl.__class__.__name__}"
|
||||||
|
return s
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||||
|
if hasattr(self.impl, "process_weights_after_loading"):
|
||||||
|
self.impl.process_weights_after_loading(act_dtype)
|
||||||
|
|
||||||
|
# FlashInfer requires attention sinks to be float32
|
||||||
|
if (self.backend == _Backend.FLASHINFER
|
||||||
|
and hasattr(self.impl, 'sinks')):
|
||||||
|
from vllm.v1.attention.backends.flashinfer import FlashInferImpl
|
||||||
|
assert isinstance(self.impl, FlashInferImpl)
|
||||||
|
if (self.impl.sinks is not None
|
||||||
|
and self.impl.sinks.dtype != torch.float32):
|
||||||
|
self.impl.sinks = self.impl.sinks.to(torch.float32)
|
||||||
|
|
||||||
|
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||||
|
return self.attn_backend
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
"""Multi-headed attention without any cache, used for ViT."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_size = head_size
|
||||||
|
self.scale = scale
|
||||||
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
|
|
||||||
|
assert self.num_heads % self.num_kv_heads == 0, \
|
||||||
|
f"num_heads ({self.num_heads}) is not " \
|
||||||
|
f"divisible by num_kv_heads ({self.num_kv_heads})"
|
||||||
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
|
# During model initialization, the default dtype is set as the model
|
||||||
|
# weight and activation dtype.
|
||||||
|
dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
|
# Determine the attention backend
|
||||||
|
backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)
|
||||||
|
|
||||||
|
# Some auto-selected backends can be upgraded
|
||||||
|
# to upstream flash attention if available.
|
||||||
|
# If vllm native fa is selected, we use it directly.
|
||||||
|
use_upstream_fa = False
|
||||||
|
if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||||
|
dtype):
|
||||||
|
backend = _Backend.FLASH_ATTN
|
||||||
|
use_upstream_fa = True
|
||||||
|
|
||||||
|
if current_platform.is_rocm() or current_platform.is_xpu():
|
||||||
|
# currently, only torch_sdpa is supported on rocm/xpu
|
||||||
|
self.attn_backend = _Backend.TORCH_SDPA
|
||||||
|
else:
|
||||||
|
|
||||||
|
self.attn_backend = backend if backend in {
|
||||||
|
_Backend.TORCH_SDPA,
|
||||||
|
_Backend.XFORMERS,
|
||||||
|
_Backend.PALLAS,
|
||||||
|
_Backend.ROCM_AITER_FA,
|
||||||
|
_Backend.FLASH_ATTN,
|
||||||
|
} else _Backend.TORCH_SDPA
|
||||||
|
|
||||||
|
if (self.attn_backend == _Backend.XFORMERS
|
||||||
|
and not check_xformers_availability()):
|
||||||
|
self.attn_backend = _Backend.TORCH_SDPA
|
||||||
|
|
||||||
|
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||||
|
if use_upstream_fa:
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
self._flash_attn_varlen_func = flash_attn_varlen_func
|
||||||
|
else:
|
||||||
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
|
self._flash_attn_varlen_func = flash_attn_varlen_func
|
||||||
|
|
||||||
|
logger.info_once(
|
||||||
|
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
|
||||||
|
f"use_upstream_fa: {use_upstream_fa}")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Input shape:
|
||||||
|
(batch_size x seq_len x hidden_size) or
|
||||||
|
(batch_size x seq_len x num_heads x head_size)
|
||||||
|
"""
|
||||||
|
bsz, q_len = query.size()[:2]
|
||||||
|
kv_len = key.size(1)
|
||||||
|
|
||||||
|
query = query.view(bsz, q_len, self.num_heads, self.head_size)
|
||||||
|
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||||
|
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
|
if (num_repeat := self.num_queries_per_kv) > 1:
|
||||||
|
# Handle MQA and GQA
|
||||||
|
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||||
|
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||||
|
|
||||||
|
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||||
|
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
|
||||||
|
step=q_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=query.device)
|
||||||
|
cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
|
||||||
|
step=kv_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=key.device)
|
||||||
|
|
||||||
|
out = self._flash_attn_varlen_func(
|
||||||
|
query.flatten(0, 1),
|
||||||
|
key.flatten(0, 1),
|
||||||
|
value.flatten(0, 1),
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=q_len,
|
||||||
|
max_seqlen_k=kv_len,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
)
|
||||||
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
|
from xformers import ops as xops
|
||||||
|
|
||||||
|
out = xops.memory_efficient_attention_forward(query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
scale=self.scale)
|
||||||
|
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||||
|
query, key, value = (x.transpose(1, 2)
|
||||||
|
for x in (query, key, value))
|
||||||
|
out = F.scaled_dot_product_attention(query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
scale=self.scale)
|
||||||
|
out = out.transpose(1, 2)
|
||||||
|
elif self.attn_backend == _Backend.PALLAS:
|
||||||
|
query, key, value = (x.transpose(1, 2)
|
||||||
|
for x in (query, key, value))
|
||||||
|
from torch_xla.experimental.custom_kernel import flash_attention
|
||||||
|
out = flash_attention(query, key, value, sm_scale=self.scale)
|
||||||
|
out = out.transpose(1, 2)
|
||||||
|
elif self.attn_backend == _Backend.ROCM_AITER_FA:
|
||||||
|
from aiter import flash_attn_varlen_func
|
||||||
|
|
||||||
|
# ROCm Flash Attention expects (batch, seq, heads, head_dim)
|
||||||
|
out = flash_attn_varlen_func(query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
softmax_scale=self.scale)
|
||||||
|
else:
|
||||||
|
# ViT attention hasn't supported this backend yet
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"ViT attention hasn't supported {self.attn_backend} "
|
||||||
|
f"backend yet.")
|
||||||
|
|
||||||
|
return out.reshape(bsz, q_len, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_kv_layer_from_connector(layer_name: str):
|
||||||
|
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||||
|
return
|
||||||
|
|
||||||
|
connector = get_kv_transfer_group()
|
||||||
|
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if attn_metadata is None:
|
||||||
|
return
|
||||||
|
assert isinstance(attn_metadata, dict)
|
||||||
|
connector.wait_for_layer_load(layer_name)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_save_kv_layer_to_connector(
|
||||||
|
layer_name: str,
|
||||||
|
kv_cache_layer: List[torch.Tensor],
|
||||||
|
):
|
||||||
|
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||||
|
return
|
||||||
|
|
||||||
|
connector = get_kv_transfer_group()
|
||||||
|
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if attn_metadata is None:
|
||||||
|
return
|
||||||
|
assert isinstance(attn_metadata, dict)
|
||||||
|
connector.save_kv_layer(layer_name, kv_cache_layer,
|
||||||
|
attn_metadata[layer_name])
|
||||||
|
|
||||||
|
|
||||||
|
def unified_attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
wait_for_kv_layer_from_connector(layer_name)
|
||||||
|
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if isinstance(attn_metadata, dict):
|
||||||
|
attn_metadata = attn_metadata[layer_name]
|
||||||
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
|
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
|
output = self.impl.forward(self, query, key, value, kv_cache,
|
||||||
|
attn_metadata)
|
||||||
|
|
||||||
|
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def unified_attention_fake(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty_like(query).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="unified_attention",
|
||||||
|
op_func=unified_attention,
|
||||||
|
fake_impl=unified_attention_fake,
|
||||||
|
tags=tag_cudagraph_unsafe,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def unified_attention_with_output(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> None:
|
||||||
|
wait_for_kv_layer_from_connector(layer_name)
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if isinstance(attn_metadata, dict):
|
||||||
|
attn_metadata = attn_metadata[layer_name]
|
||||||
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
|
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
|
self.impl.forward(self,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
kv_cache,
|
||||||
|
attn_metadata,
|
||||||
|
output=output,
|
||||||
|
output_scale=output_scale,
|
||||||
|
output_block_scale=output_block_scale)
|
||||||
|
|
||||||
|
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||||
|
|
||||||
|
|
||||||
|
def unified_attention_with_output_fake(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="unified_attention_with_output",
|
||||||
|
op_func=unified_attention_with_output,
|
||||||
|
mutates_args=["output", "output_block_scale"],
|
||||||
|
fake_impl=unified_attention_with_output_fake,
|
||||||
|
tags=tag_cudagraph_unsafe,
|
||||||
|
)
|
||||||
0
vllm/attention/layers/__init__.py
Normal file
0
vllm/attention/layers/__init__.py
Normal file
BIN
vllm/attention/layers/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/attention/layers/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
93
vllm/attention/layers/chunked_local_attention.py
Normal file
93
vllm/attention/layers/chunked_local_attention.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import functools
|
||||||
|
from typing import ClassVar, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||||
|
AttentionMetadata)
|
||||||
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
from vllm.config import CacheConfig, QuantizationConfig
|
||||||
|
from vllm.v1.attention.backends.utils import (
|
||||||
|
AttentionCGSupport, CommonAttentionMetadata,
|
||||||
|
make_local_attention_virtual_batches, subclass_attention_backend)
|
||||||
|
|
||||||
|
from ..layer import Attention
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
|
def create_chunked_local_attention_backend(
|
||||||
|
underlying_attn_backend: AttentionBackend,
|
||||||
|
attention_chunk_size: int,
|
||||||
|
block_size: int,
|
||||||
|
) -> type[AttentionBackend]:
|
||||||
|
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
|
||||||
|
|
||||||
|
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||||
|
|
||||||
|
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
|
||||||
|
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||||
|
AttentionCGSupport.NEVER
|
||||||
|
|
||||||
|
def build(self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
fast_build: bool = False) -> AttentionMetadata:
|
||||||
|
common_attn_metadata = make_local_attention_virtual_batches(
|
||||||
|
attention_chunk_size, common_attn_metadata, block_size)
|
||||||
|
return super().build(common_prefix_len, common_attn_metadata,
|
||||||
|
fast_build)
|
||||||
|
|
||||||
|
attn_backend = subclass_attention_backend(
|
||||||
|
name_prefix=prefix,
|
||||||
|
attention_backend_cls=underlying_attn_backend,
|
||||||
|
builder_cls=ChunkedLocalAttentionBuilder)
|
||||||
|
|
||||||
|
return attn_backend
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkedLocalAttention(Attention):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
attention_chunk_size: int,
|
||||||
|
num_kv_heads: Optional[int] = None,
|
||||||
|
alibi_slopes: Optional[List[float]] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
|
prefix: str = ""):
|
||||||
|
dtype = torch.get_default_dtype()
|
||||||
|
if cache_config is not None:
|
||||||
|
kv_cache_dtype = cache_config.cache_dtype
|
||||||
|
block_size = cache_config.block_size
|
||||||
|
else:
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
block_size = 16
|
||||||
|
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
underlying_attn_backend = get_attn_backend(head_size, dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
|
block_size)
|
||||||
|
|
||||||
|
attn_backend = create_chunked_local_attention_backend(
|
||||||
|
underlying_attn_backend, attention_chunk_size, block_size)
|
||||||
|
else:
|
||||||
|
# in v0 the local attention is handled inside the backends
|
||||||
|
attn_backend = None
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
num_heads=num_heads,
|
||||||
|
head_size=head_size,
|
||||||
|
scale=scale,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
alibi_slopes=alibi_slopes,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix,
|
||||||
|
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
||||||
|
attn_backend=attn_backend)
|
||||||
162
vllm/attention/layers/cross_attention.py
Normal file
162
vllm/attention/layers/cross_attention.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import functools
|
||||||
|
from copy import copy
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||||
|
AttentionMetadata, AttentionType)
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import cdiv
|
||||||
|
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||||
|
subclass_attention_backend)
|
||||||
|
from vllm.v1.kv_cache_interface import CrossAttentionSpec
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_max_encoder_len(vllm_config: "VllmConfig") -> int:
|
||||||
|
"""Gets the max number of encoder input tokens from the config.
|
||||||
|
"""
|
||||||
|
sc = vllm_config.scheduler_config
|
||||||
|
assert sc and isinstance(sc.max_num_encoder_input_tokens, int), \
|
||||||
|
"max_num_encoder_input_tokens must be int for enc-dec models"
|
||||||
|
return sc.max_num_encoder_input_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cross_slot_mapping(encoder_seq_lens: np.ndarray,
|
||||||
|
block_table_tensor: torch.Tensor,
|
||||||
|
kv_cache_spec: CrossAttentionSpec,
|
||||||
|
device: torch.device) -> torch.Tensor:
|
||||||
|
"""Get cross-attention slot mappings."""
|
||||||
|
|
||||||
|
block_size = kv_cache_spec.block_size
|
||||||
|
slot_mappings = []
|
||||||
|
|
||||||
|
# Find indices with non-zero encoder sequence lengths
|
||||||
|
# The majority of parallel requests will be running the
|
||||||
|
# decoder, so this list should be relatively small.
|
||||||
|
active_indices = np.nonzero(encoder_seq_lens)[0]
|
||||||
|
|
||||||
|
for req_index in active_indices:
|
||||||
|
encoder_seq_len = encoder_seq_lens[req_index].item()
|
||||||
|
|
||||||
|
# Calculate the number of blocks needed for this request
|
||||||
|
num_blocks_needed = cdiv(encoder_seq_len, block_size)
|
||||||
|
|
||||||
|
# Get the block IDs for this request from the tensor
|
||||||
|
req_block_ids = block_table_tensor[req_index]
|
||||||
|
|
||||||
|
# Get only the blocks we need (first num_blocks_needed blocks)
|
||||||
|
needed_block_ids = req_block_ids[:num_blocks_needed]
|
||||||
|
|
||||||
|
# All needed blocks are allocated
|
||||||
|
i_values = torch.arange(encoder_seq_len,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device)
|
||||||
|
block_indices = i_values // block_size
|
||||||
|
block_offsets = i_values % block_size
|
||||||
|
block_numbers = needed_block_ids[block_indices]
|
||||||
|
slot_mapping = block_numbers * block_size + block_offsets
|
||||||
|
|
||||||
|
slot_mappings.append(slot_mapping)
|
||||||
|
|
||||||
|
if slot_mappings:
|
||||||
|
return torch.cat(slot_mappings)
|
||||||
|
else:
|
||||||
|
return torch.empty(0, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
|
def create_cross_attention_backend(
|
||||||
|
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]:
|
||||||
|
prefix = "CrossAttention_"
|
||||||
|
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||||
|
|
||||||
|
class CrossAttentionBuilder(underlying_builder): # type: ignore
|
||||||
|
|
||||||
|
def build(self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
fast_build: bool = False) -> AttentionMetadata:
|
||||||
|
new_metadata = copy(common_attn_metadata)
|
||||||
|
new_metadata.causal = False
|
||||||
|
max_encoder_len = _get_max_encoder_len(self.vllm_config)
|
||||||
|
new_metadata.max_seq_len = max_encoder_len
|
||||||
|
|
||||||
|
new_metadata.seq_lens = torch.full(
|
||||||
|
(new_metadata.num_reqs, ),
|
||||||
|
max_encoder_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
new_metadata.seq_lens_cpu = torch.full(
|
||||||
|
(new_metadata.num_reqs, ),
|
||||||
|
max_encoder_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
new_metadata.slot_mapping = _get_cross_slot_mapping(
|
||||||
|
new_metadata.encoder_seq_lens, new_metadata.block_table_tensor,
|
||||||
|
self.kv_cache_spec, self.device)
|
||||||
|
return super().build(common_prefix_len, new_metadata, fast_build)
|
||||||
|
|
||||||
|
attn_backend = subclass_attention_backend(
|
||||||
|
name_prefix=prefix,
|
||||||
|
attention_backend_cls=underlying_attn_backend,
|
||||||
|
builder_cls=CrossAttentionBuilder)
|
||||||
|
|
||||||
|
return attn_backend
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttention(Attention):
|
||||||
|
"""
|
||||||
|
Cross-attention for encoder-decoder models.
|
||||||
|
Handles attention between decoder queries and encoder keys/values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
attn_type: Optional[str] = None,
|
||||||
|
**kwargs):
|
||||||
|
dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
|
if cache_config is not None:
|
||||||
|
kv_cache_dtype = cache_config.cache_dtype
|
||||||
|
block_size = cache_config.block_size
|
||||||
|
else:
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
block_size = 16
|
||||||
|
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
underlying_attn_backend = get_attn_backend(head_size, dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
|
block_size)
|
||||||
|
|
||||||
|
attn_backend = create_cross_attention_backend(
|
||||||
|
underlying_attn_backend)
|
||||||
|
else:
|
||||||
|
# in v0 cross attention is handled inside the backends
|
||||||
|
attn_backend = None
|
||||||
|
|
||||||
|
if attn_type is not None:
|
||||||
|
assert attn_type == AttentionType.ENCODER_DECODER, (
|
||||||
|
"CrossAttention only supports AttentionType.ENCODER_DECODER")
|
||||||
|
|
||||||
|
super().__init__(num_heads=num_heads,
|
||||||
|
head_size=head_size,
|
||||||
|
scale=scale,
|
||||||
|
cache_config=cache_config,
|
||||||
|
attn_backend=attn_backend,
|
||||||
|
attn_type=AttentionType.ENCODER_DECODER,
|
||||||
|
**kwargs)
|
||||||
86
vllm/attention/layers/encoder_only_attention.py
Normal file
86
vllm/attention/layers/encoder_only_attention.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import functools
|
||||||
|
from copy import copy
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||||
|
AttentionMetadata, AttentionType)
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
from vllm.config import CacheConfig
|
||||||
|
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||||
|
subclass_attention_backend)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
|
def create_encoder_only_attention_backend(
|
||||||
|
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]:
|
||||||
|
prefix = "EncoderOnlyAttention_"
|
||||||
|
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||||
|
|
||||||
|
class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore
|
||||||
|
|
||||||
|
def build(self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
fast_build: bool = False) -> AttentionMetadata:
|
||||||
|
new_common_attn_metadata = copy(common_attn_metadata)
|
||||||
|
new_common_attn_metadata.causal = False
|
||||||
|
return super().build(common_prefix_len, new_common_attn_metadata,
|
||||||
|
fast_build)
|
||||||
|
|
||||||
|
attn_backend = subclass_attention_backend(
|
||||||
|
name_prefix=prefix,
|
||||||
|
attention_backend_cls=underlying_attn_backend,
|
||||||
|
builder_cls=EncoderOnlyAttentionBuilder)
|
||||||
|
|
||||||
|
return attn_backend
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderOnlyAttention(Attention):
|
||||||
|
"""
|
||||||
|
Encoder attention is a special case that doesn't need a KV Cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
attn_type: Optional[str] = None,
|
||||||
|
**kwargs):
|
||||||
|
dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
|
if cache_config is not None:
|
||||||
|
kv_cache_dtype = cache_config.cache_dtype
|
||||||
|
block_size = cache_config.block_size
|
||||||
|
else:
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
block_size = 16
|
||||||
|
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
underlying_attn_backend = get_attn_backend(head_size, dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
|
block_size)
|
||||||
|
|
||||||
|
attn_backend = create_encoder_only_attention_backend(
|
||||||
|
underlying_attn_backend)
|
||||||
|
else:
|
||||||
|
# in v0 encoder only attention is handled inside the backends
|
||||||
|
attn_backend = None
|
||||||
|
|
||||||
|
if attn_type is not None:
|
||||||
|
assert attn_type == AttentionType.ENCODER_ONLY, \
|
||||||
|
"EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
|
||||||
|
|
||||||
|
super().__init__(num_heads=num_heads,
|
||||||
|
head_size=head_size,
|
||||||
|
scale=scale,
|
||||||
|
cache_config=cache_config,
|
||||||
|
attn_backend=attn_backend,
|
||||||
|
attn_type=AttentionType.ENCODER_ONLY,
|
||||||
|
**kwargs)
|
||||||
0
vllm/attention/ops/__init__.py
Normal file
0
vllm/attention/ops/__init__.py
Normal file
BIN
vllm/attention/ops/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/attention/ops/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm/attention/ops/__pycache__/common.cpython-310.pyc
Normal file
BIN
vllm/attention/ops/__pycache__/common.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/attention/ops/__pycache__/flashmla.cpython-310.pyc
Normal file
BIN
vllm/attention/ops/__pycache__/flashmla.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/attention/ops/__pycache__/merge_attn_states.cpython-310.pyc
Normal file
BIN
vllm/attention/ops/__pycache__/merge_attn_states.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/attention/ops/__pycache__/paged_attn.cpython-310.pyc
Normal file
BIN
vllm/attention/ops/__pycache__/paged_attn.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm/attention/ops/__pycache__/prefix_prefill.cpython-310.pyc
Normal file
BIN
vllm/attention/ops/__pycache__/prefix_prefill.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/attention/ops/__pycache__/rocm_aiter_mla.cpython-310.pyc
Normal file
BIN
vllm/attention/ops/__pycache__/rocm_aiter_mla.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
405
vllm/attention/ops/chunked_prefill_paged_decode.py
Normal file
405
vllm/attention/ops/chunked_prefill_paged_decode.py
Normal file
@@ -0,0 +1,405 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Authors:
|
||||||
|
# - Burkhard Ringlein <ngl@zurich.ibm.com>
|
||||||
|
# - Jan van Lunteren <jvl@zurich.ibm.com>
|
||||||
|
# - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
|
||||||
|
# - Thomas Parnell <tpa@zurich.ibm.com>
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
from .prefix_prefill import context_attention_fwd
|
||||||
|
|
||||||
|
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def cdiv_fn(x, y):
|
||||||
|
return (x + y - 1) // y
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def kernel_paged_attention_2d(
|
||||||
|
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||||
|
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||||
|
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
|
||||||
|
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
|
||||||
|
sink_ptr, # [num_query_heads]
|
||||||
|
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||||
|
seq_lens_ptr, # [num_seqs]
|
||||||
|
alibi_slopes_ptr, # [num_query_heads]
|
||||||
|
scale, # float32
|
||||||
|
k_scale, # float32
|
||||||
|
v_scale, # float32
|
||||||
|
out_scale_inv,
|
||||||
|
num_query_heads: tl.constexpr, # int
|
||||||
|
num_queries_per_kv: tl.constexpr, # int
|
||||||
|
num_queries_per_kv_padded: tl.constexpr, # int
|
||||||
|
block_table_stride: tl.int64, # int
|
||||||
|
query_stride_0: tl.int64, # int
|
||||||
|
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||||
|
output_stride_0: tl.int64, # int
|
||||||
|
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||||
|
BLOCK_SIZE: tl.constexpr, # int
|
||||||
|
HEAD_SIZE: tl.constexpr, # int
|
||||||
|
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||||
|
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||||
|
SLIDING_WINDOW: tl.constexpr, # int
|
||||||
|
x: tl.constexpr, # int
|
||||||
|
stride_k_cache_0: tl.int64, # int
|
||||||
|
stride_k_cache_1: tl.int64, # int
|
||||||
|
stride_k_cache_2: tl.int64, # int
|
||||||
|
stride_k_cache_3: tl.int64, # int
|
||||||
|
stride_k_cache_4: tl.int64, # int
|
||||||
|
stride_v_cache_0: tl.int64, # int
|
||||||
|
stride_v_cache_1: tl.int64, # int
|
||||||
|
stride_v_cache_2: tl.int64, # int
|
||||||
|
stride_v_cache_3: tl.int64, # int
|
||||||
|
filter_by_query_len: tl.constexpr, # bool
|
||||||
|
query_start_len_ptr, # [num_seqs+1]
|
||||||
|
USE_SINKS: tl.constexpr, # bool
|
||||||
|
USE_FP8: tl.constexpr,
|
||||||
|
FP8_MIN: tl.constexpr = float8_info.min,
|
||||||
|
FP8_MAX: tl.constexpr = float8_info.max):
|
||||||
|
seq_idx = tl.program_id(0)
|
||||||
|
kv_head_idx = tl.program_id(1)
|
||||||
|
|
||||||
|
if filter_by_query_len:
|
||||||
|
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
||||||
|
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx +
|
||||||
|
1)
|
||||||
|
cur_batch_query_len = cur_batch_in_all_stop_index \
|
||||||
|
- cur_batch_in_all_start_index
|
||||||
|
if cur_batch_query_len > 1:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
cur_batch_in_all_start_index = seq_idx
|
||||||
|
|
||||||
|
query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange(
|
||||||
|
0, num_queries_per_kv_padded)
|
||||||
|
|
||||||
|
query_offset = (cur_batch_in_all_start_index * query_stride_0 +
|
||||||
|
query_head_idx[:, None] * query_stride_1)
|
||||||
|
|
||||||
|
head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv
|
||||||
|
head_mask = head_mask & (query_head_idx < num_query_heads)
|
||||||
|
|
||||||
|
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
|
||||||
|
0).to(tl.int1)
|
||||||
|
|
||||||
|
# Q : (num_queries_per_kv, HEAD_SIZE,)
|
||||||
|
Q = tl.load(
|
||||||
|
query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :],
|
||||||
|
mask=dim_mask[None, :] & head_mask[:, None],
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
block_table_offset = seq_idx * block_table_stride
|
||||||
|
|
||||||
|
if not USE_SINKS:
|
||||||
|
M = tl.full([num_queries_per_kv_padded],
|
||||||
|
float("-inf"),
|
||||||
|
dtype=tl.float32)
|
||||||
|
else:
|
||||||
|
M = tl.load(
|
||||||
|
sink_ptr + query_head_idx,
|
||||||
|
mask=head_mask,
|
||||||
|
other=float("-inf"),
|
||||||
|
).to(dtype=tl.float32)
|
||||||
|
|
||||||
|
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
|
||||||
|
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED],
|
||||||
|
dtype=tl.float32)
|
||||||
|
|
||||||
|
# sequence len for this particular sequence
|
||||||
|
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||||
|
|
||||||
|
# alibi slope for this head
|
||||||
|
if USE_ALIBI_SLOPES:
|
||||||
|
alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx,
|
||||||
|
mask=head_mask,
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
|
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
|
||||||
|
|
||||||
|
# iterate through tiles
|
||||||
|
for j in range(0, num_blocks):
|
||||||
|
|
||||||
|
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
|
||||||
|
|
||||||
|
offs_n = tl.arange(0, BLOCK_SIZE)
|
||||||
|
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||||
|
|
||||||
|
v_offset = (physical_block_idx * stride_v_cache_0 +
|
||||||
|
kv_head_idx * stride_v_cache_1 +
|
||||||
|
offs_d[None, :] * stride_v_cache_2 +
|
||||||
|
offs_n[:, None] * stride_v_cache_3)
|
||||||
|
|
||||||
|
k_offset = (physical_block_idx * stride_k_cache_0 +
|
||||||
|
kv_head_idx * stride_k_cache_1 +
|
||||||
|
(offs_d[:, None] // x) * stride_k_cache_2 +
|
||||||
|
offs_n[None, :] * stride_k_cache_3 +
|
||||||
|
(offs_d[:, None] % x) * stride_k_cache_4)
|
||||||
|
|
||||||
|
# K : (HEAD_SIZE, BLOCK_SIZE)
|
||||||
|
K_load = tl.load(key_cache_ptr + k_offset,
|
||||||
|
mask=dim_mask[:, None],
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
|
if K_load.dtype.is_fp8():
|
||||||
|
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
||||||
|
else:
|
||||||
|
K = K_load
|
||||||
|
|
||||||
|
# V : (BLOCK_SIZE, HEAD_SIZE)
|
||||||
|
V_load = tl.load(value_cache_ptr + v_offset,
|
||||||
|
mask=dim_mask[None, :],
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
|
if V_load.dtype.is_fp8():
|
||||||
|
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
||||||
|
else:
|
||||||
|
V = V_load
|
||||||
|
|
||||||
|
seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
|
boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32)
|
||||||
|
seq_mask = seq_offset[None, :] < boundary
|
||||||
|
|
||||||
|
# S : (num_queries_per_kv, BLOCK_SIZE,)
|
||||||
|
S = tl.where(head_mask[:, None] & seq_mask, 0.0,
|
||||||
|
float("-inf")).to(tl.float32)
|
||||||
|
S += scale * tl.dot(Q, K)
|
||||||
|
|
||||||
|
context_len = seq_len - 1
|
||||||
|
|
||||||
|
if SLIDING_WINDOW > 0:
|
||||||
|
S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S,
|
||||||
|
-10000)
|
||||||
|
|
||||||
|
if USE_ALIBI_SLOPES:
|
||||||
|
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||||
|
|
||||||
|
# compute running maximum
|
||||||
|
# m_j : (num_queries_per_kv,)
|
||||||
|
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||||
|
|
||||||
|
# P : (num_queries_per_kv, BLOCK_SIZE,)
|
||||||
|
P = tl.exp(S - m_j[:, None])
|
||||||
|
|
||||||
|
# l_j : (num_queries_per_kv,)
|
||||||
|
l_j = tl.sum(P, axis=1)
|
||||||
|
|
||||||
|
# alpha : (num_queries_per_kv, )
|
||||||
|
alpha = tl.exp(M - m_j)
|
||||||
|
|
||||||
|
# acc : (num_queries_per_kv, BLOCK_SIZE,)
|
||||||
|
acc = acc * alpha[:, None]
|
||||||
|
|
||||||
|
# update constants
|
||||||
|
L = L * alpha + l_j
|
||||||
|
M = m_j
|
||||||
|
|
||||||
|
# acc : (num_queries_per_kv, BLOCK_SIZE,)
|
||||||
|
acc += tl.dot(P.to(V.dtype), V)
|
||||||
|
|
||||||
|
# epilogue
|
||||||
|
acc = acc / L[:, None]
|
||||||
|
if USE_FP8:
|
||||||
|
acc = acc * tl.load(out_scale_inv)
|
||||||
|
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||||
|
|
||||||
|
output_offset = (cur_batch_in_all_start_index * output_stride_0 +
|
||||||
|
query_head_idx * output_stride_1)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
output_ptr + output_offset[:, None] +
|
||||||
|
tl.arange(0, HEAD_SIZE_PADDED)[None, :],
|
||||||
|
acc,
|
||||||
|
mask=dim_mask[None, :] & head_mask[:, None],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def chunked_prefill_paged_decode(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
output,
|
||||||
|
kv_cache_dtype,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
block_table,
|
||||||
|
query_start_loc,
|
||||||
|
seq_lens,
|
||||||
|
max_seq_len,
|
||||||
|
max_query_len,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
alibi_slopes=None,
|
||||||
|
sliding_window=None,
|
||||||
|
sm_scale=None,
|
||||||
|
output_scale=None,
|
||||||
|
# Optional tensor for sinks
|
||||||
|
sinks=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
if sm_scale is None:
|
||||||
|
sm_scale = 1.0 / (query.shape[1]**0.5)
|
||||||
|
|
||||||
|
use_alibi_slopes = alibi_slopes is not None
|
||||||
|
|
||||||
|
if sliding_window is None or sliding_window <= 0:
|
||||||
|
sliding_window = 0
|
||||||
|
|
||||||
|
if max_query_len > 1:
|
||||||
|
context_attention_fwd(
|
||||||
|
q=query,
|
||||||
|
k=key,
|
||||||
|
v=value,
|
||||||
|
o=output,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
k_cache=key_cache,
|
||||||
|
v_cache=value_cache,
|
||||||
|
b_loc=block_table,
|
||||||
|
b_start_loc=query_start_loc,
|
||||||
|
b_seq_len=seq_lens,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
max_input_len=max_query_len,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
|
alibi_slopes=alibi_slopes,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
sm_scale=sm_scale,
|
||||||
|
skip_decode=True,
|
||||||
|
fp8_out_scale=output_scale,
|
||||||
|
sinks=sinks,
|
||||||
|
)
|
||||||
|
|
||||||
|
block_size = value_cache.shape[3]
|
||||||
|
num_seqs = len(seq_lens)
|
||||||
|
num_query_heads = query.shape[1]
|
||||||
|
num_kv_heads = key.shape[1]
|
||||||
|
num_queries_per_kv = query.shape[1] // key.shape[1]
|
||||||
|
head_size = query.shape[2]
|
||||||
|
|
||||||
|
# Conversion of FP8 Tensor from uint8 storage to
|
||||||
|
# appropriate torch.dtype for interpretation by Triton
|
||||||
|
if "fp8" in kv_cache_dtype:
|
||||||
|
assert key_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||||
|
assert value_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||||
|
|
||||||
|
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||||
|
target_dtype = current_platform.fp8_dtype()
|
||||||
|
elif kv_cache_dtype == "fp8_e5m2":
|
||||||
|
target_dtype = torch.float8_e5m2
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
|
||||||
|
|
||||||
|
key_cache = key_cache.view(target_dtype)
|
||||||
|
value_cache = value_cache.view(target_dtype)
|
||||||
|
|
||||||
|
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
|
||||||
|
16)
|
||||||
|
|
||||||
|
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||||
|
use_custom = use_rocm_custom_paged_attention(
|
||||||
|
query.dtype,
|
||||||
|
head_size,
|
||||||
|
block_size,
|
||||||
|
num_queries_per_kv,
|
||||||
|
max_seq_len,
|
||||||
|
sliding_window,
|
||||||
|
kv_cache_dtype,
|
||||||
|
alibi_slopes,
|
||||||
|
sinks,
|
||||||
|
)
|
||||||
|
if use_custom:
|
||||||
|
_PARTITION_SIZE_ROCM = 256
|
||||||
|
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
|
||||||
|
_PARTITION_SIZE_ROCM)
|
||||||
|
assert _PARTITION_SIZE_ROCM % block_size == 0
|
||||||
|
total_num_seq = block_table.shape[0]
|
||||||
|
tmp_output = torch.empty(
|
||||||
|
size=(total_num_seq, num_query_heads, max_num_partitions,
|
||||||
|
head_size),
|
||||||
|
dtype=query.dtype,
|
||||||
|
device=output.device,
|
||||||
|
)
|
||||||
|
exp_sums = torch.empty(
|
||||||
|
size=(total_num_seq, num_query_heads, max_num_partitions),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=output.device,
|
||||||
|
)
|
||||||
|
max_logits = torch.empty_like(exp_sums)
|
||||||
|
|
||||||
|
ops.paged_attention_rocm(
|
||||||
|
output,
|
||||||
|
exp_sums,
|
||||||
|
max_logits,
|
||||||
|
tmp_output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
num_kv_heads,
|
||||||
|
scale=sm_scale,
|
||||||
|
block_tables=block_table,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
query_start_loc=query_start_loc,
|
||||||
|
block_size=block_size,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
alibi_slopes=alibi_slopes,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
|
fp8_out_scale=output_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kernel_paged_attention_2d[(
|
||||||
|
num_seqs,
|
||||||
|
num_kv_heads,
|
||||||
|
)](
|
||||||
|
output_ptr=output,
|
||||||
|
query_ptr=query,
|
||||||
|
key_cache_ptr=key_cache,
|
||||||
|
value_cache_ptr=value_cache,
|
||||||
|
sink_ptr=sinks,
|
||||||
|
block_tables_ptr=block_table,
|
||||||
|
seq_lens_ptr=seq_lens,
|
||||||
|
alibi_slopes_ptr=alibi_slopes,
|
||||||
|
scale=sm_scale,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
|
out_scale_inv=1.0 /
|
||||||
|
output_scale if output_scale is not None else 1.0,
|
||||||
|
num_query_heads=num_query_heads,
|
||||||
|
num_queries_per_kv=num_queries_per_kv,
|
||||||
|
num_queries_per_kv_padded=num_queries_per_kv_padded,
|
||||||
|
block_table_stride=block_table.stride(0),
|
||||||
|
query_stride_0=query.stride(0),
|
||||||
|
query_stride_1=query.stride(1),
|
||||||
|
output_stride_0=output.stride(0),
|
||||||
|
output_stride_1=output.stride(1),
|
||||||
|
BLOCK_SIZE=block_size,
|
||||||
|
HEAD_SIZE=head_size,
|
||||||
|
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||||
|
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||||
|
SLIDING_WINDOW=sliding_window,
|
||||||
|
x=key_cache.shape[4],
|
||||||
|
stride_k_cache_0=key_cache.stride(0),
|
||||||
|
stride_k_cache_1=key_cache.stride(1),
|
||||||
|
stride_k_cache_2=key_cache.stride(2),
|
||||||
|
stride_k_cache_3=key_cache.stride(3),
|
||||||
|
stride_k_cache_4=key_cache.stride(4),
|
||||||
|
stride_v_cache_0=value_cache.stride(0),
|
||||||
|
stride_v_cache_1=value_cache.stride(1),
|
||||||
|
stride_v_cache_2=value_cache.stride(2),
|
||||||
|
stride_v_cache_3=value_cache.stride(3),
|
||||||
|
filter_by_query_len=True,
|
||||||
|
query_start_len_ptr=query_start_loc,
|
||||||
|
USE_SINKS=sinks is not None,
|
||||||
|
USE_FP8=output_scale is not None,
|
||||||
|
)
|
||||||
345
vllm/attention/ops/common.py
Normal file
345
vllm/attention/ops/common.py
Normal file
@@ -0,0 +1,345 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.distributed.parallel_state import GroupCoordinator
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _correct_attn_cp_out_kernel(outputs_ptr, new_output_ptr, lses_ptr,
|
||||||
|
vlse_ptr, outputs_stride_B, outputs_stride_H,
|
||||||
|
outputs_stride_D, lses_stride_N, lses_stride_B,
|
||||||
|
lses_stride_H, lse_idx, HEAD_DIM: tl.constexpr,
|
||||||
|
N_ROUNDED: tl.constexpr):
|
||||||
|
"""
|
||||||
|
Apply the all-gathered lses to correct each local rank's attention
|
||||||
|
output. we still need perform a cross-rank reduction to obtain the
|
||||||
|
final attention output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs_ptr (triton.PointerType):
|
||||||
|
Pointer to input tensor of shape [ B, H, D ]
|
||||||
|
lses_ptr (triton.PointerType):
|
||||||
|
Pointer to input tensor of shape [ N, B, H ]
|
||||||
|
new_output_ptr (triton.PointerType):
|
||||||
|
Pointer to output tensor of shape [ B, H, D ]
|
||||||
|
vlse_ptr (triton.PointerType):
|
||||||
|
Pointer to output tensor of shape [ B, H ]
|
||||||
|
"""
|
||||||
|
batch_idx = tl.program_id(axis=0).to(tl.int64)
|
||||||
|
head_idx = tl.program_id(axis=1).to(tl.int64)
|
||||||
|
d_offsets = tl.arange(0, HEAD_DIM)
|
||||||
|
num_n_offsets = tl.arange(0, N_ROUNDED)
|
||||||
|
|
||||||
|
# shape = [N]
|
||||||
|
lse_offsets = num_n_offsets * lses_stride_N + batch_idx * \
|
||||||
|
lses_stride_B + head_idx * lses_stride_H
|
||||||
|
|
||||||
|
# calc final lse
|
||||||
|
lse = tl.load(lses_ptr + lse_offsets)
|
||||||
|
lse = tl.where((lse != lse) | (lse == float('inf')), -float('inf'), lse)
|
||||||
|
lse_max = tl.max(lse, axis=0)
|
||||||
|
lse -= lse_max
|
||||||
|
lse_exp = tl.exp(lse)
|
||||||
|
lse_acc = tl.sum(lse_exp, axis=0)
|
||||||
|
lse = tl.log(lse_acc)
|
||||||
|
lse += lse_max
|
||||||
|
|
||||||
|
lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
|
||||||
|
tl.store(vlse_ptr + lse_offsets, lse)
|
||||||
|
|
||||||
|
# shape = [D]
|
||||||
|
output_offsets = batch_idx * outputs_stride_B + \
|
||||||
|
head_idx * outputs_stride_H + \
|
||||||
|
d_offsets * outputs_stride_D
|
||||||
|
|
||||||
|
# correct output
|
||||||
|
lse_offset = lse_idx * lses_stride_N + batch_idx * \
|
||||||
|
lses_stride_B + head_idx * lses_stride_H
|
||||||
|
lse_tmp = tl.load(lses_ptr + lse_offset)
|
||||||
|
lse_finally = lse_tmp - lse
|
||||||
|
lse_finally = tl.where(
|
||||||
|
(lse_finally != lse_finally) | (lse_finally == float('inf')),
|
||||||
|
-float('inf'), lse_finally)
|
||||||
|
factor = tl.exp(lse_finally)
|
||||||
|
output = tl.load(outputs_ptr + output_offsets)
|
||||||
|
output = output * factor
|
||||||
|
|
||||||
|
tl.store(new_output_ptr + output_offsets, output)
|
||||||
|
|
||||||
|
|
||||||
|
class CPTritonContext:
|
||||||
|
""" The CPTritonContext is used to avoid recompilation of the Triton JIT.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.inner_kernel = None
|
||||||
|
|
||||||
|
def call_kernel(self, kernel, grid, *regular_args, **const_args):
|
||||||
|
if self.inner_kernel is None:
|
||||||
|
self.inner_kernel = kernel[grid](*regular_args, **const_args)
|
||||||
|
else:
|
||||||
|
self.inner_kernel[grid](*regular_args)
|
||||||
|
|
||||||
|
|
||||||
|
def correct_attn_out(
|
||||||
|
out: torch.Tensor, lses: torch.Tensor, cp_rank: int,
|
||||||
|
ctx: CPTritonContext) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Correct the attention output using the all-gathered lses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out: Tensor of shape [ B, H, D ]
|
||||||
|
lses: Tensor of shape [ N, B, H ]
|
||||||
|
cp_rank: Current rank in the context-parallel group
|
||||||
|
ctx: Triton context to avoid recompilation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (out, lse) with corrected attention and final log-sum-exp.
|
||||||
|
"""
|
||||||
|
if ctx is None:
|
||||||
|
ctx = CPTritonContext()
|
||||||
|
|
||||||
|
lse = torch.empty_like(lses[0])
|
||||||
|
|
||||||
|
grid = (out.shape[0], out.shape[1], 1)
|
||||||
|
regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(),
|
||||||
|
cp_rank)
|
||||||
|
const_args = {
|
||||||
|
"HEAD_DIM": out.shape[-1],
|
||||||
|
"N_ROUNDED": lses.shape[0],
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args,
|
||||||
|
**const_args)
|
||||||
|
return out, lse
|
||||||
|
|
||||||
|
|
||||||
|
def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
|
||||||
|
cp_attn_lse: torch.Tensor,
|
||||||
|
cp_group: GroupCoordinator,
|
||||||
|
ctx: CPTritonContext = None):
|
||||||
|
"""
|
||||||
|
cp_attn_out: [ B, H, D ]
|
||||||
|
cp_attn_lse: [ B, H ]
|
||||||
|
"""
|
||||||
|
if cp_group.world_size == 1:
|
||||||
|
return cp_attn_out
|
||||||
|
|
||||||
|
if ctx is None:
|
||||||
|
ctx = CPTritonContext()
|
||||||
|
|
||||||
|
lses = torch.empty((cp_group.world_size, ) + cp_attn_lse.shape,
|
||||||
|
dtype=cp_attn_lse.dtype,
|
||||||
|
device=cp_attn_lse.device)
|
||||||
|
|
||||||
|
cp_attn_lse = cp_attn_lse.contiguous()
|
||||||
|
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
||||||
|
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
||||||
|
out = cp_group.reduce_scatter(out, dim=1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _pack_seq_kernel(
|
||||||
|
x_ptr, # [N, D]
|
||||||
|
out_ptr, # [B, Lmax, D]
|
||||||
|
lengths_ptr, # *i32, [B]
|
||||||
|
N: tl.constexpr,
|
||||||
|
D: tl.constexpr,
|
||||||
|
Lmax: tl.constexpr,
|
||||||
|
PAD_VALUE: tl.constexpr,
|
||||||
|
BLOCK_T: tl.constexpr, # timesteps per program
|
||||||
|
BLOCK_D: tl.constexpr # features per program
|
||||||
|
):
|
||||||
|
pid_b = tl.program_id(0) # batch id
|
||||||
|
pid_t = tl.program_id(1) # block over time dimension
|
||||||
|
pid_d = tl.program_id(2) # block over feature dimension
|
||||||
|
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
|
||||||
|
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
|
||||||
|
|
||||||
|
# Compute start index and sequence length from cumulative lengths
|
||||||
|
in_start = 0
|
||||||
|
for i in range(pid_b):
|
||||||
|
in_start += tl.load(lengths_ptr + i)
|
||||||
|
seq_len = tl.load(lengths_ptr + pid_b)
|
||||||
|
|
||||||
|
# valid time positions for this block
|
||||||
|
t_mask = off_t < Lmax
|
||||||
|
|
||||||
|
# compute input row indices for valid (b, t)
|
||||||
|
in_row = in_start + off_t
|
||||||
|
valid_row = (off_t < seq_len) & t_mask
|
||||||
|
|
||||||
|
# Pointers
|
||||||
|
# x_ptr: row-major [N, D]
|
||||||
|
x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :]
|
||||||
|
|
||||||
|
# out_ptr: row-major [B, Lmax, D]
|
||||||
|
out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:,
|
||||||
|
None] * D + off_d[None, :]
|
||||||
|
|
||||||
|
# Initialize with PAD (cast will occur as needed based on out_ptr dtype)
|
||||||
|
d_mask = off_d[None, :] < D
|
||||||
|
pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32)
|
||||||
|
tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask)
|
||||||
|
|
||||||
|
# Load & write only where within seq_len
|
||||||
|
x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask)
|
||||||
|
tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def pack_seq_triton(x: torch.Tensor,
|
||||||
|
lengths: torch.Tensor,
|
||||||
|
pad_value: float = -float('inf'),
|
||||||
|
block_t: int = 64,
|
||||||
|
block_d: int = 64) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Pack sequences of different lengths into a batched tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: [N, ...] - input tensor where N is total number of tokens
|
||||||
|
lengths: [B] - sequence lengths for each batch
|
||||||
|
pad_value: value to use for padding
|
||||||
|
block_t: block size for time dimension
|
||||||
|
block_d: block size for feature dimension
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
packed: [B, Lmax, ...] - packed tensor
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Handle multi-dimensional input by reshaping to (N, -1)
|
||||||
|
original_shape = x.shape
|
||||||
|
if len(original_shape) > 2:
|
||||||
|
N = original_shape[0]
|
||||||
|
x_reshaped = x.reshape(N, -1)
|
||||||
|
D = x_reshaped.shape[1]
|
||||||
|
else:
|
||||||
|
N, D = x.shape
|
||||||
|
x_reshaped = x
|
||||||
|
|
||||||
|
B = lengths.numel()
|
||||||
|
Lmax = int(lengths.max().item())
|
||||||
|
|
||||||
|
# Starts are computed inside the kernel from lengths
|
||||||
|
|
||||||
|
out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
|
||||||
|
_pack_seq_kernel[grid](x_reshaped,
|
||||||
|
out,
|
||||||
|
lengths.int(),
|
||||||
|
N,
|
||||||
|
D,
|
||||||
|
Lmax,
|
||||||
|
PAD_VALUE=float(pad_value),
|
||||||
|
BLOCK_T=block_t,
|
||||||
|
BLOCK_D=block_d,
|
||||||
|
num_warps=4,
|
||||||
|
num_stages=2)
|
||||||
|
|
||||||
|
# Reshape output back to original dimensions (except first dimension)
|
||||||
|
if len(original_shape) > 2:
|
||||||
|
output_shape = (B, Lmax) + original_shape[1:]
|
||||||
|
out = out.reshape(output_shape)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _unpack_seq_triton_kernel(
|
||||||
|
packed_ptr, # [B, Lmax, D]
|
||||||
|
out_ptr, # [N, D]
|
||||||
|
lengths_ptr, # *i32, [B]
|
||||||
|
B: tl.constexpr,
|
||||||
|
Lmax: tl.constexpr,
|
||||||
|
D: tl.constexpr,
|
||||||
|
BLOCK_T: tl.constexpr, # timesteps per program
|
||||||
|
BLOCK_D: tl.constexpr # features per program
|
||||||
|
):
|
||||||
|
pid_b = tl.program_id(0) # batch id
|
||||||
|
pid_t = tl.program_id(1) # block over time dimension
|
||||||
|
pid_d = tl.program_id(2) # block over feature dimension
|
||||||
|
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
|
||||||
|
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
|
||||||
|
|
||||||
|
# bounds: compute start from cumulative lengths
|
||||||
|
in_start = 0
|
||||||
|
for i in range(pid_b):
|
||||||
|
in_start += tl.load(lengths_ptr + i)
|
||||||
|
seq_len = tl.load(lengths_ptr + pid_b)
|
||||||
|
|
||||||
|
# valid time positions for this block
|
||||||
|
t_mask = off_t < Lmax
|
||||||
|
valid_row = (off_t < seq_len) & t_mask
|
||||||
|
|
||||||
|
# compute output row indices for valid (b, t)
|
||||||
|
out_row = in_start + off_t
|
||||||
|
|
||||||
|
# Pointers
|
||||||
|
# packed_ptr: row-major [B, Lmax, D]
|
||||||
|
packed_row_ptr = packed_ptr + (pid_b * Lmax +
|
||||||
|
off_t)[:, None] * D + off_d[None, :]
|
||||||
|
|
||||||
|
# out_ptr: row-major [N, D]
|
||||||
|
out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :]
|
||||||
|
|
||||||
|
# Load from packed tensor and store to output
|
||||||
|
d_mask = off_d[None, :] < D
|
||||||
|
packed_vals = tl.load(packed_row_ptr, mask=valid_row[:, None] & d_mask)
|
||||||
|
tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def unpack_seq_triton(packed_tensor: torch.Tensor,
|
||||||
|
lengths: torch.Tensor,
|
||||||
|
block_t: int = 64,
|
||||||
|
block_d: int = 64) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Unpack a packed decode query tensor back to the original format.
|
||||||
|
Efficient Triton implementation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
packed_tensor: [B, Lmax, ...] - packed tensor from pack_seq_triton
|
||||||
|
lengths: [B] - sequence lengths for each batch
|
||||||
|
block_t: block size for time dimension
|
||||||
|
block_d: block size for feature dimension
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
unpacked_tensor: [N, ...] where N = sum(lengths)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Handle multi-dimensional input by reshaping to (B, Lmax, -1)
|
||||||
|
original_shape = packed_tensor.shape
|
||||||
|
if len(original_shape) > 3:
|
||||||
|
B, Lmax = original_shape[:2]
|
||||||
|
packed_reshaped = packed_tensor.reshape(B, Lmax, -1)
|
||||||
|
D = packed_reshaped.shape[2]
|
||||||
|
else:
|
||||||
|
B, Lmax, D = packed_tensor.shape
|
||||||
|
packed_reshaped = packed_tensor
|
||||||
|
|
||||||
|
# Calculate total number of elements
|
||||||
|
N = int(lengths.sum().item())
|
||||||
|
|
||||||
|
out = torch.empty((N, D),
|
||||||
|
device=packed_tensor.device,
|
||||||
|
dtype=packed_tensor.dtype)
|
||||||
|
|
||||||
|
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
|
||||||
|
_unpack_seq_triton_kernel[grid](packed_reshaped,
|
||||||
|
out,
|
||||||
|
lengths.int(),
|
||||||
|
B,
|
||||||
|
Lmax,
|
||||||
|
D,
|
||||||
|
BLOCK_T=block_t,
|
||||||
|
BLOCK_D=block_d,
|
||||||
|
num_warps=4,
|
||||||
|
num_stages=2)
|
||||||
|
|
||||||
|
# Reshape output back to original dimensions (except first dimension)
|
||||||
|
if len(original_shape) > 3:
|
||||||
|
output_shape = (N, ) + original_shape[2:]
|
||||||
|
out = out.reshape(output_shape)
|
||||||
|
|
||||||
|
return out
|
||||||
192
vllm/attention/ops/flashmla.py
Normal file
192
vllm/attention/ops/flashmla.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
try:
|
||||||
|
import vllm._flashmla_C # noqa: F401
|
||||||
|
_flashmla_C_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
_flashmla_C_AVAILABLE = False
|
||||||
|
else:
|
||||||
|
_flashmla_C_AVAILABLE = False
|
||||||
|
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
try:
|
||||||
|
import vllm._flashmla_extension_C # noqa: F401
|
||||||
|
_flashmla_extension_C_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
_flashmla_extension_C_AVAILABLE = False
|
||||||
|
else:
|
||||||
|
_flashmla_extension_C_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Return: is_supported_flag, unsupported_reason (optional).
|
||||||
|
"""
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
return False, "FlashMLA is only supported on CUDA devices."
|
||||||
|
if current_platform.get_device_capability()[0] != 9:
|
||||||
|
return False, "FlashMLA is only supported on Hopper devices."
|
||||||
|
if not _flashmla_C_AVAILABLE:
|
||||||
|
return False, "vllm._flashmla_C is not available, likely was not "\
|
||||||
|
"compiled due to insufficient nvcc version or a supported arch "\
|
||||||
|
"(only sm90a currently) was not in the list of target arches to "\
|
||||||
|
"compile for."
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
def get_mla_metadata(
|
||||||
|
cache_seqlens: torch.Tensor,
|
||||||
|
num_q_tokens_per_head_k: int,
|
||||||
|
num_heads_k: int,
|
||||||
|
num_heads_q: Optional[int] = None,
|
||||||
|
is_fp8_kvcache: bool = False,
|
||||||
|
topk: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
- cache_seqlens: (batch_size), dtype torch.int32.
|
||||||
|
- num_q_tokens_per_head_k:
|
||||||
|
Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
|
||||||
|
- num_heads_k: The number of k heads.
|
||||||
|
- num_heads_q:
|
||||||
|
The number of q heads.
|
||||||
|
This argument is optional when sparse attention is not enabled
|
||||||
|
- is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
|
||||||
|
- topk: If not None, sparse attention will be enabled,
|
||||||
|
and only tokens in the `indices` array
|
||||||
|
passed to `flash_mla_with_kvcache_sm90` will be attended to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- tile_scheduler_metadata:
|
||||||
|
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
||||||
|
- num_splits: (batch_size + 1), dtype torch.int32.
|
||||||
|
"""
|
||||||
|
return torch.ops._flashmla_C.get_mla_decoding_metadata(
|
||||||
|
cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q,
|
||||||
|
is_fp8_kvcache, topk)
|
||||||
|
|
||||||
|
|
||||||
|
def flash_mla_with_kvcache(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
block_table: torch.Tensor,
|
||||||
|
cache_seqlens: torch.Tensor,
|
||||||
|
head_dim_v: int,
|
||||||
|
tile_scheduler_metadata: torch.Tensor,
|
||||||
|
num_splits: torch.Tensor,
|
||||||
|
softmax_scale: Optional[float] = None,
|
||||||
|
causal: bool = False,
|
||||||
|
descale_q: Optional[torch.Tensor] = None,
|
||||||
|
descale_k: Optional[torch.Tensor] = None,
|
||||||
|
is_fp8_kvcache: bool = False,
|
||||||
|
indices: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
- q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||||
|
- k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||||
|
- block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||||
|
- cache_seqlens: (batch_size), torch.int32.
|
||||||
|
- head_dim_v: Head dimension of v.
|
||||||
|
- tile_scheduler_metadata:
|
||||||
|
(num_sm_parts, TileSchedulerMetaDataSize), torch.int32,
|
||||||
|
returned by get_mla_metadata.
|
||||||
|
- num_splits:
|
||||||
|
(batch_size + 1), torch.int32, returned by get_mla_metadata.
|
||||||
|
- softmax_scale: float.
|
||||||
|
The scale of QK^T before applying softmax.
|
||||||
|
Default to 1 / sqrt(head_dim).
|
||||||
|
- causal: bool. Whether to apply causal attention mask.
|
||||||
|
- descale_q: (batch_size),
|
||||||
|
torch.float32. Descaling factors for Q, used for fp8 quantization.
|
||||||
|
- descale_k: (batch_size),
|
||||||
|
torch.float32. Descaling factors for K, used for fp8 quantization.
|
||||||
|
- is_fp8_kvcache: bool.
|
||||||
|
Whether the k_cache and v_cache are in fp8 format.
|
||||||
|
For the format of FP8 KV cache, please refer to README.md
|
||||||
|
- indices: (batch_size, seq_len_q, topk), torch.int32.
|
||||||
|
If not None, sparse attention will be enabled,
|
||||||
|
and only tokens in the `indices` array will be attended to.
|
||||||
|
Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
|
||||||
|
For details about how to set up `indices`, please refer to README.md.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||||
|
- softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||||
|
"""
|
||||||
|
if softmax_scale is None:
|
||||||
|
softmax_scale = q.shape[-1]**(-0.5)
|
||||||
|
if indices is not None:
|
||||||
|
# NOTE (zyongye): sparse attention is also causal
|
||||||
|
# since it only attend to the tokens before
|
||||||
|
# but here `causal` should not be specified
|
||||||
|
assert not causal, \
|
||||||
|
"causal must be `false` if sparse attention is enabled."
|
||||||
|
assert (descale_q is None) == (
|
||||||
|
descale_k is None
|
||||||
|
), "descale_q and descale_k should be both None or both not None"
|
||||||
|
|
||||||
|
if indices is None and q.element_size() == 1:
|
||||||
|
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
|
||||||
|
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
|
||||||
|
causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)
|
||||||
|
else:
|
||||||
|
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
|
||||||
|
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
|
||||||
|
causal, tile_scheduler_metadata, num_splits, is_fp8_kvcache,
|
||||||
|
indices)
|
||||||
|
return out, softmax_lse
|
||||||
|
|
||||||
|
|
||||||
|
def flash_mla_sparse_prefill(
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv: torch.Tensor,
|
||||||
|
indices: torch.Tensor,
|
||||||
|
sm_scale: float,
|
||||||
|
d_v: int = 512,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Sparse attention prefill kernel
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- q: [s_q, h_q, d_qk], bfloat16
|
||||||
|
- kv: [s_kv, h_kv, d_qk], bfloat16
|
||||||
|
- indices: [s_q, h_kv, topk], int32.
|
||||||
|
Invalid indices should be set to -1 or numbers >= s_kv
|
||||||
|
- sm_scale: float
|
||||||
|
- d_v: The dimension of value vectors. Can only be 512
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- (output, max_logits, lse)
|
||||||
|
About the definition of output,
|
||||||
|
max_logits and lse, please refer to README.md
|
||||||
|
- output: [s_q, h_q, d_v], bfloat16
|
||||||
|
- max_logits: [s_q, h_q], float
|
||||||
|
- lse: [s_q, h_q], float, 2-based log-sum-exp
|
||||||
|
"""
|
||||||
|
results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices,
|
||||||
|
sm_scale, d_v)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# TODO: Add fake functions
|
||||||
|
#
|
||||||
|
# @register_fake("_flashmla_C::get_mla_metadata")
|
||||||
|
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# return ....
|
||||||
|
#
|
||||||
|
# @register_fake("_flashmla_C::fwd_kvcache_mla")
|
||||||
|
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# return ....
|
||||||
|
#
|
||||||
43
vllm/attention/ops/merge_attn_states.py
Normal file
43
vllm/attention/ops/merge_attn_states.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
def merge_attn_states(
|
||||||
|
output: torch.Tensor,
|
||||||
|
prefix_output: torch.Tensor,
|
||||||
|
prefix_lse: torch.Tensor,
|
||||||
|
suffix_output: torch.Tensor,
|
||||||
|
suffix_lse: torch.Tensor,
|
||||||
|
output_lse: Optional[torch.Tensor] = None,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel
|
||||||
|
# is not support for FP8 dtype, fallback to use Triton kernel.
|
||||||
|
def supported_dtypes(o: torch.Tensor) -> bool:
|
||||||
|
return o.dtype in [torch.float32, torch.half, torch.bfloat16]
|
||||||
|
|
||||||
|
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA
|
||||||
|
# kernel load/store 128b(16 bytes) per memory issue within
|
||||||
|
# thread. Namely, the headsize(headdim) must be multiple of
|
||||||
|
# pack_size (float32 -> 4, half/bfloat16 -> 8).
|
||||||
|
def supported_headdim(o: torch.Tensor) -> bool:
|
||||||
|
headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||||
|
if o.dtype == torch.float32:
|
||||||
|
return headdim % 4 == 0
|
||||||
|
return headdim % 8 == 0
|
||||||
|
|
||||||
|
if (current_platform.is_cuda() and supported_dtypes(output)
|
||||||
|
and supported_headdim(output)):
|
||||||
|
from vllm._custom_ops import merge_attn_states
|
||||||
|
return merge_attn_states(output, prefix_output, prefix_lse,
|
||||||
|
suffix_output, suffix_lse, output_lse)
|
||||||
|
else:
|
||||||
|
from vllm.attention.ops.triton_merge_attn_states import (
|
||||||
|
merge_attn_states)
|
||||||
|
return merge_attn_states(output, prefix_output, prefix_lse,
|
||||||
|
suffix_output, suffix_lse, output_lse)
|
||||||
262
vllm/attention/ops/paged_attn.py
Normal file
262
vllm/attention/ops/paged_attn.py
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import HAS_TRITON
|
||||||
|
|
||||||
|
if current_platform.is_cuda_alike():
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
elif current_platform.is_xpu():
|
||||||
|
from vllm._ipex_ops import ipex_ops as ops
|
||||||
|
|
||||||
|
if HAS_TRITON:
|
||||||
|
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||||
|
|
||||||
|
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||||
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PagedAttentionMetadata:
|
||||||
|
"""Metadata for PagedAttention."""
|
||||||
|
# (batch_size,). The length of sequences (entire tokens seen so far) per
|
||||||
|
# sequence.
|
||||||
|
seq_lens_tensor: Optional[torch.Tensor]
|
||||||
|
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
|
||||||
|
max_decode_seq_len: int
|
||||||
|
# (batch_size, max_blocks_per_seq).
|
||||||
|
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||||
|
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||||
|
# in the kv cache. Each block can contain up to block_size tokens.
|
||||||
|
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||||
|
# captured.
|
||||||
|
block_tables: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class PagedAttention:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_supported_head_sizes() -> List[int]:
|
||||||
|
return [32, 64, 80, 96, 112, 120, 128, 192, 256]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_kv_cache_shape(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
cache_dtype_str: str = "auto",
|
||||||
|
) -> Tuple[int, ...]:
|
||||||
|
return (2, num_blocks, block_size * num_kv_heads * head_size)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def split_kv_cache(
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
x = 16 // kv_cache.element_size()
|
||||||
|
num_blocks = kv_cache.shape[1]
|
||||||
|
|
||||||
|
key_cache = kv_cache[0]
|
||||||
|
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
|
||||||
|
-1, x)
|
||||||
|
value_cache = kv_cache[1]
|
||||||
|
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
||||||
|
return key_cache, value_cache
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def write_to_paged_cache(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: torch.Tensor,
|
||||||
|
v_scale: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
ops.reshape_and_cache(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
slot_mapping.flatten(),
|
||||||
|
kv_cache_dtype,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward_decode(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
max_seq_len: int,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
num_kv_heads: int,
|
||||||
|
scale: float,
|
||||||
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
|
k_scale: torch.Tensor,
|
||||||
|
v_scale: torch.Tensor,
|
||||||
|
tp_rank: int = 0,
|
||||||
|
blocksparse_local_blocks: int = 0,
|
||||||
|
blocksparse_vert_stride: int = 0,
|
||||||
|
blocksparse_block_size: int = 64,
|
||||||
|
blocksparse_head_sliding_step: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||||
|
# use blocksparse paged attention
|
||||||
|
block_size = value_cache.size(-1)
|
||||||
|
assert (blocksparse_block_size > 0 and
|
||||||
|
blocksparse_block_size % block_size == 0), \
|
||||||
|
(f"{blocksparse_block_size=} needs to be a multiple of"
|
||||||
|
f"{block_size=} used in block_tables.")
|
||||||
|
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
block_size = value_cache.shape[3]
|
||||||
|
num_seqs, num_heads, head_size = query.shape
|
||||||
|
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
|
||||||
|
_PARTITION_SIZE)
|
||||||
|
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||||
|
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||||
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
|
# to parallelize.
|
||||||
|
# TODO(woosuk): Tune this heuristic.
|
||||||
|
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
|
||||||
|
use_v1 = (max_seq_len <= 8192
|
||||||
|
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
|
||||||
|
|
||||||
|
if use_v1:
|
||||||
|
# Run PagedAttention V1.
|
||||||
|
ops.paged_attention_v1(
|
||||||
|
output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
num_kv_heads,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
seq_lens,
|
||||||
|
block_size,
|
||||||
|
max_seq_len,
|
||||||
|
alibi_slopes,
|
||||||
|
kv_cache_dtype,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
tp_rank,
|
||||||
|
blocksparse_local_blocks,
|
||||||
|
blocksparse_vert_stride,
|
||||||
|
blocksparse_block_size,
|
||||||
|
blocksparse_head_sliding_step,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Run PagedAttention V2.
|
||||||
|
assert _PARTITION_SIZE % block_size == 0
|
||||||
|
tmp_output = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||||
|
dtype=output.dtype,
|
||||||
|
device=output.device,
|
||||||
|
)
|
||||||
|
exp_sums = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, max_num_partitions),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=output.device,
|
||||||
|
)
|
||||||
|
max_logits = torch.empty_like(exp_sums)
|
||||||
|
ops.paged_attention_v2(
|
||||||
|
output,
|
||||||
|
exp_sums,
|
||||||
|
max_logits,
|
||||||
|
tmp_output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
num_kv_heads,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
seq_lens,
|
||||||
|
block_size,
|
||||||
|
max_seq_len,
|
||||||
|
alibi_slopes,
|
||||||
|
kv_cache_dtype,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
tp_rank,
|
||||||
|
blocksparse_local_blocks,
|
||||||
|
blocksparse_vert_stride,
|
||||||
|
blocksparse_block_size,
|
||||||
|
blocksparse_head_sliding_step,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward_prefix(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
query_start_loc: torch.Tensor,
|
||||||
|
seq_lens_tensor: torch.Tensor,
|
||||||
|
max_query_len: int,
|
||||||
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
|
sliding_window: Optional[int],
|
||||||
|
k_scale: torch.Tensor,
|
||||||
|
v_scale: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
max_seq_len = None
|
||||||
|
context_attention_fwd(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
output,
|
||||||
|
kv_cache_dtype,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
block_tables,
|
||||||
|
# query_start_loc is (batch_size + 1,)
|
||||||
|
query_start_loc,
|
||||||
|
seq_lens_tensor,
|
||||||
|
max_seq_len,
|
||||||
|
max_query_len,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
alibi_slopes,
|
||||||
|
sliding_window,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def swap_blocks(
|
||||||
|
src_kv_cache: torch.Tensor,
|
||||||
|
dst_kv_cache: torch.Tensor,
|
||||||
|
src_to_dst: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
src_key_cache = src_kv_cache[0]
|
||||||
|
dst_key_cache = dst_kv_cache[0]
|
||||||
|
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||||
|
|
||||||
|
src_value_cache = src_kv_cache[1]
|
||||||
|
dst_value_cache = dst_kv_cache[1]
|
||||||
|
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def copy_blocks(
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
src_to_dists: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||||
|
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||||
|
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||||
124
vllm/attention/ops/pallas_kv_cache_update.py
Normal file
124
vllm/attention/ops/pallas_kv_cache_update.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import jax
|
||||||
|
from jax.experimental import pallas as pl
|
||||||
|
from jax.experimental.pallas import tpu as pltpu
|
||||||
|
|
||||||
|
from vllm.utils import cdiv
|
||||||
|
|
||||||
|
|
||||||
|
def _kv_cache_update_kernel(
|
||||||
|
# Prefetch
|
||||||
|
slices_ref, # [3, padded_num_slices], list of (kv_cache_start,
|
||||||
|
# new_kv_start, slice_len)
|
||||||
|
num_slices_ref, # [1]
|
||||||
|
# Input
|
||||||
|
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
|
||||||
|
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
|
||||||
|
# head_dim]
|
||||||
|
# Output
|
||||||
|
_, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
|
||||||
|
# Scratch
|
||||||
|
scratch, # [num_slices_per_block, page_size, num_combined_kv_heads,
|
||||||
|
# head_dim]
|
||||||
|
sem,
|
||||||
|
):
|
||||||
|
async_copies = []
|
||||||
|
block_idx = pl.program_id(0)
|
||||||
|
num_slices_per_block = scratch.shape[0]
|
||||||
|
|
||||||
|
# Copy from new_kv_hbm_ref to scratch
|
||||||
|
for i in range(num_slices_per_block):
|
||||||
|
offset_i = i + block_idx * num_slices_per_block
|
||||||
|
new_kv_start = jax.lax.select(offset_i < num_slices_ref[0],
|
||||||
|
slices_ref[1, offset_i], 0)
|
||||||
|
length = jax.lax.select(offset_i < num_slices_ref[0],
|
||||||
|
slices_ref[2, offset_i], 0)
|
||||||
|
async_copy = pltpu.make_async_copy(
|
||||||
|
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
|
||||||
|
scratch.at[i, pl.ds(0, length), ...],
|
||||||
|
sem,
|
||||||
|
)
|
||||||
|
async_copy.start()
|
||||||
|
async_copies.append(async_copy)
|
||||||
|
|
||||||
|
for async_copy in async_copies:
|
||||||
|
async_copy.wait()
|
||||||
|
|
||||||
|
# Copy from scratch to kv_cache_hbm_ref
|
||||||
|
async_copies.clear()
|
||||||
|
for i in range(num_slices_per_block):
|
||||||
|
offset_i = i + block_idx * num_slices_per_block
|
||||||
|
kv_cache_start = jax.lax.select(offset_i < num_slices_ref[0],
|
||||||
|
slices_ref[0, offset_i], 0)
|
||||||
|
length = jax.lax.select(offset_i < num_slices_ref[0],
|
||||||
|
slices_ref[2, offset_i], 0)
|
||||||
|
async_copy = pltpu.make_async_copy(
|
||||||
|
scratch.at[i, pl.ds(0, length), ...],
|
||||||
|
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
|
||||||
|
sem,
|
||||||
|
)
|
||||||
|
async_copy.start()
|
||||||
|
async_copies.append(async_copy)
|
||||||
|
for async_copy in async_copies:
|
||||||
|
async_copy.wait()
|
||||||
|
|
||||||
|
|
||||||
|
@functools.partial(
|
||||||
|
jax.jit,
|
||||||
|
static_argnames=["page_size", "num_slices_per_block"],
|
||||||
|
)
|
||||||
|
def kv_cache_update(
|
||||||
|
new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim]
|
||||||
|
slices: jax.
|
||||||
|
Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
|
||||||
|
kv_cache: jax.
|
||||||
|
Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
|
||||||
|
num_kv_update_slices: jax.Array, # [1]
|
||||||
|
*,
|
||||||
|
page_size: int = 32,
|
||||||
|
num_slices_per_block: int = 8,
|
||||||
|
):
|
||||||
|
_, num_combined_kv_heads, head_dim = new_kv.shape
|
||||||
|
assert kv_cache.shape[1] == num_combined_kv_heads
|
||||||
|
assert kv_cache.shape[2] == head_dim
|
||||||
|
assert head_dim % 128 == 0
|
||||||
|
# TODO: Add dynamic check to make sure that the all the slice lengths are
|
||||||
|
# smaller or equal to page_size
|
||||||
|
|
||||||
|
in_specs = [
|
||||||
|
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||||
|
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||||
|
]
|
||||||
|
|
||||||
|
out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
|
||||||
|
out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]
|
||||||
|
|
||||||
|
scalar_prefetches = [slices, num_kv_update_slices]
|
||||||
|
scratch = pltpu.VMEM(
|
||||||
|
(num_slices_per_block, page_size, num_combined_kv_heads, head_dim),
|
||||||
|
new_kv.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
scratch_shapes = [
|
||||||
|
scratch,
|
||||||
|
pltpu.SemaphoreType.DMA,
|
||||||
|
]
|
||||||
|
|
||||||
|
kernel = pl.pallas_call(
|
||||||
|
_kv_cache_update_kernel,
|
||||||
|
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||||
|
num_scalar_prefetch=len(scalar_prefetches),
|
||||||
|
in_specs=in_specs,
|
||||||
|
out_specs=out_specs,
|
||||||
|
grid=(cdiv(num_kv_update_slices[0], num_slices_per_block), ),
|
||||||
|
scratch_shapes=scratch_shapes,
|
||||||
|
),
|
||||||
|
out_shape=out_shape,
|
||||||
|
input_output_aliases={len(scalar_prefetches) + 1: 0},
|
||||||
|
)
|
||||||
|
|
||||||
|
return kernel(*scalar_prefetches, new_kv, kv_cache)[0]
|
||||||
928
vllm/attention/ops/prefix_prefill.py
Normal file
928
vllm/attention/ops/prefix_prefill.py
Normal file
@@ -0,0 +1,928 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
|
||||||
|
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
# Static kernels parameters
|
||||||
|
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
|
||||||
|
NUM_WARPS = 4 if current_platform.is_rocm() else 8
|
||||||
|
|
||||||
|
# To check compatibility
|
||||||
|
IS_TURING = current_platform.get_device_capability() == (7, 5)
|
||||||
|
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||||
|
|
||||||
|
|
||||||
|
# Here's an example autotuner config for this kernel. This config does provide
|
||||||
|
# a performance improvement, but dramatically increases first call latency in
|
||||||
|
# triton 3.2. Because of this tradeoff, it's currently commented out.
|
||||||
|
# @triton.autotune(
|
||||||
|
# configs=[
|
||||||
|
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \
|
||||||
|
# "num_unroll_cache": 4, \
|
||||||
|
# "num_unroll_request": 1 } | \
|
||||||
|
# ({"kpack": 2, "waves_per_eu": 2} \
|
||||||
|
# if current_platform.is_rocm() else {}), \
|
||||||
|
# num_warps=4, \
|
||||||
|
# num_stages=1)
|
||||||
|
# ],
|
||||||
|
# key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"]
|
||||||
|
# )
|
||||||
|
@triton.jit
|
||||||
|
def _fwd_kernel(Q,
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
K_cache,
|
||||||
|
V_cache,
|
||||||
|
sink_ptr,
|
||||||
|
B_Loc,
|
||||||
|
sm_scale,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
out_scale_inv,
|
||||||
|
B_Start_Loc,
|
||||||
|
B_Seqlen,
|
||||||
|
x: tl.constexpr,
|
||||||
|
Out,
|
||||||
|
stride_b_loc_b,
|
||||||
|
stride_b_loc_s,
|
||||||
|
stride_qbs,
|
||||||
|
stride_qh,
|
||||||
|
stride_qd,
|
||||||
|
stride_kbs,
|
||||||
|
stride_kh,
|
||||||
|
stride_kd,
|
||||||
|
stride_vbs,
|
||||||
|
stride_vh,
|
||||||
|
stride_vd,
|
||||||
|
stride_obs,
|
||||||
|
stride_oh,
|
||||||
|
stride_od,
|
||||||
|
stride_k_cache_bs,
|
||||||
|
stride_k_cache_h,
|
||||||
|
stride_k_cache_d,
|
||||||
|
stride_k_cache_bl: tl.constexpr,
|
||||||
|
stride_k_cache_x,
|
||||||
|
stride_v_cache_bs,
|
||||||
|
stride_v_cache_h,
|
||||||
|
stride_v_cache_d,
|
||||||
|
stride_v_cache_bl,
|
||||||
|
num_queries_per_kv: tl.constexpr,
|
||||||
|
IN_PRECISION: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
|
BLOCK_DMODEL_PADDED: tl.constexpr,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
SLIDING_WINDOW: tl.constexpr,
|
||||||
|
num_unroll_cache: tl.constexpr,
|
||||||
|
num_unroll_request: tl.constexpr,
|
||||||
|
SKIP_DECODE: tl.constexpr,
|
||||||
|
USE_SINKS: tl.constexpr,
|
||||||
|
USE_FP8: tl.constexpr,
|
||||||
|
MAX_Q_LEN: tl.constexpr = 0,
|
||||||
|
MAX_CTX_LEN: tl.constexpr = 0,
|
||||||
|
FP8_MIN: tl.constexpr = float8_info.min,
|
||||||
|
FP8_MAX: tl.constexpr = float8_info.max):
|
||||||
|
|
||||||
|
cur_batch = tl.program_id(0)
|
||||||
|
cur_head = tl.program_id(1)
|
||||||
|
start_m = tl.program_id(2)
|
||||||
|
|
||||||
|
cur_kv_head = cur_head // num_queries_per_kv
|
||||||
|
|
||||||
|
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||||
|
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||||
|
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
|
||||||
|
cur_batch_query_len = (cur_batch_in_all_stop_index -
|
||||||
|
cur_batch_in_all_start_index)
|
||||||
|
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
|
||||||
|
|
||||||
|
if SKIP_DECODE and cur_batch_query_len == 1:
|
||||||
|
return
|
||||||
|
|
||||||
|
# start position inside of the query
|
||||||
|
# generally, N goes over kv, while M goes over query_len
|
||||||
|
block_start_loc = BLOCK_M * start_m
|
||||||
|
|
||||||
|
# initialize offsets
|
||||||
|
# [BLOCK_SIZE]; starts at 0
|
||||||
|
offs_bs_n = tl.arange(0, BLOCK_SIZE)
|
||||||
|
# [N]; starts at 0
|
||||||
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
|
# [D]; starts at 0
|
||||||
|
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
|
||||||
|
# [M]; starts at current position in query
|
||||||
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
# [M,D]
|
||||||
|
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||||
|
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||||
|
|
||||||
|
dim_mask = tl.where(
|
||||||
|
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,
|
||||||
|
0).to(tl.int1) # [D]
|
||||||
|
|
||||||
|
q = tl.load(Q + off_q,
|
||||||
|
mask=dim_mask[None, :] &
|
||||||
|
(offs_m[:, None] < cur_batch_query_len),
|
||||||
|
other=0.0) # [M,D]
|
||||||
|
|
||||||
|
# initialize pointer to m and l
|
||||||
|
if not USE_SINKS:
|
||||||
|
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||||
|
else:
|
||||||
|
m_i = tl.load(
|
||||||
|
sink_ptr + tl.full([BLOCK_M], cur_head, dtype=tl.int64),
|
||||||
|
mask=(offs_m < cur_batch_query_len),
|
||||||
|
other=float("-inf"),
|
||||||
|
).to(dtype=tl.float32)
|
||||||
|
|
||||||
|
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||||
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D]
|
||||||
|
|
||||||
|
# compute query against context (no causal mask here)
|
||||||
|
for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \
|
||||||
|
loop_unroll_factor=num_unroll_cache):
|
||||||
|
start_n = tl.multiple_of(start_n, BLOCK_SIZE)
|
||||||
|
# -- compute qk ----
|
||||||
|
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||||
|
(start_n // BLOCK_SIZE) * stride_b_loc_s).to(tl.int64)
|
||||||
|
# [D,BLOCK_SIZE]
|
||||||
|
off_k = (
|
||||||
|
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
||||||
|
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||||
|
((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl +
|
||||||
|
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||||
|
|
||||||
|
# [BLOCK_SIZE,D]
|
||||||
|
off_v = (bn[:, None] * stride_v_cache_bs +
|
||||||
|
cur_kv_head * stride_v_cache_h +
|
||||||
|
offs_d[None, :] * stride_v_cache_d +
|
||||||
|
offs_bs_n[:, None] * stride_v_cache_bl)
|
||||||
|
|
||||||
|
if start_n + BLOCK_SIZE > cur_batch_ctx_len or \
|
||||||
|
BLOCK_DMODEL != BLOCK_DMODEL_PADDED:
|
||||||
|
k_load = tl.load(
|
||||||
|
K_cache + off_k,
|
||||||
|
mask=dim_mask[:, None] &
|
||||||
|
((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len),
|
||||||
|
other=0.0) # [D,N]
|
||||||
|
else:
|
||||||
|
k_load = tl.load(K_cache + off_k)
|
||||||
|
|
||||||
|
if k_load.dtype.is_fp8():
|
||||||
|
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
|
||||||
|
else:
|
||||||
|
k = k_load
|
||||||
|
|
||||||
|
qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N]
|
||||||
|
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||||
|
qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||||
|
float("-inf"))
|
||||||
|
qk *= sm_scale
|
||||||
|
if SLIDING_WINDOW > 0:
|
||||||
|
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
|
||||||
|
# Q entries in sequence
|
||||||
|
# (start_n + offs_bs_n[None, :]) are the positions of
|
||||||
|
# KV entries in sequence
|
||||||
|
# So the condition makes sure each entry in Q only attends
|
||||||
|
# to KV entries not more than SLIDING_WINDOW away.
|
||||||
|
#
|
||||||
|
# We can't use -inf here, because the
|
||||||
|
# sliding window may lead to the entire row being masked.
|
||||||
|
# This then makes m_ij contain -inf, which causes NaNs in
|
||||||
|
# exp().
|
||||||
|
qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
|
||||||
|
(start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, qk,
|
||||||
|
-10000)
|
||||||
|
|
||||||
|
# compute running maximum
|
||||||
|
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
|
||||||
|
p = tl.exp(qk - m_ij[:, None])
|
||||||
|
l_ij = tl.sum(p, axis=1)
|
||||||
|
alpha = tl.exp(m_i - m_ij)
|
||||||
|
acc = acc * alpha[:, None]
|
||||||
|
|
||||||
|
# update acc
|
||||||
|
if start_n + BLOCK_SIZE > cur_batch_ctx_len or \
|
||||||
|
BLOCK_DMODEL != BLOCK_DMODEL_PADDED:
|
||||||
|
v_load = tl.load(
|
||||||
|
V_cache + off_v,
|
||||||
|
mask=dim_mask[None, :] &
|
||||||
|
((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len),
|
||||||
|
other=0.0) # [N,D]
|
||||||
|
else:
|
||||||
|
v_load = tl.load(V_cache + off_v)
|
||||||
|
|
||||||
|
if v_load.dtype.is_fp8():
|
||||||
|
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
|
||||||
|
else:
|
||||||
|
v = v_load
|
||||||
|
p = p.to(v.dtype)
|
||||||
|
|
||||||
|
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
|
||||||
|
# # update m_i and l_i
|
||||||
|
l_i = l_i * alpha + l_ij
|
||||||
|
m_i = m_ij
|
||||||
|
|
||||||
|
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
|
||||||
|
offs_d[:, None] * stride_kd)
|
||||||
|
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
|
||||||
|
offs_d[None, :] * stride_vd)
|
||||||
|
k_ptrs = K + off_k
|
||||||
|
v_ptrs = V + off_v
|
||||||
|
|
||||||
|
# block_mask is 0 when we're already past the current query length
|
||||||
|
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
|
||||||
|
|
||||||
|
# compute query against itself (with causal mask)
|
||||||
|
for start_n in tl.range(0, \
|
||||||
|
block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \
|
||||||
|
loop_unroll_factor=num_unroll_request):
|
||||||
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
|
# -- compute qk ----
|
||||||
|
k = tl.load(k_ptrs +
|
||||||
|
(cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||||
|
mask=dim_mask[:, None] &
|
||||||
|
((start_n + offs_n[None, :]) < cur_batch_query_len),
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
|
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||||
|
qk *= sm_scale
|
||||||
|
# apply causal mask
|
||||||
|
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||||
|
float("-inf"))
|
||||||
|
if SLIDING_WINDOW > 0:
|
||||||
|
qk = tl.where(
|
||||||
|
offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW,
|
||||||
|
qk, -10000)
|
||||||
|
|
||||||
|
# compute running maximum
|
||||||
|
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
|
||||||
|
p = tl.exp(qk - m_ij[:, None])
|
||||||
|
l_ij = tl.sum(p, axis=1)
|
||||||
|
alpha = tl.exp(m_i - m_ij)
|
||||||
|
acc = acc * alpha[:, None]
|
||||||
|
|
||||||
|
# update acc
|
||||||
|
v = tl.load(v_ptrs +
|
||||||
|
(cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||||
|
mask=dim_mask[None, :] &
|
||||||
|
((start_n + offs_n[:, None]) < cur_batch_query_len),
|
||||||
|
other=0.0)
|
||||||
|
p = p.to(v.dtype)
|
||||||
|
|
||||||
|
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
|
||||||
|
# update m_i and l_i
|
||||||
|
l_i = l_i * alpha + l_ij
|
||||||
|
m_i = m_ij
|
||||||
|
|
||||||
|
acc = acc / l_i[:, None]
|
||||||
|
|
||||||
|
# initialize pointers to output
|
||||||
|
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||||
|
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
||||||
|
out_ptrs = Out + off_o
|
||||||
|
if USE_FP8:
|
||||||
|
acc = acc * tl.load(out_scale_inv)
|
||||||
|
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||||
|
tl.store(out_ptrs,
|
||||||
|
acc,
|
||||||
|
mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len))
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _fwd_kernel_flash_attn_v2(
|
||||||
|
Q,
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
K_cache,
|
||||||
|
V_cache,
|
||||||
|
B_Loc,
|
||||||
|
sm_scale,
|
||||||
|
B_Start_Loc,
|
||||||
|
B_Seqlen,
|
||||||
|
B_Ctxlen,
|
||||||
|
block_size,
|
||||||
|
x,
|
||||||
|
Out,
|
||||||
|
stride_b_loc_b,
|
||||||
|
stride_b_loc_s,
|
||||||
|
stride_qbs,
|
||||||
|
stride_qh,
|
||||||
|
stride_qd,
|
||||||
|
stride_kbs,
|
||||||
|
stride_kh,
|
||||||
|
stride_kd,
|
||||||
|
stride_vbs,
|
||||||
|
stride_vh,
|
||||||
|
stride_vd,
|
||||||
|
stride_obs,
|
||||||
|
stride_oh,
|
||||||
|
stride_od,
|
||||||
|
stride_k_cache_bs,
|
||||||
|
stride_k_cache_h,
|
||||||
|
stride_k_cache_d,
|
||||||
|
stride_k_cache_bl,
|
||||||
|
stride_k_cache_x,
|
||||||
|
stride_v_cache_bs,
|
||||||
|
stride_v_cache_h,
|
||||||
|
stride_v_cache_d,
|
||||||
|
stride_v_cache_bl,
|
||||||
|
num_queries_per_kv: int,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
):
|
||||||
|
cur_batch = tl.program_id(0)
|
||||||
|
cur_head = tl.program_id(1)
|
||||||
|
start_m = tl.program_id(2)
|
||||||
|
|
||||||
|
cur_kv_head = cur_head // num_queries_per_kv
|
||||||
|
|
||||||
|
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
|
||||||
|
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||||
|
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||||
|
|
||||||
|
block_start_loc = BLOCK_M * start_m
|
||||||
|
|
||||||
|
# initialize offsets
|
||||||
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||||
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||||
|
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||||
|
|
||||||
|
q = tl.load(Q + off_q,
|
||||||
|
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
|
# # initialize pointer to m and l
|
||||||
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||||
|
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||||
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||||
|
|
||||||
|
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
||||||
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
|
# -- compute qk ----
|
||||||
|
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||||
|
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
||||||
|
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||||
|
other=0).to(tl.int64)
|
||||||
|
off_k = (
|
||||||
|
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
||||||
|
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||||
|
((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl +
|
||||||
|
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||||
|
off_v = (bn[:, None] * stride_v_cache_bs +
|
||||||
|
cur_kv_head * stride_v_cache_h +
|
||||||
|
offs_d[None, :] * stride_v_cache_d +
|
||||||
|
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||||
|
k = tl.load(K_cache + off_k,
|
||||||
|
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
|
||||||
|
other=0.0)
|
||||||
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
|
qk += tl.dot(q, k)
|
||||||
|
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||||
|
float("-inf"))
|
||||||
|
qk *= sm_scale
|
||||||
|
|
||||||
|
# -- compute m_ij, p, l_ij
|
||||||
|
m_ij = tl.max(qk, 1)
|
||||||
|
m_i_new = tl.maximum(m_i, m_ij)
|
||||||
|
p = tl.math.exp(qk - m_i_new[:, None])
|
||||||
|
l_ij = tl.sum(p, 1)
|
||||||
|
# -- update m_i and l_i
|
||||||
|
|
||||||
|
alpha = tl.math.exp(m_i - m_i_new)
|
||||||
|
l_i_new = alpha * l_i + l_ij
|
||||||
|
# -- update output accumulator --
|
||||||
|
# scale p
|
||||||
|
# scale acc
|
||||||
|
acc_scale = alpha
|
||||||
|
# acc_scale = l_i / l_i_new * alpha
|
||||||
|
acc = acc * acc_scale[:, None]
|
||||||
|
# update acc
|
||||||
|
v = tl.load(V_cache + off_v,
|
||||||
|
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
|
p = p.to(v.dtype)
|
||||||
|
acc += tl.dot(p, v)
|
||||||
|
# update m_i and l_i
|
||||||
|
l_i = l_i_new
|
||||||
|
m_i = m_i_new
|
||||||
|
|
||||||
|
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
|
||||||
|
offs_d[:, None] * stride_kd)
|
||||||
|
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
|
||||||
|
offs_d[None, :] * stride_vd)
|
||||||
|
k_ptrs = K + off_k
|
||||||
|
v_ptrs = V + off_v
|
||||||
|
|
||||||
|
block_mask = tl.where(
|
||||||
|
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
|
||||||
|
|
||||||
|
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||||
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
|
# -- compute qk ----
|
||||||
|
k = tl.load(k_ptrs +
|
||||||
|
(cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||||
|
mask=(start_n + offs_n[None, :])
|
||||||
|
< cur_batch_seq_len - cur_batch_ctx_len,
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
|
qk += tl.dot(q, k)
|
||||||
|
qk *= sm_scale
|
||||||
|
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||||
|
float("-inf"))
|
||||||
|
|
||||||
|
# -- compute m_ij, p, l_ij
|
||||||
|
m_ij = tl.max(qk, 1)
|
||||||
|
m_i_new = tl.maximum(m_i, m_ij)
|
||||||
|
p = tl.math.exp(qk - m_i_new[:, None])
|
||||||
|
l_ij = tl.sum(p, 1)
|
||||||
|
# -- update m_i and l_i
|
||||||
|
|
||||||
|
alpha = tl.math.exp(m_i - m_i_new)
|
||||||
|
l_i_new = alpha * l_i + l_ij
|
||||||
|
# -- update output accumulator --
|
||||||
|
# scale p
|
||||||
|
# scale acc
|
||||||
|
acc_scale = alpha
|
||||||
|
# acc_scale = l_i / l_i_new * alpha
|
||||||
|
acc = acc * acc_scale[:, None]
|
||||||
|
# update acc
|
||||||
|
v = tl.load(v_ptrs +
|
||||||
|
(cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||||
|
mask=(start_n + offs_n[:, None])
|
||||||
|
< cur_batch_seq_len - cur_batch_ctx_len,
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
|
p = p.to(v.dtype)
|
||||||
|
acc += tl.dot(p, v)
|
||||||
|
# update m_i and l_i
|
||||||
|
l_i = l_i_new
|
||||||
|
m_i = m_i_new
|
||||||
|
|
||||||
|
# acc /= l_i[:, None]
|
||||||
|
# initialize pointers to output
|
||||||
|
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||||
|
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
||||||
|
out_ptrs = Out + off_o
|
||||||
|
tl.store(out_ptrs,
|
||||||
|
acc,
|
||||||
|
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _fwd_kernel_alibi(
|
||||||
|
Q,
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
K_cache,
|
||||||
|
V_cache,
|
||||||
|
B_Loc,
|
||||||
|
sm_scale,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
B_Start_Loc,
|
||||||
|
B_Seqlen,
|
||||||
|
Alibi_slopes,
|
||||||
|
block_size,
|
||||||
|
x,
|
||||||
|
Out,
|
||||||
|
stride_b_loc_b,
|
||||||
|
stride_b_loc_s,
|
||||||
|
stride_qbs,
|
||||||
|
stride_qh,
|
||||||
|
stride_qd,
|
||||||
|
stride_kbs,
|
||||||
|
stride_kh,
|
||||||
|
stride_kd,
|
||||||
|
stride_vbs,
|
||||||
|
stride_vh,
|
||||||
|
stride_vd,
|
||||||
|
stride_obs,
|
||||||
|
stride_oh,
|
||||||
|
stride_od,
|
||||||
|
stride_k_cache_bs,
|
||||||
|
stride_k_cache_h,
|
||||||
|
stride_k_cache_d,
|
||||||
|
stride_k_cache_bl,
|
||||||
|
stride_k_cache_x,
|
||||||
|
stride_v_cache_bs,
|
||||||
|
stride_v_cache_h,
|
||||||
|
stride_v_cache_d,
|
||||||
|
stride_v_cache_bl,
|
||||||
|
num_queries_per_kv: int,
|
||||||
|
IN_PRECISION: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_DMODEL: tl.constexpr, # head size
|
||||||
|
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
SKIP_DECODE: tl.constexpr,
|
||||||
|
):
|
||||||
|
# attn_bias[]
|
||||||
|
cur_batch = tl.program_id(0)
|
||||||
|
cur_head = tl.program_id(1)
|
||||||
|
start_m = tl.program_id(2)
|
||||||
|
|
||||||
|
cur_kv_head = cur_head // num_queries_per_kv
|
||||||
|
|
||||||
|
# cur_batch_seq_len: the length of prompts
|
||||||
|
# cur_batch_ctx_len: the length of prefix
|
||||||
|
# cur_batch_in_all_start_index: the start id of the dim=0
|
||||||
|
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||||
|
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||||
|
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
|
||||||
|
cur_batch_query_len = (cur_batch_in_all_stop_index -
|
||||||
|
cur_batch_in_all_start_index)
|
||||||
|
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
|
||||||
|
|
||||||
|
if SKIP_DECODE and cur_batch_query_len == 1:
|
||||||
|
return
|
||||||
|
|
||||||
|
block_start_loc = BLOCK_M * start_m
|
||||||
|
|
||||||
|
# initialize offsets
|
||||||
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
|
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
|
||||||
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||||
|
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||||
|
|
||||||
|
dim_mask = tl.where(
|
||||||
|
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
|
||||||
|
|
||||||
|
q = tl.load(Q + off_q,
|
||||||
|
mask=dim_mask[None, :] &
|
||||||
|
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
|
# # initialize pointer to m and l
|
||||||
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||||
|
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||||
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
|
||||||
|
|
||||||
|
alibi_slope = tl.load(Alibi_slopes + cur_head)
|
||||||
|
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
|
||||||
|
alibi_start_k = 0
|
||||||
|
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
||||||
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
|
# -- compute qk ----
|
||||||
|
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||||
|
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
||||||
|
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||||
|
other=0).to(tl.int64)
|
||||||
|
off_k = (
|
||||||
|
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
||||||
|
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||||
|
((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl +
|
||||||
|
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||||
|
off_v = (bn[:, None] * stride_v_cache_bs +
|
||||||
|
cur_kv_head * stride_v_cache_h +
|
||||||
|
offs_d[None, :] * stride_v_cache_d +
|
||||||
|
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||||
|
k_load = tl.load(K_cache + off_k,
|
||||||
|
mask=dim_mask[:, None] &
|
||||||
|
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
|
||||||
|
other=0.0) # [D,N]
|
||||||
|
|
||||||
|
if k_load.dtype.is_fp8():
|
||||||
|
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
|
||||||
|
else:
|
||||||
|
k = k_load
|
||||||
|
|
||||||
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
|
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||||
|
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||||
|
float("-inf"))
|
||||||
|
qk *= sm_scale
|
||||||
|
|
||||||
|
# load alibi
|
||||||
|
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
|
||||||
|
alibi_start_q[:, None]) * alibi_slope
|
||||||
|
alibi = tl.where(
|
||||||
|
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi,
|
||||||
|
float("-inf"))
|
||||||
|
qk += alibi
|
||||||
|
alibi_start_k += BLOCK_N
|
||||||
|
|
||||||
|
# -- compute m_ij, p, l_ij
|
||||||
|
m_ij = tl.max(qk, 1)
|
||||||
|
m_i_new = tl.maximum(m_i, m_ij)
|
||||||
|
p = tl.math.exp(qk - m_i_new[:, None])
|
||||||
|
l_ij = tl.sum(p, 1)
|
||||||
|
# -- update m_i and l_i
|
||||||
|
|
||||||
|
alpha = tl.math.exp(m_i - m_i_new)
|
||||||
|
l_i_new = alpha * l_i + l_ij
|
||||||
|
# -- update output accumulator --
|
||||||
|
# scale p
|
||||||
|
# scale acc
|
||||||
|
acc_scale = alpha
|
||||||
|
# acc_scale = l_i / l_i_new * alpha
|
||||||
|
acc = acc * acc_scale[:, None]
|
||||||
|
# update acc
|
||||||
|
v_load = tl.load(V_cache + off_v,
|
||||||
|
mask=dim_mask[None, :] &
|
||||||
|
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
|
||||||
|
other=0.0)
|
||||||
|
if v_load.dtype.is_fp8():
|
||||||
|
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
|
||||||
|
else:
|
||||||
|
v = v_load
|
||||||
|
p = p.to(v.dtype)
|
||||||
|
|
||||||
|
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
|
||||||
|
# update m_i and l_i
|
||||||
|
l_i = l_i_new
|
||||||
|
m_i = m_i_new
|
||||||
|
|
||||||
|
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
|
||||||
|
offs_d[:, None] * stride_kd)
|
||||||
|
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
|
||||||
|
offs_d[None, :] * stride_vd)
|
||||||
|
k_ptrs = K + off_k
|
||||||
|
v_ptrs = V + off_v
|
||||||
|
|
||||||
|
block_mask = tl.where(
|
||||||
|
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
|
||||||
|
|
||||||
|
# init alibi
|
||||||
|
alibi_slope = tl.load(Alibi_slopes + cur_head)
|
||||||
|
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
|
||||||
|
alibi_start_k = cur_batch_ctx_len
|
||||||
|
# # init debugger
|
||||||
|
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
|
||||||
|
# offset_db_k = tl.arange(0, BLOCK_N)
|
||||||
|
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
|
||||||
|
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||||
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
|
# -- compute qk ----
|
||||||
|
k = tl.load(
|
||||||
|
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||||
|
mask=dim_mask[:, None] & ((start_n + offs_n[None, :])
|
||||||
|
< cur_batch_seq_len - cur_batch_ctx_len),
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
|
qk = tl.dot(q, k, acc=qk, input_precision='ieee')
|
||||||
|
qk *= sm_scale
|
||||||
|
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||||
|
float("-inf"))
|
||||||
|
|
||||||
|
# load alibi
|
||||||
|
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
|
||||||
|
alibi_start_q[:, None]) * alibi_slope
|
||||||
|
alibi = tl.where(
|
||||||
|
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi,
|
||||||
|
float("-inf"))
|
||||||
|
qk += alibi
|
||||||
|
alibi_start_k += BLOCK_N
|
||||||
|
|
||||||
|
# -- compute m_ij, p, l_ij
|
||||||
|
m_ij = tl.max(qk, 1)
|
||||||
|
m_i_new = tl.maximum(m_i, m_ij)
|
||||||
|
p = tl.math.exp(qk - m_i_new[:, None])
|
||||||
|
l_ij = tl.sum(p, 1)
|
||||||
|
# -- update m_i and l_i
|
||||||
|
|
||||||
|
alpha = tl.math.exp(m_i - m_i_new)
|
||||||
|
l_i_new = alpha * l_i + l_ij
|
||||||
|
# -- update output accumulator --
|
||||||
|
# scale p
|
||||||
|
# scale acc
|
||||||
|
acc_scale = alpha
|
||||||
|
# acc_scale = l_i / l_i_new * alpha
|
||||||
|
acc = acc * acc_scale[:, None]
|
||||||
|
# update acc
|
||||||
|
v = tl.load(
|
||||||
|
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||||
|
mask=dim_mask[None, :] & ((start_n + offs_n[:, None])
|
||||||
|
< cur_batch_seq_len - cur_batch_ctx_len),
|
||||||
|
other=0.0)
|
||||||
|
p = p.to(v.dtype)
|
||||||
|
|
||||||
|
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
|
||||||
|
# update m_i and l_i
|
||||||
|
l_i = l_i_new
|
||||||
|
m_i = m_i_new
|
||||||
|
|
||||||
|
acc = acc / l_i[:, None]
|
||||||
|
|
||||||
|
# initialize pointers to output
|
||||||
|
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||||
|
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
||||||
|
out_ptrs = Out + off_o
|
||||||
|
tl.store(out_ptrs,
|
||||||
|
acc,
|
||||||
|
mask=dim_mask[None, :] &
|
||||||
|
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def context_attention_fwd(q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
o,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
b_loc,
|
||||||
|
b_start_loc,
|
||||||
|
b_seq_len,
|
||||||
|
max_seq_len,
|
||||||
|
max_input_len,
|
||||||
|
k_scale: torch.Tensor,
|
||||||
|
v_scale: torch.Tensor,
|
||||||
|
alibi_slopes=None,
|
||||||
|
sliding_window=None,
|
||||||
|
sm_scale=None,
|
||||||
|
skip_decode=False,
|
||||||
|
fp8_out_scale=None,
|
||||||
|
sinks=None):
|
||||||
|
|
||||||
|
q_dtype_is_f32 = q.dtype is torch.float32
|
||||||
|
|
||||||
|
# Turing does have tensor core for float32 multiplication
|
||||||
|
# use ieee as fallback for triton kernels work. There is also
|
||||||
|
# warning on vllm/config.py to inform users this fallback
|
||||||
|
# implementation
|
||||||
|
IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None
|
||||||
|
|
||||||
|
# Conversion of FP8 Tensor from uint8 storage to
|
||||||
|
# appropriate torch.dtype for interpretation by Triton
|
||||||
|
if "fp8" in kv_cache_dtype:
|
||||||
|
assert k_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||||
|
assert v_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||||
|
|
||||||
|
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||||
|
target_dtype = current_platform.fp8_dtype()
|
||||||
|
elif kv_cache_dtype == "fp8_e5m2":
|
||||||
|
target_dtype = torch.float8_e5m2
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
|
||||||
|
|
||||||
|
k_cache = k_cache.view(target_dtype)
|
||||||
|
v_cache = v_cache.view(target_dtype)
|
||||||
|
|
||||||
|
if (k_cache.dtype == torch.uint8
|
||||||
|
or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"):
|
||||||
|
raise ValueError("kv_cache_dtype='auto' unsupported for\
|
||||||
|
FP8 KV Cache prefill kernel")
|
||||||
|
|
||||||
|
# shape constraints
|
||||||
|
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||||
|
assert Lq == Lk and Lk == Lv
|
||||||
|
# round up Lk to a power of 2 - this is required for Triton block size
|
||||||
|
Lk_padded = triton.next_power_of_2(Lk)
|
||||||
|
|
||||||
|
if sm_scale is None:
|
||||||
|
sm_scale = 1.0 / (Lq**0.5)
|
||||||
|
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||||
|
num_queries_per_kv = q.shape[1] // k.shape[1]
|
||||||
|
|
||||||
|
assert batch + 1 == len(b_start_loc)
|
||||||
|
|
||||||
|
# 0 means "disable"
|
||||||
|
if sliding_window is None or sliding_window <= 0:
|
||||||
|
sliding_window = 0
|
||||||
|
|
||||||
|
if alibi_slopes is not None:
|
||||||
|
assert sinks is None, "Sinks arg is not supported with alibi"
|
||||||
|
assert fp8_out_scale is None, "FP8 output not supported with alibi"
|
||||||
|
# need to reduce num. blocks when using fp32
|
||||||
|
# due to increased use of GPU shared memory
|
||||||
|
# if q.dtype is torch.float32:
|
||||||
|
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK
|
||||||
|
# batch, head,
|
||||||
|
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||||
|
_fwd_kernel_alibi[grid](
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
b_loc,
|
||||||
|
sm_scale,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
b_start_loc,
|
||||||
|
b_seq_len,
|
||||||
|
alibi_slopes,
|
||||||
|
v_cache.shape[3],
|
||||||
|
k_cache.shape[4],
|
||||||
|
o,
|
||||||
|
b_loc.stride(0),
|
||||||
|
b_loc.stride(1),
|
||||||
|
q.stride(0),
|
||||||
|
q.stride(1),
|
||||||
|
q.stride(2),
|
||||||
|
k.stride(0),
|
||||||
|
k.stride(1),
|
||||||
|
k.stride(2),
|
||||||
|
v.stride(0),
|
||||||
|
v.stride(1),
|
||||||
|
v.stride(2),
|
||||||
|
o.stride(0),
|
||||||
|
o.stride(1),
|
||||||
|
o.stride(2),
|
||||||
|
k_cache.stride(0),
|
||||||
|
k_cache.stride(1),
|
||||||
|
k_cache.stride(2),
|
||||||
|
k_cache.stride(3),
|
||||||
|
k_cache.stride(
|
||||||
|
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||||
|
v_cache.stride(0),
|
||||||
|
v_cache.stride(1),
|
||||||
|
v_cache.stride(2),
|
||||||
|
v_cache.stride(
|
||||||
|
3), #[num_blocks, num_kv_heads, head_size, block_size]
|
||||||
|
num_queries_per_kv=num_queries_per_kv,
|
||||||
|
IN_PRECISION=IN_PRECISION,
|
||||||
|
BLOCK_M=BLOCK,
|
||||||
|
BLOCK_DMODEL=Lk,
|
||||||
|
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||||
|
BLOCK_N=BLOCK,
|
||||||
|
SKIP_DECODE=skip_decode,
|
||||||
|
num_warps=NUM_WARPS,
|
||||||
|
num_stages=1,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
max_seq_len = 0 if max_seq_len is None else max_seq_len
|
||||||
|
extra_kargs = {}
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
extra_kargs = {"kpack": 1, "waves_per_eu": 2}
|
||||||
|
|
||||||
|
grid = lambda META: (batch, head,
|
||||||
|
triton.cdiv(max_input_len, META["BLOCK_M"]))
|
||||||
|
_fwd_kernel[grid](
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
sinks,
|
||||||
|
b_loc,
|
||||||
|
sm_scale,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
1.0 / fp8_out_scale if fp8_out_scale is not None else 1.0,
|
||||||
|
b_start_loc,
|
||||||
|
b_seq_len,
|
||||||
|
k_cache.shape[4],
|
||||||
|
o,
|
||||||
|
b_loc.stride(0),
|
||||||
|
b_loc.stride(1),
|
||||||
|
q.stride(0),
|
||||||
|
q.stride(1),
|
||||||
|
q.stride(2),
|
||||||
|
k.stride(0),
|
||||||
|
k.stride(1),
|
||||||
|
k.stride(2),
|
||||||
|
v.stride(0),
|
||||||
|
v.stride(1),
|
||||||
|
v.stride(2),
|
||||||
|
o.stride(0),
|
||||||
|
o.stride(1),
|
||||||
|
o.stride(2),
|
||||||
|
k_cache.stride(0),
|
||||||
|
k_cache.stride(1),
|
||||||
|
k_cache.stride(2),
|
||||||
|
k_cache.stride(3),
|
||||||
|
k_cache.stride(
|
||||||
|
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||||
|
v_cache.stride(0),
|
||||||
|
v_cache.stride(1),
|
||||||
|
v_cache.stride(2),
|
||||||
|
v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size]
|
||||||
|
BLOCK_SIZE=v_cache.shape[3],
|
||||||
|
num_queries_per_kv=num_queries_per_kv,
|
||||||
|
IN_PRECISION=IN_PRECISION,
|
||||||
|
BLOCK_DMODEL=Lk,
|
||||||
|
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||||
|
SLIDING_WINDOW=sliding_window,
|
||||||
|
SKIP_DECODE=skip_decode,
|
||||||
|
USE_FP8=fp8_out_scale is not None,
|
||||||
|
BLOCK_M=128,
|
||||||
|
BLOCK_N=64,
|
||||||
|
num_unroll_cache=4,
|
||||||
|
num_unroll_request=1,
|
||||||
|
num_warps=4,
|
||||||
|
num_stages=1,
|
||||||
|
USE_SINKS=sinks is not None,
|
||||||
|
**extra_kargs)
|
||||||
|
return
|
||||||
104
vllm/attention/ops/rocm_aiter_mla.py
Normal file
104
vllm/attention/ops/rocm_aiter_mla.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||||
|
|
||||||
|
|
||||||
|
def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
|
||||||
|
max_block_per_batch: int,
|
||||||
|
device: torch.device) -> tuple[torch.Tensor, ...]:
|
||||||
|
paged_kv_indices = torch.zeros(max_batch_size * max_block_per_batch,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
paged_kv_indptr = torch.zeros(max_batch_size + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
paged_kv_last_page_lens = torch.full((max_batch_size, ),
|
||||||
|
block_size,
|
||||||
|
dtype=torch.int32)
|
||||||
|
qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device)
|
||||||
|
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr
|
||||||
|
|
||||||
|
|
||||||
|
def aiter_mla_decode_fwd(
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv_buffer: torch.Tensor,
|
||||||
|
o: torch.Tensor,
|
||||||
|
sm_scale: float,
|
||||||
|
qo_indptr: torch.Tensor,
|
||||||
|
max_seqlen_qo: int,
|
||||||
|
kv_indptr: Optional[torch.Tensor] = None,
|
||||||
|
kv_indices: Optional[torch.Tensor] = None,
|
||||||
|
kv_last_page_lens: Optional[torch.Tensor] = None,
|
||||||
|
logit_cap: float = 0.0,
|
||||||
|
):
|
||||||
|
|
||||||
|
torch.ops.vllm.rocm_aiter_mla_decode_fwd(q,
|
||||||
|
kv_buffer.view(
|
||||||
|
-1, 1, 1, q.shape[-1]),
|
||||||
|
o,
|
||||||
|
qo_indptr,
|
||||||
|
max_seqlen_qo,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_last_page_lens,
|
||||||
|
sm_scale=sm_scale,
|
||||||
|
logit_cap=logit_cap)
|
||||||
|
|
||||||
|
|
||||||
|
def mla_decode_fwd_impl(
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv_buffer: torch.Tensor,
|
||||||
|
o: torch.Tensor,
|
||||||
|
qo_indptr: torch.Tensor,
|
||||||
|
max_seqlen_qo: int,
|
||||||
|
kv_indptr: Optional[torch.Tensor] = None,
|
||||||
|
kv_indices: Optional[torch.Tensor] = None,
|
||||||
|
kv_last_page_lens: Optional[torch.Tensor] = None,
|
||||||
|
sm_scale: float = 1.0,
|
||||||
|
logit_cap: float = 0.0,
|
||||||
|
) -> None:
|
||||||
|
from aiter.mla import mla_decode_fwd
|
||||||
|
|
||||||
|
mla_decode_fwd(q,
|
||||||
|
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
||||||
|
o,
|
||||||
|
qo_indptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_last_page_lens,
|
||||||
|
max_seqlen_qo,
|
||||||
|
sm_scale=sm_scale,
|
||||||
|
logit_cap=logit_cap)
|
||||||
|
|
||||||
|
|
||||||
|
def mla_decode_fwd_fake(
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv_buffer: torch.Tensor,
|
||||||
|
o: torch.Tensor,
|
||||||
|
qo_indptr: torch.Tensor,
|
||||||
|
max_seqlen_qo: int,
|
||||||
|
kv_indptr: Optional[torch.Tensor] = None,
|
||||||
|
kv_indices: Optional[torch.Tensor] = None,
|
||||||
|
kv_last_page_lens: Optional[torch.Tensor] = None,
|
||||||
|
sm_scale: float = 1.0,
|
||||||
|
logit_cap: float = 0.0,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
if is_torch_equal_or_newer("2.7.0"):
|
||||||
|
tags = ()
|
||||||
|
else:
|
||||||
|
tags = (torch.Tag.needs_fixed_stride_order, ),
|
||||||
|
direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd",
|
||||||
|
op_func=mla_decode_fwd_impl,
|
||||||
|
mutates_args=["o"],
|
||||||
|
fake_impl=mla_decode_fwd_fake,
|
||||||
|
tags=tags)
|
||||||
102
vllm/attention/ops/rocm_aiter_paged_attn.py
Normal file
102
vllm/attention/ops/rocm_aiter_paged_attn.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiter as rocm_aiter
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.ops.paged_attn import PagedAttention
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import cdiv
|
||||||
|
|
||||||
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
|
||||||
|
|
||||||
|
class AITERPagedAttention(PagedAttention):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def write_to_paged_cache(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: torch.Tensor,
|
||||||
|
v_scale: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
|
||||||
|
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||||
|
value_cache, slot_mapping,
|
||||||
|
kv_cache_dtype, k_scale,
|
||||||
|
v_scale)
|
||||||
|
else:
|
||||||
|
kv_cache_torch_dtype = (FP8_DTYPE
|
||||||
|
if "fp8" in kv_cache_dtype else torch.int8)
|
||||||
|
key_cache = key_cache.view(kv_cache_torch_dtype)
|
||||||
|
value_cache = value_cache.view(kv_cache_torch_dtype)
|
||||||
|
|
||||||
|
rocm_aiter.reshape_and_cache_with_pertoken_quant(
|
||||||
|
key, value, key_cache, value_cache, k_scale, v_scale,
|
||||||
|
slot_mapping.flatten(), True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward_decode(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
max_seq_len: int,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
num_kv_heads: int,
|
||||||
|
scale: float,
|
||||||
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
|
k_scale: torch.Tensor,
|
||||||
|
v_scale: torch.Tensor,
|
||||||
|
tp_rank: int = 0,
|
||||||
|
blocksparse_local_blocks: int = 0,
|
||||||
|
blocksparse_vert_stride: int = 0,
|
||||||
|
blocksparse_block_size: int = 64,
|
||||||
|
blocksparse_head_sliding_step: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
|
||||||
|
return PagedAttention.forward_decode(
|
||||||
|
query=query,
|
||||||
|
key_cache=key_cache,
|
||||||
|
value_cache=value_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
scale=scale,
|
||||||
|
alibi_slopes=alibi_slopes,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
blocksparse_local_blocks=blocksparse_local_blocks,
|
||||||
|
blocksparse_vert_stride=blocksparse_vert_stride,
|
||||||
|
blocksparse_block_size=blocksparse_block_size,
|
||||||
|
blocksparse_head_sliding_step=blocksparse_head_sliding_step)
|
||||||
|
|
||||||
|
if "fp8" in kv_cache_dtype:
|
||||||
|
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||||
|
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||||
|
|
||||||
|
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||||
|
# use blocksparse paged attention
|
||||||
|
block_size = value_cache.size(-1)
|
||||||
|
assert (blocksparse_block_size > 0 and
|
||||||
|
blocksparse_block_size % block_size == 0), \
|
||||||
|
(f"{blocksparse_block_size=} needs to be a multiple of"
|
||||||
|
f"{block_size=} used in block_tables.")
|
||||||
|
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
block_size = value_cache.shape[3]
|
||||||
|
max_num_blocks_per_seq = cdiv(max_seq_len, block_size)
|
||||||
|
|
||||||
|
rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables,
|
||||||
|
seq_lens, max_num_blocks_per_seq, k_scale,
|
||||||
|
v_scale, output)
|
||||||
|
return output
|
||||||
691
vllm/attention/ops/triton_decode_attention.py
Normal file
691
vllm/attention/ops/triton_decode_attention.py
Normal file
@@ -0,0 +1,691 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py
|
||||||
|
# which was originally adapted from
|
||||||
|
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
|
||||||
|
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
|
||||||
|
|
||||||
|
# Changes:
|
||||||
|
# - Add support for page size >= 1.
|
||||||
|
|
||||||
|
# Copyright 2025 vLLM Team
|
||||||
|
# Copyright 2023-2024 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
Memory-efficient attention for decoding.
|
||||||
|
It supports page size >= 1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
is_hip_ = current_platform.is_rocm()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Only print the following warnings when triton version < 3.2.0.
|
||||||
|
# The issue won't affect performance or accuracy.
|
||||||
|
if version.parse(triton.__version__) < version.parse('3.2.0'):
|
||||||
|
logger.warning(
|
||||||
|
"The following error message 'operation scheduled before its operands' "
|
||||||
|
"can be ignored.")
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def tanh(x):
|
||||||
|
# Tanh is just a scaled sigmoid
|
||||||
|
return 2 * tl.sigmoid(2 * x) - 1
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _fwd_kernel_stage1(
|
||||||
|
Q,
|
||||||
|
K_Buffer,
|
||||||
|
V_Buffer,
|
||||||
|
sm_scale,
|
||||||
|
Req_to_tokens,
|
||||||
|
B_Seqlen,
|
||||||
|
Att_Out,
|
||||||
|
stride_req_to_tokens_b,
|
||||||
|
stride_qbs,
|
||||||
|
stride_qh,
|
||||||
|
stride_buf_kbs,
|
||||||
|
stride_buf_kh,
|
||||||
|
stride_buf_vbs,
|
||||||
|
stride_buf_vh,
|
||||||
|
stride_mid_ob,
|
||||||
|
stride_mid_oh,
|
||||||
|
stride_mid_os,
|
||||||
|
kv_group_num: tl.constexpr,
|
||||||
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
|
BLOCK_DV: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
NUM_KV_SPLITS: tl.constexpr,
|
||||||
|
PAGE_SIZE: tl.constexpr,
|
||||||
|
logit_cap: tl.constexpr,
|
||||||
|
Lk: tl.constexpr,
|
||||||
|
Lv: tl.constexpr,
|
||||||
|
):
|
||||||
|
cur_batch = tl.program_id(0)
|
||||||
|
cur_head = tl.program_id(1)
|
||||||
|
split_kv_id = tl.program_id(2)
|
||||||
|
|
||||||
|
cur_kv_head = cur_head // kv_group_num
|
||||||
|
|
||||||
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||||
|
offs_dv = tl.arange(0, BLOCK_DV)
|
||||||
|
mask_d = offs_d < Lk
|
||||||
|
mask_dv = offs_dv < Lv
|
||||||
|
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||||
|
cur_batch_req_idx = cur_batch
|
||||||
|
|
||||||
|
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
||||||
|
q = tl.load(Q + off_q, mask=mask_d, other=0.0)
|
||||||
|
|
||||||
|
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||||
|
split_kv_start = kv_len_per_split * split_kv_id
|
||||||
|
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
|
||||||
|
cur_batch_seq_len)
|
||||||
|
|
||||||
|
e_max = -float("inf")
|
||||||
|
e_sum = 0.0
|
||||||
|
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
||||||
|
|
||||||
|
if split_kv_end > split_kv_start:
|
||||||
|
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||||
|
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||||
|
kv_page_number = tl.load(
|
||||||
|
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +
|
||||||
|
offs_n // PAGE_SIZE,
|
||||||
|
mask=offs_n < split_kv_end,
|
||||||
|
other=0,
|
||||||
|
)
|
||||||
|
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
|
||||||
|
offs_buf_k = (kv_loc[:, None] * stride_buf_kbs +
|
||||||
|
cur_kv_head * stride_buf_kh + offs_d[None, :])
|
||||||
|
k = tl.load(
|
||||||
|
K_Buffer + offs_buf_k,
|
||||||
|
mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
qk = tl.sum(q[None, :] * k, 1)
|
||||||
|
qk *= sm_scale
|
||||||
|
|
||||||
|
if logit_cap > 0:
|
||||||
|
qk = logit_cap * tanh(qk / logit_cap)
|
||||||
|
|
||||||
|
qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))
|
||||||
|
|
||||||
|
offs_buf_v = (kv_loc[:, None] * stride_buf_vbs +
|
||||||
|
cur_kv_head * stride_buf_vh + offs_dv[None, :])
|
||||||
|
v = tl.load(
|
||||||
|
V_Buffer + offs_buf_v,
|
||||||
|
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
|
||||||
|
re_scale = tl.exp(e_max - n_e_max)
|
||||||
|
p = tl.exp(qk - n_e_max)
|
||||||
|
acc *= re_scale
|
||||||
|
acc += tl.sum(p[:, None] * v, 0)
|
||||||
|
|
||||||
|
e_sum = e_sum * re_scale + tl.sum(p, 0)
|
||||||
|
e_max = n_e_max
|
||||||
|
|
||||||
|
offs_mid_o = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
|
||||||
|
split_kv_id * stride_mid_os + offs_dv)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
Att_Out + offs_mid_o,
|
||||||
|
acc / e_sum,
|
||||||
|
mask=(mask_dv),
|
||||||
|
)
|
||||||
|
|
||||||
|
offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
|
||||||
|
split_kv_id * stride_mid_os + Lv)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
Att_Out + offs_mid_o_1,
|
||||||
|
e_max + tl.log(e_sum),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_att_m_fwd(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
att_out,
|
||||||
|
Req_to_tokens,
|
||||||
|
B_Seqlen,
|
||||||
|
num_kv_splits,
|
||||||
|
sm_scale,
|
||||||
|
page_size,
|
||||||
|
logit_cap,
|
||||||
|
):
|
||||||
|
BLOCK = 64 if not is_hip_ else 8
|
||||||
|
|
||||||
|
NUM_KV_SPLITS = num_kv_splits
|
||||||
|
Lk = k_buffer.shape[-1]
|
||||||
|
Lv = v_buffer.shape[-1]
|
||||||
|
|
||||||
|
batch, head_num = q.shape[0], q.shape[1]
|
||||||
|
|
||||||
|
grid = (batch, head_num, NUM_KV_SPLITS)
|
||||||
|
kv_group_num = q.shape[1] // k_buffer.shape[-2]
|
||||||
|
|
||||||
|
num_warps = 4
|
||||||
|
if kv_group_num != 1:
|
||||||
|
num_warps = 1 if is_hip_ else 2
|
||||||
|
|
||||||
|
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
||||||
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
|
_fwd_kernel_stage1[grid](
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
sm_scale,
|
||||||
|
Req_to_tokens,
|
||||||
|
B_Seqlen,
|
||||||
|
att_out,
|
||||||
|
Req_to_tokens.stride(0),
|
||||||
|
q.stride(0),
|
||||||
|
q.stride(1),
|
||||||
|
k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||||
|
k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||||
|
v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||||
|
v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||||
|
att_out.stride(0),
|
||||||
|
att_out.stride(1),
|
||||||
|
att_out.stride(2),
|
||||||
|
kv_group_num=kv_group_num,
|
||||||
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||||
|
BLOCK_DV=BLOCK_DV,
|
||||||
|
BLOCK_N=BLOCK,
|
||||||
|
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||||
|
PAGE_SIZE=page_size,
|
||||||
|
logit_cap=logit_cap,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=2,
|
||||||
|
Lk=Lk,
|
||||||
|
Lv=Lv,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _fwd_grouped_kernel_stage1(
|
||||||
|
Q,
|
||||||
|
K_Buffer,
|
||||||
|
V_Buffer,
|
||||||
|
sm_scale,
|
||||||
|
Req_to_tokens,
|
||||||
|
B_Seqlen,
|
||||||
|
Att_Out,
|
||||||
|
stride_req_to_tokens_b,
|
||||||
|
stride_qbs,
|
||||||
|
stride_qh,
|
||||||
|
stride_buf_kbs,
|
||||||
|
stride_buf_kh,
|
||||||
|
stride_buf_vbs,
|
||||||
|
stride_buf_vh,
|
||||||
|
stride_mid_ob,
|
||||||
|
stride_mid_oh,
|
||||||
|
stride_mid_os,
|
||||||
|
kv_group_num: tl.constexpr,
|
||||||
|
q_head_num: tl.constexpr,
|
||||||
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
|
BLOCK_DPE: tl.constexpr,
|
||||||
|
BLOCK_DV: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
BLOCK_H: tl.constexpr,
|
||||||
|
NUM_KV_SPLITS: tl.constexpr,
|
||||||
|
PAGE_SIZE: tl.constexpr,
|
||||||
|
logit_cap: tl.constexpr,
|
||||||
|
Lk: tl.constexpr,
|
||||||
|
Lv: tl.constexpr,
|
||||||
|
):
|
||||||
|
cur_batch = tl.program_id(0)
|
||||||
|
cur_head_id = tl.program_id(1)
|
||||||
|
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
|
||||||
|
split_kv_id = tl.program_id(2)
|
||||||
|
|
||||||
|
if kv_group_num > BLOCK_H:
|
||||||
|
VALID_BLOCK_H: tl.constexpr = BLOCK_H
|
||||||
|
else:
|
||||||
|
VALID_BLOCK_H: tl.constexpr = kv_group_num
|
||||||
|
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
|
||||||
|
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
|
||||||
|
mask_h = mask_h & (cur_head < q_head_num)
|
||||||
|
|
||||||
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||||
|
offs_dv = tl.arange(0, BLOCK_DV)
|
||||||
|
mask_d = offs_d < Lk
|
||||||
|
mask_dv = offs_dv < Lv
|
||||||
|
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||||
|
cur_batch_req_idx = cur_batch
|
||||||
|
|
||||||
|
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[
|
||||||
|
None, :]
|
||||||
|
q = tl.load(Q + offs_q,
|
||||||
|
mask=(mask_h[:, None]) & (mask_d[None, :]),
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
|
if BLOCK_DPE > 0:
|
||||||
|
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
||||||
|
mask_dpe = offs_dpe < Lk
|
||||||
|
off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh +
|
||||||
|
offs_dpe[None, :])
|
||||||
|
qpe = tl.load(Q + off_qpe,
|
||||||
|
mask=(mask_h[:, None]) & (mask_dpe[None, :]),
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
|
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||||
|
split_kv_start = kv_len_per_split * split_kv_id
|
||||||
|
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
|
||||||
|
cur_batch_seq_len)
|
||||||
|
|
||||||
|
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
|
||||||
|
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
|
||||||
|
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
|
||||||
|
|
||||||
|
if split_kv_end > split_kv_start:
|
||||||
|
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||||
|
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||||
|
kv_page_number = tl.load(
|
||||||
|
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +
|
||||||
|
offs_n // PAGE_SIZE,
|
||||||
|
mask=offs_n < split_kv_end,
|
||||||
|
other=0,
|
||||||
|
)
|
||||||
|
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
|
||||||
|
offs_buf_k = (kv_loc[None, :] * stride_buf_kbs +
|
||||||
|
cur_kv_head * stride_buf_kh + offs_d[:, None])
|
||||||
|
k = tl.load(
|
||||||
|
K_Buffer + offs_buf_k,
|
||||||
|
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
qk = tl.dot(q, k.to(q.dtype))
|
||||||
|
if BLOCK_DPE > 0:
|
||||||
|
offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs +
|
||||||
|
cur_kv_head * stride_buf_kh +
|
||||||
|
offs_dpe[:, None])
|
||||||
|
kpe = tl.load(
|
||||||
|
K_Buffer + offs_buf_kpe,
|
||||||
|
mask=(offs_n[None, :] < split_kv_end) &
|
||||||
|
(mask_dpe[:, None]),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
qk += tl.dot(qpe, kpe.to(qpe.dtype))
|
||||||
|
qk *= sm_scale
|
||||||
|
|
||||||
|
if logit_cap > 0:
|
||||||
|
qk = logit_cap * tanh(qk / logit_cap)
|
||||||
|
|
||||||
|
qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end),
|
||||||
|
qk, float("-inf"))
|
||||||
|
|
||||||
|
offs_buf_v = (kv_loc[:, None] * stride_buf_vbs +
|
||||||
|
cur_kv_head * stride_buf_vh + offs_dv[None, :])
|
||||||
|
v = tl.load(
|
||||||
|
V_Buffer + offs_buf_v,
|
||||||
|
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||||
|
re_scale = tl.exp(e_max - n_e_max)
|
||||||
|
p = tl.exp(qk - n_e_max[:, None])
|
||||||
|
acc *= re_scale[:, None]
|
||||||
|
acc += tl.dot(p.to(v.dtype), v)
|
||||||
|
|
||||||
|
e_sum = e_sum * re_scale + tl.sum(p, 1)
|
||||||
|
e_max = n_e_max
|
||||||
|
|
||||||
|
offs_mid_o = (cur_batch * stride_mid_ob +
|
||||||
|
cur_head[:, None] * stride_mid_oh +
|
||||||
|
split_kv_id * stride_mid_os + offs_dv[None, :])
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
Att_Out + offs_mid_o,
|
||||||
|
acc / e_sum[:, None],
|
||||||
|
mask=(mask_h[:, None]) & (mask_dv[None, :]),
|
||||||
|
)
|
||||||
|
|
||||||
|
offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
|
||||||
|
split_kv_id * stride_mid_os + Lv)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
Att_Out + offs_mid_o_1,
|
||||||
|
e_max + tl.log(e_sum),
|
||||||
|
mask=mask_h,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_grouped_att_m_fwd(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
att_out,
|
||||||
|
Req_to_tokens,
|
||||||
|
B_Seqlen,
|
||||||
|
num_kv_splits,
|
||||||
|
sm_scale,
|
||||||
|
page_size,
|
||||||
|
logit_cap,
|
||||||
|
):
|
||||||
|
BLOCK = 32
|
||||||
|
Lk = k_buffer.shape[-1]
|
||||||
|
Lv = v_buffer.shape[-1]
|
||||||
|
|
||||||
|
# [TODO] work around shmem limit on MI3xx
|
||||||
|
if is_hip_ and Lk >= 576:
|
||||||
|
BLOCK = 16
|
||||||
|
|
||||||
|
if Lk == 576:
|
||||||
|
BLOCK_DMODEL = 512
|
||||||
|
BLOCK_DPE = 64
|
||||||
|
elif Lk == 288:
|
||||||
|
BLOCK_DMODEL = 256
|
||||||
|
BLOCK_DPE = 32
|
||||||
|
else:
|
||||||
|
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
||||||
|
BLOCK_DPE = 0
|
||||||
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
|
batch, head_num = q.shape[0], q.shape[1]
|
||||||
|
kv_group_num = q.shape[1] // k_buffer.shape[-2]
|
||||||
|
|
||||||
|
BLOCK_H = 16
|
||||||
|
NUM_KV_SPLITS = num_kv_splits
|
||||||
|
grid = (
|
||||||
|
batch,
|
||||||
|
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
|
||||||
|
NUM_KV_SPLITS,
|
||||||
|
)
|
||||||
|
|
||||||
|
extra_kargs = {}
|
||||||
|
num_stages = 2
|
||||||
|
if is_hip_:
|
||||||
|
# https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization
|
||||||
|
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||||
|
extra_kargs = {
|
||||||
|
"waves_per_eu": 1,
|
||||||
|
"matrix_instr_nonkdim": 16,
|
||||||
|
"kpack": 2
|
||||||
|
}
|
||||||
|
num_stages = 1
|
||||||
|
|
||||||
|
_fwd_grouped_kernel_stage1[grid](
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
sm_scale,
|
||||||
|
Req_to_tokens,
|
||||||
|
B_Seqlen,
|
||||||
|
att_out,
|
||||||
|
Req_to_tokens.stride(0),
|
||||||
|
q.stride(0),
|
||||||
|
q.stride(1),
|
||||||
|
k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||||
|
k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||||
|
v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||||
|
v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||||
|
att_out.stride(0),
|
||||||
|
att_out.stride(1),
|
||||||
|
att_out.stride(2),
|
||||||
|
kv_group_num=kv_group_num,
|
||||||
|
q_head_num=head_num,
|
||||||
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||||
|
BLOCK_DPE=BLOCK_DPE,
|
||||||
|
BLOCK_DV=BLOCK_DV,
|
||||||
|
BLOCK_N=BLOCK,
|
||||||
|
BLOCK_H=BLOCK_H,
|
||||||
|
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||||
|
PAGE_SIZE=page_size,
|
||||||
|
logit_cap=logit_cap,
|
||||||
|
num_warps=4,
|
||||||
|
num_stages=num_stages,
|
||||||
|
Lk=Lk,
|
||||||
|
Lv=Lv,
|
||||||
|
**extra_kargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _fwd_kernel_stage2(
|
||||||
|
Mid_O,
|
||||||
|
o,
|
||||||
|
lse,
|
||||||
|
B_Seqlen,
|
||||||
|
stride_mid_ob,
|
||||||
|
stride_mid_oh,
|
||||||
|
stride_mid_os,
|
||||||
|
stride_obs,
|
||||||
|
stride_oh,
|
||||||
|
stride_lse_bs,
|
||||||
|
NUM_KV_SPLITS: tl.constexpr,
|
||||||
|
BLOCK_DV: tl.constexpr,
|
||||||
|
Lv: tl.constexpr,
|
||||||
|
):
|
||||||
|
cur_batch = tl.program_id(0)
|
||||||
|
cur_head = tl.program_id(1)
|
||||||
|
|
||||||
|
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||||
|
|
||||||
|
offs_d = tl.arange(0, BLOCK_DV)
|
||||||
|
mask_d = offs_d < Lv
|
||||||
|
|
||||||
|
e_sum = 0.0
|
||||||
|
e_max = -float("inf")
|
||||||
|
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
||||||
|
|
||||||
|
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
|
||||||
|
offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv
|
||||||
|
|
||||||
|
for split_kv_id in range(0, NUM_KV_SPLITS):
|
||||||
|
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||||
|
split_kv_start = kv_len_per_split * split_kv_id
|
||||||
|
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
|
||||||
|
cur_batch_seq_len)
|
||||||
|
|
||||||
|
if split_kv_end > split_kv_start:
|
||||||
|
tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os,
|
||||||
|
mask=mask_d,
|
||||||
|
other=0.0)
|
||||||
|
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
|
||||||
|
n_e_max = tl.maximum(tlogic, e_max)
|
||||||
|
|
||||||
|
old_scale = tl.exp(e_max - n_e_max)
|
||||||
|
acc *= old_scale
|
||||||
|
exp_logic = tl.exp(tlogic - n_e_max)
|
||||||
|
acc += exp_logic * tv
|
||||||
|
|
||||||
|
e_sum = e_sum * old_scale + exp_logic
|
||||||
|
e_max = n_e_max
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
o + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
|
||||||
|
acc / e_sum,
|
||||||
|
mask=mask_d,
|
||||||
|
)
|
||||||
|
lse_val = e_max + tl.log(e_sum)
|
||||||
|
tl.store(
|
||||||
|
lse + cur_batch * stride_lse_bs + cur_head,
|
||||||
|
lse_val,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_softmax_reducev_fwd(
|
||||||
|
logits,
|
||||||
|
q,
|
||||||
|
o,
|
||||||
|
lse,
|
||||||
|
v_buffer,
|
||||||
|
b_seq_len,
|
||||||
|
num_kv_splits,
|
||||||
|
):
|
||||||
|
batch, head_num = q.shape[0], q.shape[1]
|
||||||
|
Lv = v_buffer.shape[-1]
|
||||||
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
|
NUM_KV_SPLITS = num_kv_splits
|
||||||
|
|
||||||
|
extra_kargs = {}
|
||||||
|
if is_hip_:
|
||||||
|
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
||||||
|
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||||
|
extra_kargs = {
|
||||||
|
"waves_per_eu": 4,
|
||||||
|
"matrix_instr_nonkdim": 16,
|
||||||
|
"kpack": 2
|
||||||
|
}
|
||||||
|
|
||||||
|
grid = (batch, head_num)
|
||||||
|
_fwd_kernel_stage2[grid](
|
||||||
|
logits,
|
||||||
|
o,
|
||||||
|
lse,
|
||||||
|
b_seq_len,
|
||||||
|
logits.stride(0),
|
||||||
|
logits.stride(1),
|
||||||
|
logits.stride(2),
|
||||||
|
o.stride(0),
|
||||||
|
o.stride(1),
|
||||||
|
lse.stride(0),
|
||||||
|
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||||
|
BLOCK_DV=BLOCK_DV,
|
||||||
|
Lv=Lv,
|
||||||
|
num_warps=4,
|
||||||
|
num_stages=2,
|
||||||
|
**extra_kargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_attention_fwd_normal(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
o,
|
||||||
|
lse,
|
||||||
|
req_to_token,
|
||||||
|
b_seq_len,
|
||||||
|
attn_logits,
|
||||||
|
num_kv_splits,
|
||||||
|
sm_scale,
|
||||||
|
page_size,
|
||||||
|
logit_cap=0.0,
|
||||||
|
):
|
||||||
|
_decode_att_m_fwd(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
attn_logits,
|
||||||
|
req_to_token,
|
||||||
|
b_seq_len,
|
||||||
|
num_kv_splits,
|
||||||
|
sm_scale,
|
||||||
|
page_size,
|
||||||
|
logit_cap,
|
||||||
|
)
|
||||||
|
_decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len,
|
||||||
|
num_kv_splits)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_attention_fwd_grouped(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
o,
|
||||||
|
lse,
|
||||||
|
req_to_token,
|
||||||
|
b_seq_len,
|
||||||
|
attn_logits,
|
||||||
|
num_kv_splits,
|
||||||
|
sm_scale,
|
||||||
|
page_size,
|
||||||
|
logit_cap=0.0,
|
||||||
|
):
|
||||||
|
_decode_grouped_att_m_fwd(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
attn_logits,
|
||||||
|
req_to_token,
|
||||||
|
b_seq_len,
|
||||||
|
num_kv_splits,
|
||||||
|
sm_scale,
|
||||||
|
page_size,
|
||||||
|
logit_cap,
|
||||||
|
)
|
||||||
|
_decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len,
|
||||||
|
num_kv_splits)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_attention_fwd(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
o,
|
||||||
|
lse,
|
||||||
|
req_to_token,
|
||||||
|
b_seq_len,
|
||||||
|
attn_logits,
|
||||||
|
num_kv_splits,
|
||||||
|
sm_scale,
|
||||||
|
page_size=1,
|
||||||
|
logit_cap=0.0,
|
||||||
|
):
|
||||||
|
assert num_kv_splits == attn_logits.shape[2]
|
||||||
|
kv_group_num = q.shape[1] // v_buffer.shape[-2]
|
||||||
|
|
||||||
|
if kv_group_num == 1:
|
||||||
|
# MHA
|
||||||
|
decode_attention_fwd_normal(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
o,
|
||||||
|
lse,
|
||||||
|
req_to_token,
|
||||||
|
b_seq_len,
|
||||||
|
attn_logits,
|
||||||
|
num_kv_splits,
|
||||||
|
sm_scale,
|
||||||
|
page_size,
|
||||||
|
logit_cap,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# GQA/MQA/MLA
|
||||||
|
decode_attention_fwd_grouped(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
o,
|
||||||
|
lse,
|
||||||
|
req_to_token,
|
||||||
|
b_seq_len,
|
||||||
|
attn_logits,
|
||||||
|
num_kv_splits,
|
||||||
|
sm_scale,
|
||||||
|
page_size,
|
||||||
|
logit_cap,
|
||||||
|
)
|
||||||
984
vllm/attention/ops/triton_flash_attention.py
Normal file
984
vllm/attention/ops/triton_flash_attention.py
Normal file
@@ -0,0 +1,984 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Fused Attention
|
||||||
|
===============
|
||||||
|
|
||||||
|
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
|
||||||
|
(https://tridao.me/publications/flash2/flash2.pdf)
|
||||||
|
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
|
||||||
|
|
||||||
|
Features supported:
|
||||||
|
|
||||||
|
1) Fwd with causal masking
|
||||||
|
2) Any sequence lengths without padding (currently fwd kernel only)
|
||||||
|
3) Support for different sequence lengths for q and k
|
||||||
|
4) Nested tensor API currently does not support dropout or bias.
|
||||||
|
|
||||||
|
Not currently supported:
|
||||||
|
|
||||||
|
1) Non power of two head dims
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
# Avoid misleading ROCm warning.
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
from vllm.platforms.rocm import on_gfx1x
|
||||||
|
else:
|
||||||
|
on_gfx1x = lambda *args, **kwargs: False
|
||||||
|
|
||||||
|
torch_dtype: tl.constexpr = torch.float16
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def cdiv_fn(x, y):
|
||||||
|
return (x + y - 1) // y
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def max_fn(x, y):
|
||||||
|
return tl.math.max(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||||
|
ms = tl.arange(0, m)
|
||||||
|
ns = tl.arange(0, n)
|
||||||
|
return philox_offset + ms[:, None] * stride + ns[None, :]
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||||
|
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n,
|
||||||
|
stride).to(tl.uint32)
|
||||||
|
# TODO: use tl.randint for better performance
|
||||||
|
return tl.rand(philox_seed, rng_offsets)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||||
|
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n,
|
||||||
|
stride)
|
||||||
|
rng_keep = rng_output > dropout_p
|
||||||
|
return rng_keep
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def load_fn(block_ptr, first, second, pad):
|
||||||
|
if first and second:
|
||||||
|
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
|
||||||
|
elif first:
|
||||||
|
tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
|
||||||
|
elif second:
|
||||||
|
tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
|
||||||
|
else:
|
||||||
|
tensor = tl.load(block_ptr)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _attn_fwd_inner(
|
||||||
|
acc,
|
||||||
|
l_i,
|
||||||
|
m_i,
|
||||||
|
q,
|
||||||
|
K_block_ptr,
|
||||||
|
V_block_ptr,
|
||||||
|
start_m,
|
||||||
|
actual_seqlen_k,
|
||||||
|
dropout_p,
|
||||||
|
philox_seed,
|
||||||
|
batch_philox_offset,
|
||||||
|
encoded_softmax_block_ptr,
|
||||||
|
block_min,
|
||||||
|
block_max,
|
||||||
|
offs_n_causal,
|
||||||
|
masked_blocks,
|
||||||
|
n_extra_tokens,
|
||||||
|
bias_ptr,
|
||||||
|
IS_CAUSAL: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
OFFS_M: tl.constexpr,
|
||||||
|
OFFS_N: tl.constexpr,
|
||||||
|
PRE_LOAD_V: tl.constexpr,
|
||||||
|
MASK_STEPS: tl.constexpr,
|
||||||
|
ENABLE_DROPOUT: tl.constexpr,
|
||||||
|
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
||||||
|
PADDED_HEAD: tl.constexpr,
|
||||||
|
USE_FP8: tl.constexpr,
|
||||||
|
qk_scale,
|
||||||
|
p_descale,
|
||||||
|
):
|
||||||
|
# loop over k, v, and update accumulator
|
||||||
|
for start_n in range(block_min, block_max, BLOCK_N):
|
||||||
|
# For padded blocks, we will overrun the tensor size if
|
||||||
|
# we load all BLOCK_N. For others, the blocks are all within range.
|
||||||
|
k = load_fn(
|
||||||
|
K_block_ptr,
|
||||||
|
PADDED_HEAD,
|
||||||
|
MASK_STEPS and (n_extra_tokens != 0),
|
||||||
|
"zero",
|
||||||
|
)
|
||||||
|
if PRE_LOAD_V:
|
||||||
|
v = load_fn(
|
||||||
|
V_block_ptr,
|
||||||
|
MASK_STEPS and (n_extra_tokens != 0),
|
||||||
|
PADDED_HEAD,
|
||||||
|
"zero",
|
||||||
|
)
|
||||||
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
|
# We start from end of seqlen_k so only the first iteration would need
|
||||||
|
# to be checked for padding if it is not a multiple of block_n
|
||||||
|
# TODO: This can be optimized to only be true for the padded block.
|
||||||
|
if MASK_STEPS: # noqa: SIM102
|
||||||
|
# If this is the last block / iteration, we want to
|
||||||
|
# mask if the sequence length is not a multiple of block size
|
||||||
|
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
|
||||||
|
# if not is_modulo_mn. last step might get wasted but that is okay.
|
||||||
|
# check if this masking works for that case.
|
||||||
|
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
|
||||||
|
boundary_m = tl.full([BLOCK_M],
|
||||||
|
actual_seqlen_k,
|
||||||
|
dtype=tl.int32)
|
||||||
|
size_n = start_n + OFFS_N[None, :]
|
||||||
|
mask = size_n < boundary_m[:, None]
|
||||||
|
qk = tl.where(mask, qk, float("-inf"))
|
||||||
|
if IS_CAUSAL:
|
||||||
|
causal_boundary = start_n + offs_n_causal
|
||||||
|
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
|
||||||
|
qk = tl.where(causal_mask, qk, float("-inf"))
|
||||||
|
# -- compute qk ----
|
||||||
|
qk += tl.dot(q, k)
|
||||||
|
if USE_FP8:
|
||||||
|
qk *= qk_scale
|
||||||
|
if bias_ptr is not None:
|
||||||
|
bias = load_fn(bias_ptr, False, MASK_STEPS
|
||||||
|
and (n_extra_tokens != 0), "zero")
|
||||||
|
# While bias is added after multiplying qk with sm_scale, our
|
||||||
|
# optimization to use 2^x instead of e^x results in an additional
|
||||||
|
# scale factor of log2(e) which we must also multiply the bias with.
|
||||||
|
qk += bias * 1.44269504089
|
||||||
|
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||||
|
qk = qk - m_ij[:, None]
|
||||||
|
p = tl.math.exp2(qk)
|
||||||
|
|
||||||
|
# CAVEAT: Must update l_ij before applying dropout
|
||||||
|
l_ij = tl.sum(p, 1)
|
||||||
|
if ENABLE_DROPOUT:
|
||||||
|
philox_offset = (batch_philox_offset +
|
||||||
|
start_m * BLOCK_M * actual_seqlen_k + start_n -
|
||||||
|
BLOCK_N)
|
||||||
|
keep = dropout_mask(
|
||||||
|
philox_seed,
|
||||||
|
philox_offset,
|
||||||
|
dropout_p,
|
||||||
|
BLOCK_M,
|
||||||
|
BLOCK_N,
|
||||||
|
actual_seqlen_k,
|
||||||
|
)
|
||||||
|
if RETURN_ENCODED_SOFTMAX:
|
||||||
|
tl.store(
|
||||||
|
encoded_softmax_block_ptr,
|
||||||
|
tl.where(keep, p,
|
||||||
|
-p).to(encoded_softmax_block_ptr.type.element_ty),
|
||||||
|
)
|
||||||
|
p = tl.where(keep, p, 0.0)
|
||||||
|
elif RETURN_ENCODED_SOFTMAX:
|
||||||
|
tl.store(
|
||||||
|
encoded_softmax_block_ptr,
|
||||||
|
p.to(encoded_softmax_block_ptr.type.element_ty),
|
||||||
|
)
|
||||||
|
# -- update output accumulator --
|
||||||
|
alpha = tl.math.exp2(m_i - m_ij)
|
||||||
|
acc = acc * alpha[:, None]
|
||||||
|
if not PRE_LOAD_V:
|
||||||
|
v = load_fn(
|
||||||
|
V_block_ptr,
|
||||||
|
MASK_STEPS and (n_extra_tokens != 0),
|
||||||
|
PADDED_HEAD,
|
||||||
|
"zero",
|
||||||
|
)
|
||||||
|
# -- update m_i and l_i
|
||||||
|
l_i = l_i * alpha + l_ij
|
||||||
|
# update m_i and l_i
|
||||||
|
m_i = m_ij
|
||||||
|
|
||||||
|
if USE_FP8:
|
||||||
|
p *= p_descale
|
||||||
|
|
||||||
|
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
|
||||||
|
|
||||||
|
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||||
|
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||||
|
if bias_ptr is not None:
|
||||||
|
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
|
||||||
|
if RETURN_ENCODED_SOFTMAX:
|
||||||
|
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
|
||||||
|
(0, BLOCK_N))
|
||||||
|
return acc, l_i, m_i
|
||||||
|
|
||||||
|
|
||||||
|
def get_cdna_autotune_configs():
|
||||||
|
return [
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
'BLOCK_M': 256,
|
||||||
|
'BLOCK_N': 64,
|
||||||
|
'waves_per_eu': 2,
|
||||||
|
'PRE_LOAD_V': False
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=8),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
'BLOCK_M': 128,
|
||||||
|
'BLOCK_N': 128,
|
||||||
|
'waves_per_eu': 2,
|
||||||
|
'PRE_LOAD_V': False
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=4),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
'BLOCK_M': 256,
|
||||||
|
'BLOCK_N': 128,
|
||||||
|
'waves_per_eu': 2,
|
||||||
|
'PRE_LOAD_V': False
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=8),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
'BLOCK_M': 128,
|
||||||
|
'BLOCK_N': 64,
|
||||||
|
'waves_per_eu': 1,
|
||||||
|
'PRE_LOAD_V': False
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=4),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
'BLOCK_M': 128,
|
||||||
|
'BLOCK_N': 64,
|
||||||
|
'waves_per_eu': 3,
|
||||||
|
'PRE_LOAD_V': True
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=4),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
'BLOCK_M': 128,
|
||||||
|
'BLOCK_N': 64,
|
||||||
|
'waves_per_eu': 3,
|
||||||
|
'PRE_LOAD_V': False
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=4),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
'BLOCK_M': 64,
|
||||||
|
'BLOCK_N': 64,
|
||||||
|
'waves_per_eu': 4,
|
||||||
|
'PRE_LOAD_V': False
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=8),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
'BLOCK_M': 32,
|
||||||
|
'BLOCK_N': 32,
|
||||||
|
'waves_per_eu': 4,
|
||||||
|
'PRE_LOAD_V': False
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=8),
|
||||||
|
# TODO: This config fails with head_size not pow2 with data mismatches.
|
||||||
|
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
|
||||||
|
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
|
||||||
|
|
||||||
|
# Fails in AccelerateAMDMatmul (Triton) assert when using FP8:
|
||||||
|
# triton.Config(
|
||||||
|
# {
|
||||||
|
# "BLOCK_M": 16,
|
||||||
|
# "BLOCK_N": 16,
|
||||||
|
# "waves_per_eu": 1,
|
||||||
|
# "PRE_LOAD_V": False,
|
||||||
|
# },
|
||||||
|
# num_stages=1,
|
||||||
|
# num_warps=4,
|
||||||
|
# ),
|
||||||
|
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8']
|
||||||
|
|
||||||
|
|
||||||
|
def get_rdna_autotune_configs():
|
||||||
|
return [
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
'BLOCK_M': 32,
|
||||||
|
'BLOCK_N': 32,
|
||||||
|
'waves_per_eu': 4,
|
||||||
|
'PRE_LOAD_V': False
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=2),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
'BLOCK_M': 32,
|
||||||
|
'BLOCK_N': 32,
|
||||||
|
'waves_per_eu': 2,
|
||||||
|
'PRE_LOAD_V': False
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=2),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
'BLOCK_M': 32,
|
||||||
|
'BLOCK_N': 16,
|
||||||
|
'waves_per_eu': 4,
|
||||||
|
'PRE_LOAD_V': False
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=2),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
'BLOCK_M': 32,
|
||||||
|
'BLOCK_N': 16,
|
||||||
|
'waves_per_eu': 2,
|
||||||
|
'PRE_LOAD_V': False
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=2),
|
||||||
|
# Fails in AccelerateAMDMatmul (Triton) assert when using FP8:
|
||||||
|
# triton.Config(
|
||||||
|
# {
|
||||||
|
# 'BLOCK_M': 16,
|
||||||
|
# 'BLOCK_N': 16,
|
||||||
|
# 'waves_per_eu': 4,
|
||||||
|
# 'PRE_LOAD_V': False
|
||||||
|
# },
|
||||||
|
# num_stages=1,
|
||||||
|
# num_warps=2),
|
||||||
|
# triton.Config(
|
||||||
|
# {
|
||||||
|
# 'BLOCK_M': 16,
|
||||||
|
# 'BLOCK_N': 16,
|
||||||
|
# 'waves_per_eu': 2,
|
||||||
|
# 'PRE_LOAD_V': False
|
||||||
|
# },
|
||||||
|
# num_stages=1,
|
||||||
|
# num_warps=2),
|
||||||
|
# # Fall-back config.
|
||||||
|
# triton.Config(
|
||||||
|
# {
|
||||||
|
# 'BLOCK_M': 16,
|
||||||
|
# 'BLOCK_N': 16,
|
||||||
|
# 'waves_per_eu': 1,
|
||||||
|
# 'PRE_LOAD_V': False
|
||||||
|
# },
|
||||||
|
# num_stages=1,
|
||||||
|
# num_warps=2),
|
||||||
|
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8']
|
||||||
|
|
||||||
|
|
||||||
|
def get_autotune_configs():
|
||||||
|
if on_gfx1x():
|
||||||
|
return get_rdna_autotune_configs()
|
||||||
|
else:
|
||||||
|
return get_cdna_autotune_configs()
|
||||||
|
|
||||||
|
|
||||||
|
autotune_configs, autotune_keys = get_autotune_configs()
|
||||||
|
|
||||||
|
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||||
|
|
||||||
|
|
||||||
|
@triton.autotune(
|
||||||
|
configs=autotune_configs,
|
||||||
|
key=autotune_keys,
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def attn_fwd(
|
||||||
|
Q,
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
bias,
|
||||||
|
sm_scale,
|
||||||
|
q_scale,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
p_scale,
|
||||||
|
p_descale,
|
||||||
|
o_descale,
|
||||||
|
L,
|
||||||
|
Out,
|
||||||
|
stride_qz: tl.int64,
|
||||||
|
stride_qh: tl.int64,
|
||||||
|
stride_qm: tl.int64,
|
||||||
|
stride_qk: tl.int64,
|
||||||
|
stride_kz: tl.int64,
|
||||||
|
stride_kh: tl.int64,
|
||||||
|
stride_kn: tl.int64,
|
||||||
|
stride_kk: tl.int64,
|
||||||
|
stride_vz: tl.int64,
|
||||||
|
stride_vh: tl.int64,
|
||||||
|
stride_vk: tl.int64,
|
||||||
|
stride_vn: tl.int64,
|
||||||
|
stride_oz: tl.int64,
|
||||||
|
stride_oh: tl.int64,
|
||||||
|
stride_om: tl.int64,
|
||||||
|
stride_on: tl.int64,
|
||||||
|
stride_bz: tl.int64,
|
||||||
|
stride_bh: tl.int64,
|
||||||
|
stride_bm: tl.int64,
|
||||||
|
stride_bn: tl.int64,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
dropout_p,
|
||||||
|
philox_seed,
|
||||||
|
philox_offset_base,
|
||||||
|
encoded_softmax,
|
||||||
|
HQ: tl.constexpr,
|
||||||
|
HK: tl.constexpr,
|
||||||
|
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
||||||
|
MAX_SEQLENS_Q: tl.constexpr,
|
||||||
|
MAX_SEQLENS_K: tl.constexpr,
|
||||||
|
VARLEN: tl.constexpr,
|
||||||
|
IS_CAUSAL: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
|
USE_FP8: tl.constexpr,
|
||||||
|
USE_FP8_OUT: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
PRE_LOAD_V: tl.constexpr,
|
||||||
|
BIAS_TYPE: tl.constexpr,
|
||||||
|
ENABLE_DROPOUT: tl.constexpr,
|
||||||
|
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
||||||
|
FP8_MIN: tl.constexpr = float8_info.min,
|
||||||
|
FP8_MAX: tl.constexpr = float8_info.max,
|
||||||
|
):
|
||||||
|
start_m = tl.program_id(0)
|
||||||
|
off_h_q = tl.program_id(1)
|
||||||
|
off_z = tl.program_id(2)
|
||||||
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
|
if VARLEN:
|
||||||
|
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
|
||||||
|
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
|
||||||
|
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
|
||||||
|
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
|
||||||
|
# small for all start_m so for those we return early.
|
||||||
|
if start_m * BLOCK_M > seqlen_q:
|
||||||
|
return
|
||||||
|
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
|
||||||
|
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
|
||||||
|
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
|
||||||
|
else:
|
||||||
|
cu_seqlens_q_start = 0
|
||||||
|
cu_seqlens_k_start = 0
|
||||||
|
seqlen_q = MAX_SEQLENS_Q
|
||||||
|
seqlen_k = MAX_SEQLENS_K
|
||||||
|
|
||||||
|
# Now we compute whether we need to exit early due to causal masking.
|
||||||
|
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
|
||||||
|
# are completely masked, resulting in 0s written to the output, and
|
||||||
|
# inf written to LSE. We don't need to do any GEMMs in this case.
|
||||||
|
# This block of code determines what N is, and if this WG is operating
|
||||||
|
# on those M rows.
|
||||||
|
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
|
||||||
|
if IS_CAUSAL:
|
||||||
|
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
|
||||||
|
# If seqlen_q != seqlen_k, attn scores are rectangular which means
|
||||||
|
# the causal mask boundary is bottom right aligned, and ends at either
|
||||||
|
# the top edge (seqlen_q < seqlen_k) or left edge.
|
||||||
|
# This captures the decrease in n_blocks if we have a rectangular attn
|
||||||
|
# matrix
|
||||||
|
n_blocks_seqlen = cdiv_fn(
|
||||||
|
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
|
||||||
|
# This is what adjusts the block_max for the current WG, only
|
||||||
|
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
|
||||||
|
n_blocks = min(n_blocks, n_blocks_seqlen)
|
||||||
|
# If we have no blocks after adjusting for seqlen deltas, this WG is
|
||||||
|
# part of the blocks that are all 0. We exit early.
|
||||||
|
if n_blocks <= 0:
|
||||||
|
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
|
||||||
|
off_h_q * stride_oh)
|
||||||
|
O_block_ptr = tl.make_block_ptr(
|
||||||
|
base=Out + o_offset,
|
||||||
|
shape=(seqlen_q, BLOCK_DMODEL),
|
||||||
|
strides=(stride_om, stride_on),
|
||||||
|
offsets=(start_m * BLOCK_M, 0),
|
||||||
|
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||||
|
order=(1, 0),
|
||||||
|
)
|
||||||
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
|
||||||
|
# We still need to write 0s to the result
|
||||||
|
# tl.store(O_block_ptr,
|
||||||
|
# acc.to(Out.type.element_ty), boundary_check=(0,1))
|
||||||
|
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
|
||||||
|
# + offs_m
|
||||||
|
# We store inf to LSE, not -inf because in the bwd pass,
|
||||||
|
# we subtract this
|
||||||
|
# from qk which makes it -inf, such that exp(qk - inf) = 0
|
||||||
|
# for these masked blocks.
|
||||||
|
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
|
||||||
|
# tl.store(l_ptrs, l)
|
||||||
|
# TODO: Should dropout and return encoded softmax be handled here?
|
||||||
|
return
|
||||||
|
|
||||||
|
# If MQA / GQA, set the K and V head offsets appropriately.
|
||||||
|
GROUP_SIZE: tl.constexpr = HQ // HK
|
||||||
|
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
|
||||||
|
|
||||||
|
n_extra_tokens = 0
|
||||||
|
if seqlen_k < BLOCK_N:
|
||||||
|
n_extra_tokens = BLOCK_N - seqlen_k
|
||||||
|
elif seqlen_k % BLOCK_N:
|
||||||
|
n_extra_tokens = seqlen_k % BLOCK_N
|
||||||
|
padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
|
||||||
|
|
||||||
|
# Compute pointers for all the tensors used in this kernel.
|
||||||
|
q_offset = (off_z * stride_qz + off_h_q * stride_qh +
|
||||||
|
cu_seqlens_q_start * stride_qm)
|
||||||
|
Q_block_ptr = tl.make_block_ptr(
|
||||||
|
base=Q + q_offset,
|
||||||
|
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||||
|
strides=(stride_qm, stride_qk),
|
||||||
|
offsets=(start_m * BLOCK_M, 0),
|
||||||
|
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||||
|
order=(1, 0),
|
||||||
|
)
|
||||||
|
k_offset = (off_z * stride_kz + off_h_k * stride_kh +
|
||||||
|
cu_seqlens_k_start * stride_kn)
|
||||||
|
K_block_ptr = tl.make_block_ptr(
|
||||||
|
base=K + k_offset,
|
||||||
|
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
|
||||||
|
strides=(stride_kk, stride_kn),
|
||||||
|
offsets=(0, 0),
|
||||||
|
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||||
|
order=(0, 1),
|
||||||
|
)
|
||||||
|
v_offset = (off_z * stride_vz + off_h_k * stride_vh +
|
||||||
|
cu_seqlens_k_start * stride_vk)
|
||||||
|
V_block_ptr = tl.make_block_ptr(
|
||||||
|
base=V + v_offset,
|
||||||
|
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
|
||||||
|
strides=(stride_vk, stride_vn),
|
||||||
|
offsets=(0, 0),
|
||||||
|
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||||
|
order=(1, 0),
|
||||||
|
)
|
||||||
|
if BIAS_TYPE != 0:
|
||||||
|
bias_ptr = tl.make_block_ptr(
|
||||||
|
base=bias + off_h_q * stride_bh,
|
||||||
|
shape=(seqlen_q, seqlen_k),
|
||||||
|
strides=(stride_bm, stride_bn),
|
||||||
|
offsets=(start_m * BLOCK_M, 0),
|
||||||
|
block_shape=(BLOCK_M, BLOCK_N),
|
||||||
|
order=(1, 0),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
bias_ptr = None
|
||||||
|
if ENABLE_DROPOUT:
|
||||||
|
batch_philox_offset = philox_offset_base \
|
||||||
|
+ (off_z * HQ + off_h_q) \
|
||||||
|
* seqlen_q * seqlen_k
|
||||||
|
else:
|
||||||
|
batch_philox_offset = 0
|
||||||
|
# We can ask to return the dropout mask without actually doing any dropout.
|
||||||
|
# In this case, we return an invalid pointer so indicate the mask is not i
|
||||||
|
# valid.
|
||||||
|
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
|
||||||
|
if RETURN_ENCODED_SOFTMAX:
|
||||||
|
encoded_softmax_block_ptr = tl.make_block_ptr(
|
||||||
|
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
|
||||||
|
shape=(seqlen_q, seqlen_k),
|
||||||
|
strides=(seqlen_k, 1),
|
||||||
|
offsets=(start_m * BLOCK_M, 0),
|
||||||
|
block_shape=(BLOCK_M, BLOCK_N),
|
||||||
|
order=(1, 0),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
encoded_softmax_block_ptr = 0
|
||||||
|
# initialize pointer to m and l
|
||||||
|
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||||
|
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||||
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||||
|
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
|
||||||
|
# have native e^x support in HW.
|
||||||
|
qk_scale = sm_scale * 1.44269504089
|
||||||
|
# Q is loaded once at the beginning and shared by all N blocks.
|
||||||
|
q = load_fn(Q_block_ptr, True, padded_head, "zero")
|
||||||
|
if not USE_FP8:
|
||||||
|
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
|
||||||
|
acc_scale = 1.0
|
||||||
|
else:
|
||||||
|
qk_scale *= q_scale * k_scale
|
||||||
|
acc_scale = p_scale * v_scale
|
||||||
|
|
||||||
|
# Here we compute how many full and masked blocks we have.
|
||||||
|
padded_block_k = n_extra_tokens != 0
|
||||||
|
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
|
||||||
|
if IS_CAUSAL:
|
||||||
|
# There are always at least BLOCK_M // BLOCK_N masked blocks.
|
||||||
|
# Additionally there might be one more due to dissimilar seqlens.
|
||||||
|
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
|
||||||
|
else:
|
||||||
|
# Padding on Q does not need to be masked in the FA loop.
|
||||||
|
masked_blocks = padded_block_k
|
||||||
|
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
|
||||||
|
# block. In this case we might exceed n_blocks so pick the min.
|
||||||
|
masked_blocks = min(masked_blocks, n_blocks)
|
||||||
|
n_full_blocks = n_blocks - masked_blocks
|
||||||
|
block_min = 0
|
||||||
|
block_max = n_blocks * BLOCK_N
|
||||||
|
# Compute for full blocks. Here we set causal to false regardless of its
|
||||||
|
# value because there is no masking. Similarly we do not need padding.
|
||||||
|
if n_full_blocks > 0:
|
||||||
|
block_max = (n_blocks - masked_blocks) * BLOCK_N
|
||||||
|
acc, l_i, m_i = _attn_fwd_inner(
|
||||||
|
acc,
|
||||||
|
l_i,
|
||||||
|
m_i,
|
||||||
|
q,
|
||||||
|
K_block_ptr,
|
||||||
|
V_block_ptr,
|
||||||
|
start_m,
|
||||||
|
seqlen_k,
|
||||||
|
dropout_p,
|
||||||
|
philox_seed,
|
||||||
|
batch_philox_offset,
|
||||||
|
encoded_softmax_block_ptr,
|
||||||
|
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
|
||||||
|
block_min,
|
||||||
|
block_max,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
bias_ptr,
|
||||||
|
# IS_CAUSAL, ....
|
||||||
|
False,
|
||||||
|
BLOCK_M,
|
||||||
|
BLOCK_DMODEL,
|
||||||
|
BLOCK_N,
|
||||||
|
offs_m,
|
||||||
|
offs_n,
|
||||||
|
# _, MASK_STEPS, ...
|
||||||
|
PRE_LOAD_V,
|
||||||
|
False,
|
||||||
|
ENABLE_DROPOUT,
|
||||||
|
RETURN_ENCODED_SOFTMAX,
|
||||||
|
padded_head,
|
||||||
|
USE_FP8,
|
||||||
|
qk_scale,
|
||||||
|
p_descale,
|
||||||
|
)
|
||||||
|
block_min = block_max
|
||||||
|
block_max = n_blocks * BLOCK_N
|
||||||
|
|
||||||
|
tl.debug_barrier()
|
||||||
|
# Remaining blocks, if any, are full / not masked.
|
||||||
|
if masked_blocks > 0:
|
||||||
|
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
|
||||||
|
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
|
||||||
|
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
|
||||||
|
if bias_ptr is not None:
|
||||||
|
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
|
||||||
|
if RETURN_ENCODED_SOFTMAX:
|
||||||
|
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
|
||||||
|
(0, n_full_blocks))
|
||||||
|
acc, l_i, m_i = _attn_fwd_inner(
|
||||||
|
acc,
|
||||||
|
l_i,
|
||||||
|
m_i,
|
||||||
|
q,
|
||||||
|
K_block_ptr,
|
||||||
|
V_block_ptr,
|
||||||
|
start_m,
|
||||||
|
seqlen_k,
|
||||||
|
dropout_p,
|
||||||
|
philox_seed,
|
||||||
|
batch_philox_offset,
|
||||||
|
encoded_softmax_block_ptr,
|
||||||
|
block_min,
|
||||||
|
block_max,
|
||||||
|
offs_n_causal,
|
||||||
|
masked_blocks,
|
||||||
|
n_extra_tokens,
|
||||||
|
bias_ptr,
|
||||||
|
IS_CAUSAL,
|
||||||
|
BLOCK_M,
|
||||||
|
BLOCK_DMODEL,
|
||||||
|
BLOCK_N,
|
||||||
|
offs_m,
|
||||||
|
offs_n,
|
||||||
|
# _, MASK_STEPS, ...
|
||||||
|
PRE_LOAD_V,
|
||||||
|
True,
|
||||||
|
ENABLE_DROPOUT,
|
||||||
|
RETURN_ENCODED_SOFTMAX,
|
||||||
|
padded_head,
|
||||||
|
USE_FP8,
|
||||||
|
qk_scale,
|
||||||
|
p_descale,
|
||||||
|
)
|
||||||
|
# epilogue
|
||||||
|
|
||||||
|
if USE_FP8:
|
||||||
|
acc *= acc_scale
|
||||||
|
acc = acc / l_i[:, None]
|
||||||
|
if ENABLE_DROPOUT:
|
||||||
|
acc = acc / (1 - dropout_p)
|
||||||
|
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
|
||||||
|
# then we have one block with a row of all NaNs which come from computing
|
||||||
|
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
|
||||||
|
# and store 0s where there are NaNs as these rows should've been zeroed out.
|
||||||
|
end_m_idx = (start_m + 1) * BLOCK_M
|
||||||
|
start_m_idx = start_m * BLOCK_M
|
||||||
|
causal_start_idx = seqlen_q - seqlen_k
|
||||||
|
if USE_FP8_OUT:
|
||||||
|
acc *= o_descale
|
||||||
|
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||||
|
acc = acc.to(Out.type.element_ty)
|
||||||
|
if IS_CAUSAL: # noqa: SIM102
|
||||||
|
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
|
||||||
|
out_mask_boundary = tl.full((BLOCK_DMODEL, ),
|
||||||
|
causal_start_idx,
|
||||||
|
dtype=tl.int32)
|
||||||
|
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
|
||||||
|
out_ptrs_mask = (mask_m_offsets[:, None]
|
||||||
|
>= out_mask_boundary[None, :])
|
||||||
|
z = tl.zeros((1, ), tl.float32)
|
||||||
|
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
|
||||||
|
# write back LSE
|
||||||
|
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
|
||||||
|
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
|
||||||
|
# few rows. This is only true for the last M block. For others,
|
||||||
|
# overflow_size will be -ve
|
||||||
|
# overflow_size = end_m_idx - seqlen_q
|
||||||
|
# if overflow_size > 0:
|
||||||
|
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
|
||||||
|
# # This is a > check because mask being 0 blocks the store.
|
||||||
|
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
|
||||||
|
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
|
||||||
|
# else:
|
||||||
|
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
||||||
|
|
||||||
|
# write back O
|
||||||
|
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
|
||||||
|
off_h_q * stride_oh)
|
||||||
|
O_block_ptr = tl.make_block_ptr(
|
||||||
|
base=Out + o_offset,
|
||||||
|
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||||
|
strides=(stride_om, stride_on),
|
||||||
|
offsets=(start_m * BLOCK_M, 0),
|
||||||
|
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||||
|
order=(1, 0),
|
||||||
|
)
|
||||||
|
# Need boundary check on this to make sure the padding from the
|
||||||
|
# Q and KV tensors in both dims are not part of what we store back.
|
||||||
|
# TODO: Do the boundary check optionally.
|
||||||
|
tl.store(O_block_ptr, acc, boundary_check=(0, 1))
|
||||||
|
|
||||||
|
|
||||||
|
def check_args(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
o,
|
||||||
|
varlen=True,
|
||||||
|
max_seqlens=None,
|
||||||
|
cu_seqlens_q=None,
|
||||||
|
cu_seqlens_k=None,
|
||||||
|
):
|
||||||
|
assert q.dim() == k.dim() and q.dim() == v.dim()
|
||||||
|
if varlen:
|
||||||
|
assert q.dim() == 3
|
||||||
|
total_q, nheads_q, head_size = q.shape
|
||||||
|
total_k, nheads_k, _ = k.shape
|
||||||
|
assert cu_seqlens_q is not None
|
||||||
|
assert cu_seqlens_k is not None
|
||||||
|
assert len(cu_seqlens_q) == len(cu_seqlens_k)
|
||||||
|
else:
|
||||||
|
assert q.dim() == 4
|
||||||
|
batch, nheads_q, seqlen_q, head_size = q.shape
|
||||||
|
_, nheads_k, seqlen_k, _ = k.shape
|
||||||
|
assert max_seqlens > 0
|
||||||
|
assert k.shape == v.shape
|
||||||
|
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
|
||||||
|
# TODO: Change assert if we support qkl f8 and v f16
|
||||||
|
assert q.dtype == k.dtype and q.dtype == v.dtype
|
||||||
|
assert head_size <= 256
|
||||||
|
assert o.shape == q.shape
|
||||||
|
assert (nheads_q % nheads_k) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class _attention(torch.autograd.Function):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
o,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlens_q,
|
||||||
|
max_seqlens_k,
|
||||||
|
causal=False,
|
||||||
|
sm_scale=1.0,
|
||||||
|
bias=None,
|
||||||
|
fp8_scales=None,
|
||||||
|
fp8_out_scale=None,
|
||||||
|
):
|
||||||
|
if fp8_scales is not None:
|
||||||
|
use_fp8 = True
|
||||||
|
(q_scale, k_scale, v_scale, p_scale) = fp8_scales
|
||||||
|
float8 = current_platform.fp8_dtype()
|
||||||
|
|
||||||
|
def check_and_convert(t, scale):
|
||||||
|
if t.dtype != float8:
|
||||||
|
descale = 1.0 / scale
|
||||||
|
ts = (t * descale).clamp(min=float8_info.min,
|
||||||
|
max=float8_info.max)
|
||||||
|
return ts.to(float8)
|
||||||
|
else:
|
||||||
|
return t
|
||||||
|
|
||||||
|
q = check_and_convert(q, q_scale)
|
||||||
|
k = check_and_convert(k, k_scale)
|
||||||
|
v = check_and_convert(v, v_scale)
|
||||||
|
else:
|
||||||
|
use_fp8 = False
|
||||||
|
q_scale = k_scale = v_scale = p_scale = 1.0
|
||||||
|
|
||||||
|
if o is None:
|
||||||
|
o = torch.empty_like(q, dtype=v.dtype)
|
||||||
|
|
||||||
|
check_args(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
o,
|
||||||
|
varlen=True,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
)
|
||||||
|
if True: # varlen
|
||||||
|
total_q, nheads_q, head_size = q.shape
|
||||||
|
total_k, nheads_k, _ = k.shape
|
||||||
|
batch = len(cu_seqlens_q) - 1
|
||||||
|
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
|
||||||
|
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
|
||||||
|
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
|
||||||
|
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
|
||||||
|
else:
|
||||||
|
batch, seqlen_q, nheads_q, head_size = q.shape
|
||||||
|
_, seqlen_k, nheads_k, _ = k.shape
|
||||||
|
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
|
||||||
|
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
|
||||||
|
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
|
||||||
|
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
|
||||||
|
|
||||||
|
# Get closest power of 2 over or equal to 32.
|
||||||
|
unpadded_head_dims = {32, 64, 128, 256}
|
||||||
|
if head_size not in unpadded_head_dims:
|
||||||
|
padded_d_model = None
|
||||||
|
for i in unpadded_head_dims:
|
||||||
|
if i > head_size:
|
||||||
|
padded_d_model = i
|
||||||
|
break
|
||||||
|
assert padded_d_model is not None
|
||||||
|
else:
|
||||||
|
padded_d_model = head_size
|
||||||
|
|
||||||
|
grid = lambda META: (
|
||||||
|
triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
|
||||||
|
nheads_q,
|
||||||
|
batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
encoded_softmax = None
|
||||||
|
|
||||||
|
# Seed the RNG so we get reproducible results for testing.
|
||||||
|
philox_seed = 0x1BF52
|
||||||
|
philox_offset = 0x1D4B42
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
bias_strides = (
|
||||||
|
bias.stride(0),
|
||||||
|
bias.stride(1),
|
||||||
|
bias.stride(2),
|
||||||
|
bias.stride(3),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
bias_strides = (0, 0, 0, 0)
|
||||||
|
|
||||||
|
p_descale = 1.0 / p_scale
|
||||||
|
o_descale = 1.0 / fp8_out_scale.item(
|
||||||
|
) if fp8_out_scale is not None else 1.0
|
||||||
|
|
||||||
|
arg_max_seqlens_q = 0 if on_gfx1x() else max_seqlens_q
|
||||||
|
arg_max_seqlens_k = 0 if on_gfx1x() else max_seqlens_k
|
||||||
|
|
||||||
|
attn_fwd[grid](
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
bias,
|
||||||
|
sm_scale,
|
||||||
|
q_scale,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
p_scale,
|
||||||
|
p_descale,
|
||||||
|
o_descale,
|
||||||
|
None,
|
||||||
|
o,
|
||||||
|
*q_strides,
|
||||||
|
*k_strides,
|
||||||
|
*v_strides,
|
||||||
|
*o_strides,
|
||||||
|
*bias_strides,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
dropout_p=0.0,
|
||||||
|
philox_seed=philox_seed,
|
||||||
|
philox_offset_base=philox_offset,
|
||||||
|
encoded_softmax=encoded_softmax,
|
||||||
|
HQ=nheads_q,
|
||||||
|
HK=nheads_k,
|
||||||
|
ACTUAL_BLOCK_DMODEL=head_size,
|
||||||
|
MAX_SEQLENS_Q=arg_max_seqlens_q,
|
||||||
|
MAX_SEQLENS_K=arg_max_seqlens_k,
|
||||||
|
IS_CAUSAL=causal,
|
||||||
|
VARLEN=True,
|
||||||
|
BLOCK_DMODEL=padded_d_model,
|
||||||
|
BIAS_TYPE=0 if bias is None else 1,
|
||||||
|
ENABLE_DROPOUT=False,
|
||||||
|
RETURN_ENCODED_SOFTMAX=False,
|
||||||
|
USE_FP8=use_fp8,
|
||||||
|
USE_FP8_OUT=fp8_out_scale is not None,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx.grid = grid
|
||||||
|
ctx.sm_scale = sm_scale
|
||||||
|
ctx.BLOCK_DMODEL = head_size
|
||||||
|
ctx.causal = causal
|
||||||
|
ctx.dropout_p = 0.0
|
||||||
|
ctx.philox_seed = philox_seed
|
||||||
|
ctx.philox_offset = philox_offset
|
||||||
|
ctx.encoded_softmax = encoded_softmax
|
||||||
|
ctx.return_encoded_softmax = False
|
||||||
|
return o, encoded_softmax
|
||||||
|
|
||||||
|
|
||||||
|
triton_attention = _attention.apply
|
||||||
97
vllm/attention/ops/triton_merge_attn_states.py
Normal file
97
vllm/attention/ops/triton_merge_attn_states.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
|
||||||
|
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||||
|
# can be used to combine partial attention results (in the split-KV case)
|
||||||
|
def merge_attn_states(
|
||||||
|
output: torch.Tensor,
|
||||||
|
prefix_output: torch.Tensor,
|
||||||
|
prefix_lse: torch.Tensor,
|
||||||
|
suffix_output: torch.Tensor,
|
||||||
|
suffix_lse: torch.Tensor,
|
||||||
|
output_lse: Optional[torch.Tensor] = None,
|
||||||
|
) -> None:
|
||||||
|
num_tokens = output.shape[0]
|
||||||
|
num_query_heads = output.shape[1]
|
||||||
|
head_size = output.shape[2]
|
||||||
|
padded_head_size = triton.next_power_of_2(head_size)
|
||||||
|
|
||||||
|
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
|
||||||
|
merge_attn_states_kernel[(num_tokens, num_query_heads)](
|
||||||
|
output,
|
||||||
|
output_lse,
|
||||||
|
prefix_output,
|
||||||
|
prefix_lse,
|
||||||
|
suffix_output,
|
||||||
|
suffix_lse,
|
||||||
|
head_size,
|
||||||
|
padded_head_size,
|
||||||
|
output_lse is not None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def merge_attn_states_kernel(
|
||||||
|
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||||
|
output_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||||
|
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||||
|
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||||
|
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||||
|
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||||
|
HEAD_SIZE: tl.constexpr,
|
||||||
|
PADDED_HEAD_SIZE: tl.constexpr,
|
||||||
|
OUTPUT_LSE: tl.constexpr,
|
||||||
|
):
|
||||||
|
token_idx = tl.program_id(0)
|
||||||
|
num_tokens = tl.num_programs(0)
|
||||||
|
head_idx = tl.program_id(1)
|
||||||
|
num_heads = tl.num_programs(1)
|
||||||
|
|
||||||
|
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
|
||||||
|
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
|
||||||
|
|
||||||
|
# FA2 and FA3 have different behavior for when the sum-exp is 0, this namely
|
||||||
|
# arises with 0 len seqlens. FA3 returns -inf here while FA2 returns inf.
|
||||||
|
# If we see an inf assume FA2 and convert inf to -inf for consistency
|
||||||
|
# and correctness. Inf generally doesn't make sense in this context outside
|
||||||
|
# of undefined-behavior/FA2-case, so I think this a safe assumption.
|
||||||
|
p_lse = float('-inf') if p_lse == float('inf') else p_lse
|
||||||
|
s_lse = float('-inf') if s_lse == float('inf') else s_lse
|
||||||
|
|
||||||
|
max_lse = tl.maximum(p_lse, s_lse)
|
||||||
|
p_lse = p_lse - max_lse
|
||||||
|
s_lse = s_lse - max_lse
|
||||||
|
# Will reuse precomputed Exp values for scale factor computation.
|
||||||
|
p_se = tl.exp(p_lse)
|
||||||
|
s_se = tl.exp(s_lse)
|
||||||
|
out_se = (p_se + s_se)
|
||||||
|
|
||||||
|
if OUTPUT_LSE:
|
||||||
|
out_lse = tl.log(out_se) + max_lse
|
||||||
|
tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse)
|
||||||
|
|
||||||
|
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
|
||||||
|
head_mask = head_arange < HEAD_SIZE
|
||||||
|
p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE +
|
||||||
|
head_idx * HEAD_SIZE + head_arange,
|
||||||
|
mask=head_mask)
|
||||||
|
s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE +
|
||||||
|
head_idx * HEAD_SIZE + head_arange,
|
||||||
|
mask=head_mask)
|
||||||
|
|
||||||
|
# NOTE(woosuk): Be careful with the numerical stability.
|
||||||
|
# We should compute the scale first, and then multiply it with the output.
|
||||||
|
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
|
||||||
|
p_scale = p_se / out_se
|
||||||
|
s_scale = s_se / out_se
|
||||||
|
out = p_out * p_scale + s_out * s_scale
|
||||||
|
tl.store(output + token_idx * num_heads * HEAD_SIZE +
|
||||||
|
head_idx * HEAD_SIZE + head_arange,
|
||||||
|
out,
|
||||||
|
mask=head_mask)
|
||||||
175
vllm/attention/ops/triton_reshape_and_cache_flash.py
Normal file
175
vllm/attention/ops/triton_reshape_and_cache_flash.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def reshape_and_cache_kernel_flash(
|
||||||
|
key_ptr, # [num_tokens, num_heads, head_size]
|
||||||
|
value_ptr, # [num_tokens, num_heads, head_size]
|
||||||
|
key_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
|
||||||
|
value_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
|
||||||
|
slot_mapping_ptr, # [num_tokens]
|
||||||
|
k_scale, # float32
|
||||||
|
v_scale, # float32
|
||||||
|
# strides
|
||||||
|
key_stride: tl.int64,
|
||||||
|
value_stride: tl.int64,
|
||||||
|
block_stride: tl.int64,
|
||||||
|
page_stride: tl.int64,
|
||||||
|
num_heads: tl.constexpr,
|
||||||
|
head_size: tl.constexpr,
|
||||||
|
block_size: tl.constexpr,
|
||||||
|
# FP8 flags
|
||||||
|
FP8_KV_CACHE: tl.constexpr,
|
||||||
|
# tune parameters
|
||||||
|
TILE_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
|
||||||
|
token_idx = tl.program_id(axis=0)
|
||||||
|
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64)
|
||||||
|
if slot_idx < 0:
|
||||||
|
# Padding token that should be ignored.
|
||||||
|
return
|
||||||
|
|
||||||
|
tile_i = tl.program_id(axis=1)
|
||||||
|
tile_offs = tl.arange(0, TILE_SIZE)
|
||||||
|
tile_pos = tile_i * TILE_SIZE + tile_offs
|
||||||
|
|
||||||
|
block_idx = slot_idx // block_size
|
||||||
|
block_offset = slot_idx % block_size
|
||||||
|
|
||||||
|
src_key_idx = token_idx * key_stride
|
||||||
|
src_value_idx = token_idx * value_stride
|
||||||
|
|
||||||
|
tgt_idx = block_idx * block_stride + block_offset * page_stride
|
||||||
|
|
||||||
|
# [TILE_SIZE]
|
||||||
|
key_load = tl.load(key_ptr + src_key_idx + tile_pos,
|
||||||
|
mask=tile_pos < (num_heads * head_size))
|
||||||
|
if FP8_KV_CACHE:
|
||||||
|
if key_load.dtype.is_fp8():
|
||||||
|
key_tile = key_load
|
||||||
|
else:
|
||||||
|
# tl.store will do the correct implicit cast to fp8,
|
||||||
|
# based on the key_cache_ptr.dtype.element_ty
|
||||||
|
key_tile = key_load / tl.load(k_scale)
|
||||||
|
else:
|
||||||
|
key_tile = key_load
|
||||||
|
|
||||||
|
# [TILE_SIZE]
|
||||||
|
value_load = tl.load(value_ptr + src_value_idx + tile_pos,
|
||||||
|
mask=tile_pos < (num_heads * head_size))
|
||||||
|
if FP8_KV_CACHE:
|
||||||
|
if value_load.dtype.is_fp8():
|
||||||
|
value_tile = value_load
|
||||||
|
else:
|
||||||
|
# tl.store will do the correct implicit cast to fp8,
|
||||||
|
# based on the value_cache_ptr.dtype.element_ty
|
||||||
|
value_tile = value_load / tl.load(v_scale)
|
||||||
|
else:
|
||||||
|
value_tile = value_load
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
key_cache_ptr + tgt_idx + tile_pos,
|
||||||
|
key_tile,
|
||||||
|
mask=tile_pos < (num_heads * head_size),
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
value_cache_ptr + tgt_idx + tile_pos,
|
||||||
|
value_tile,
|
||||||
|
mask=tile_pos < (num_heads * head_size),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def triton_reshape_and_cache_flash(
|
||||||
|
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||||
|
value: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||||
|
# [num_blocks, block_size, num_heads, head_size]
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
# [num_blocks, block_size, num_heads, head_size]
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor, # [num_tokens]
|
||||||
|
kv_cache_dtype: str, # "auto", "fp8"
|
||||||
|
k_scale: torch.Tensor, # float32
|
||||||
|
v_scale: torch.Tensor, # float32
|
||||||
|
):
|
||||||
|
num_tokens = key.shape[0]
|
||||||
|
num_heads = key.shape[1]
|
||||||
|
head_size = key.shape[2]
|
||||||
|
block_size = key_cache.shape[1]
|
||||||
|
n = num_heads * head_size
|
||||||
|
|
||||||
|
key_stride = key.stride()[0]
|
||||||
|
value_stride = value.stride()[0]
|
||||||
|
block_stride = key_cache.stride()[0]
|
||||||
|
page_stride = key_cache.stride()[1]
|
||||||
|
|
||||||
|
head_stride = key_cache.stride()[2]
|
||||||
|
assert head_stride == head_size, "only continous heads are supported"
|
||||||
|
|
||||||
|
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), \
|
||||||
|
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
|
||||||
|
kv_cache_torch_dtype = current_platform.fp8_dtype() if \
|
||||||
|
kv_cache_dtype.startswith("fp8") else key_cache.dtype
|
||||||
|
|
||||||
|
if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith(
|
||||||
|
"fp8"):
|
||||||
|
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
|
||||||
|
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
|
||||||
|
key_cache = key_cache.view(kv_cache_torch_dtype)
|
||||||
|
value_cache = value_cache.view(kv_cache_torch_dtype)
|
||||||
|
assert kv_cache_dtype != torch.uint8, "explicit fp8 cast and store to "\
|
||||||
|
"uint8 is not supported by triton reshape_and_cache_flash"
|
||||||
|
|
||||||
|
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
|
||||||
|
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
|
||||||
|
torch.float8_e4m3fn, torch.float8_e5m2, torch.uint8,
|
||||||
|
torch.float8_e4m3fnuz], \
|
||||||
|
"unsupported dtype of KV cache tensor, got "\
|
||||||
|
"{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, " \
|
||||||
|
"fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz."
|
||||||
|
|
||||||
|
# heuristics instead of autotuning
|
||||||
|
TILE_SIZE = min(2048, triton.next_power_of_2(n))
|
||||||
|
if torch.version.hip or torch.version.xpu:
|
||||||
|
num_stages = 4
|
||||||
|
num_warps = 8
|
||||||
|
else: # cuda
|
||||||
|
num_stages = 10
|
||||||
|
num_warps = 16
|
||||||
|
if torch.cuda.get_device_capability(key.device)[0] < 9:
|
||||||
|
TILE_SIZE = min(512, TILE_SIZE)
|
||||||
|
|
||||||
|
# TODO(ngl): maybe replace with static launch grid to avoid overhead if
|
||||||
|
# using cudagraphs
|
||||||
|
grid = lambda meta: (int(num_tokens), triton.cdiv(n, meta["TILE_SIZE"]))
|
||||||
|
|
||||||
|
reshape_and_cache_kernel_flash[grid](
|
||||||
|
key_ptr=key,
|
||||||
|
value_ptr=value,
|
||||||
|
key_cache_ptr=key_cache,
|
||||||
|
value_cache_ptr=value_cache,
|
||||||
|
slot_mapping_ptr=slot_mapping,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
|
# strides
|
||||||
|
key_stride=key_stride,
|
||||||
|
value_stride=value_stride,
|
||||||
|
block_stride=block_stride,
|
||||||
|
page_stride=page_stride,
|
||||||
|
num_heads=num_heads,
|
||||||
|
head_size=head_size,
|
||||||
|
block_size=block_size,
|
||||||
|
# FP8 flags
|
||||||
|
FP8_KV_CACHE=FP8_KV_CACHE,
|
||||||
|
# autotune parameters
|
||||||
|
TILE_SIZE=TILE_SIZE,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=num_stages,
|
||||||
|
)
|
||||||
894
vllm/attention/ops/triton_unified_attention.py
Normal file
894
vllm/attention/ops/triton_unified_attention.py
Normal file
@@ -0,0 +1,894 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Authors:
|
||||||
|
# - Burkhard Ringlein <ngl@zurich.ibm.com>
|
||||||
|
# - Jan van Lunteren <jvl@zurich.ibm.com>
|
||||||
|
# - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
|
||||||
|
# - Thomas Parnell <tpa@zurich.ibm.com>
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def cdiv_fn(x, y):
|
||||||
|
return (x + y - 1) // y
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def apply_softcap(S, x):
|
||||||
|
Sdiv = S / x
|
||||||
|
p1 = tl.exp(Sdiv)
|
||||||
|
p2 = tl.exp(-Sdiv)
|
||||||
|
return x * (p1 - p2) / (p1 + p2)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def find_seq_idx(query_start_len_ptr, target_idx, num_seqs,
|
||||||
|
BLOCK_Q: tl.constexpr, use_q_block_mode: tl.constexpr):
|
||||||
|
left: tl.int32 = 0
|
||||||
|
right = num_seqs
|
||||||
|
while left < right:
|
||||||
|
mid = (left + right) // 2
|
||||||
|
val = tl.load(query_start_len_ptr + mid)
|
||||||
|
mid_val = val // BLOCK_Q + mid if use_q_block_mode else val
|
||||||
|
|
||||||
|
if mid_val <= target_idx:
|
||||||
|
left = mid + 1
|
||||||
|
else:
|
||||||
|
right = mid
|
||||||
|
|
||||||
|
return left - 1
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def kernel_unified_attention_2d(
|
||||||
|
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||||
|
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||||
|
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||||
|
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||||
|
sink_ptr, # [num_query_heads]
|
||||||
|
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||||
|
seq_lens_ptr, # [num_seqs]
|
||||||
|
alibi_slopes_ptr, # [num_query_heads]
|
||||||
|
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
|
||||||
|
scale, # float32
|
||||||
|
k_scale, # float32
|
||||||
|
v_scale, # float32
|
||||||
|
out_scale, # float32
|
||||||
|
softcap, # float32
|
||||||
|
num_query_heads: tl.constexpr, # int
|
||||||
|
num_queries_per_kv: tl.constexpr, # int
|
||||||
|
block_table_stride: tl.int64, # int
|
||||||
|
query_stride_0: tl.int64, # int
|
||||||
|
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||||
|
output_stride_0: tl.int64, # int
|
||||||
|
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||||
|
qq_bias_stride_0: tl.int64, # int
|
||||||
|
BLOCK_SIZE: tl.constexpr, # int
|
||||||
|
TILE_SIZE: tl.constexpr, # int must be power of 2
|
||||||
|
HEAD_SIZE: tl.constexpr, # int
|
||||||
|
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||||
|
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||||
|
USE_QQ_BIAS: tl.constexpr, # bool
|
||||||
|
USE_SOFTCAP: tl.constexpr, # bool
|
||||||
|
USE_SINKS: tl.constexpr, # bool
|
||||||
|
SLIDING_WINDOW: tl.constexpr, # int
|
||||||
|
stride_k_cache_0: tl.int64, # int
|
||||||
|
stride_k_cache_1: tl.int64, # int
|
||||||
|
stride_k_cache_2: tl.int64, # int
|
||||||
|
stride_k_cache_3: tl.constexpr, # int
|
||||||
|
stride_v_cache_0: tl.int64, # int
|
||||||
|
stride_v_cache_1: tl.int64, # int
|
||||||
|
stride_v_cache_2: tl.int64, # int
|
||||||
|
stride_v_cache_3: tl.constexpr, # int
|
||||||
|
query_start_len_ptr, # [num_seqs+1]
|
||||||
|
BLOCK_Q: tl.constexpr, # int
|
||||||
|
num_seqs: tl.int32,
|
||||||
|
BLOCK_M: tl.constexpr, # int
|
||||||
|
USE_FP8: tl.constexpr, # bool
|
||||||
|
FP8_MIN: tl.constexpr = float8_info.min,
|
||||||
|
FP8_MAX: tl.constexpr = float8_info.max,
|
||||||
|
):
|
||||||
|
q_block_global_idx = tl.program_id(0)
|
||||||
|
kv_head_idx = tl.program_id(1)
|
||||||
|
|
||||||
|
seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs,
|
||||||
|
BLOCK_Q, True)
|
||||||
|
|
||||||
|
q_block_start_idx = tl.load(query_start_len_ptr +
|
||||||
|
seq_idx) // BLOCK_Q + seq_idx
|
||||||
|
|
||||||
|
q_block_local_idx = q_block_global_idx - q_block_start_idx
|
||||||
|
|
||||||
|
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
||||||
|
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
|
||||||
|
|
||||||
|
cur_batch_query_len = cur_batch_in_all_stop_index \
|
||||||
|
- cur_batch_in_all_start_index
|
||||||
|
|
||||||
|
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
|
||||||
|
return
|
||||||
|
|
||||||
|
offs_m = tl.arange(0, BLOCK_M)
|
||||||
|
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||||
|
offs_t = tl.arange(0, TILE_SIZE)
|
||||||
|
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
|
||||||
|
|
||||||
|
query_offset_0 = cur_batch_in_all_start_index + query_pos
|
||||||
|
query_offset_1 = kv_head_idx * num_queries_per_kv + \
|
||||||
|
offs_m % num_queries_per_kv
|
||||||
|
query_offset = (query_offset_0[:, None] * query_stride_0 +
|
||||||
|
query_offset_1[:, None] * query_stride_1 + offs_d[None, :])
|
||||||
|
|
||||||
|
dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
|
||||||
|
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
|
||||||
|
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
|
||||||
|
|
||||||
|
# Q : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||||
|
Q = tl.load(
|
||||||
|
query_ptr + query_offset,
|
||||||
|
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
block_table_offset = seq_idx * block_table_stride
|
||||||
|
|
||||||
|
if not USE_SINKS:
|
||||||
|
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||||
|
else:
|
||||||
|
M = tl.load(
|
||||||
|
sink_ptr + query_offset_1,
|
||||||
|
mask=query_mask_1,
|
||||||
|
other=float("-inf"),
|
||||||
|
).to(dtype=tl.float32)
|
||||||
|
|
||||||
|
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||||
|
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
|
||||||
|
|
||||||
|
# sequence len for this particular sequence
|
||||||
|
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||||
|
|
||||||
|
# context length for this particular sequences
|
||||||
|
context_len = seq_len - cur_batch_query_len
|
||||||
|
|
||||||
|
# alibi slope for this head
|
||||||
|
if USE_ALIBI_SLOPES:
|
||||||
|
alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1,
|
||||||
|
mask=query_mask_1,
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
|
# query-query attention bias
|
||||||
|
if USE_QQ_BIAS:
|
||||||
|
qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
|
||||||
|
) # shape: [BLOCK_M]
|
||||||
|
|
||||||
|
# compute the length of the longest sequence prefix spanned by any
|
||||||
|
# query token in the current q_block (q_block_local_idx)
|
||||||
|
max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
|
||||||
|
BLOCK_M - 1) // num_queries_per_kv + 1
|
||||||
|
|
||||||
|
# adjust for potential padding in the last q_block by considering the
|
||||||
|
# actual sequence length
|
||||||
|
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
|
||||||
|
|
||||||
|
# calculate the number of tiles that need to be processed to
|
||||||
|
# cover the longest sequence prefix (due to causal masking, tiles beyond
|
||||||
|
# this prefix can be skipped)
|
||||||
|
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
|
||||||
|
|
||||||
|
# ---- Sliding-window tile pruning --------------------
|
||||||
|
# Default: keep previous global behavior
|
||||||
|
tile_start = 0
|
||||||
|
tile_end = num_tiles
|
||||||
|
if SLIDING_WINDOW > 0:
|
||||||
|
# Query rows covered by this Q-block
|
||||||
|
qpos_lo = q_block_local_idx * BLOCK_Q
|
||||||
|
qpos_hi = tl.minimum(
|
||||||
|
qpos_lo + (BLOCK_M - 1) // num_queries_per_kv,
|
||||||
|
cur_batch_query_len - 1,
|
||||||
|
)
|
||||||
|
# For sliding window, each query position q can only attend to
|
||||||
|
# keys in the range [q_abs - SLIDING_WINDOW + 1, q_abs]
|
||||||
|
# where q_abs = context_len + q
|
||||||
|
# The union of allowed key positions for this Q-block is:
|
||||||
|
# [context_len + qpos_lo - SLIDING_WINDOW + 1, context_len + qpos_hi]
|
||||||
|
first_allowed_key = context_len + qpos_lo - SLIDING_WINDOW + 1
|
||||||
|
last_allowed_key = context_len + qpos_hi
|
||||||
|
# Convert to tile indices and clamp
|
||||||
|
tile_start = tl.maximum(0, first_allowed_key // TILE_SIZE)
|
||||||
|
tile_end = tl.minimum((last_allowed_key // TILE_SIZE) + 1, num_tiles)
|
||||||
|
|
||||||
|
# iterate through tiles (now limited to the sliding window range)
|
||||||
|
for j in range(tile_start, tile_end):
|
||||||
|
seq_offset = j * TILE_SIZE + offs_t
|
||||||
|
tile_mask = seq_offset < max_seq_prefix_len
|
||||||
|
|
||||||
|
physical_block_idx = tl.load(block_tables_ptr + block_table_offset +
|
||||||
|
seq_offset // BLOCK_SIZE).to(tl.int64)
|
||||||
|
|
||||||
|
v_offset = (physical_block_idx[:, None] * stride_v_cache_0 +
|
||||||
|
kv_head_idx * stride_v_cache_2 +
|
||||||
|
offs_d[None, :] * stride_v_cache_3 +
|
||||||
|
(seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1)
|
||||||
|
|
||||||
|
k_offset = (physical_block_idx[None, :] * stride_k_cache_0 +
|
||||||
|
kv_head_idx * stride_k_cache_2 +
|
||||||
|
offs_d[:, None] * stride_k_cache_3 +
|
||||||
|
(seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1)
|
||||||
|
|
||||||
|
# K : (HEAD_SIZE, TILE_SIZE)
|
||||||
|
K_load = tl.load(key_cache_ptr + k_offset,
|
||||||
|
mask=dim_mask[:, None] & tile_mask[None, :],
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
|
if K_load.dtype.is_fp8():
|
||||||
|
if Q.dtype.is_fp8():
|
||||||
|
K = K_load
|
||||||
|
else:
|
||||||
|
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
||||||
|
else:
|
||||||
|
K = K_load
|
||||||
|
|
||||||
|
# V : (TILE_SIZE, HEAD_SIZE)
|
||||||
|
V_load = tl.load(value_cache_ptr + v_offset,
|
||||||
|
mask=dim_mask[None, :] & tile_mask[:, None],
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
|
if V_load.dtype.is_fp8():
|
||||||
|
if Q.dtype.is_fp8():
|
||||||
|
V = V_load
|
||||||
|
else:
|
||||||
|
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
||||||
|
else:
|
||||||
|
V = V_load
|
||||||
|
|
||||||
|
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
||||||
|
|
||||||
|
# S : (BLOCK_M, TILE_SIZE)
|
||||||
|
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||||
|
|
||||||
|
S += scale * tl.dot(Q, K)
|
||||||
|
|
||||||
|
if USE_SOFTCAP:
|
||||||
|
S = apply_softcap(S, softcap)
|
||||||
|
|
||||||
|
S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask,
|
||||||
|
S, float("-inf"))
|
||||||
|
|
||||||
|
if SLIDING_WINDOW > 0:
|
||||||
|
S = tl.where((context_len + query_pos[:, None] - seq_offset)
|
||||||
|
< SLIDING_WINDOW, S, float("-inf"))
|
||||||
|
|
||||||
|
if USE_ALIBI_SLOPES:
|
||||||
|
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||||
|
|
||||||
|
if USE_QQ_BIAS:
|
||||||
|
# compute key positions relative to query section
|
||||||
|
key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE]
|
||||||
|
# load bias only for keys that correspond to queries
|
||||||
|
is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0
|
||||||
|
qq_bias = tl.load(
|
||||||
|
qq_bias_row_ptrs + key_rel_pos[None, :],
|
||||||
|
mask=is_query_key[None, :], # avoid OOB for context keys
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
S += qq_bias
|
||||||
|
|
||||||
|
# compute running maximum
|
||||||
|
# m_j : (BLOCK_M,)
|
||||||
|
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||||
|
|
||||||
|
# For sliding window there's a chance the max is -inf due to masking of
|
||||||
|
# the entire row. In this case we need to set m_j 0 to avoid NaN
|
||||||
|
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
|
||||||
|
|
||||||
|
# P : (BLOCK_M, TILE_SIZE)
|
||||||
|
P = tl.exp(S - m_j[:, None])
|
||||||
|
|
||||||
|
# l_j : (BLOCK_M,)
|
||||||
|
l_j = tl.sum(P, axis=1)
|
||||||
|
|
||||||
|
# alpha : (BLOCK_M, )
|
||||||
|
alpha = tl.exp(M - m_j)
|
||||||
|
|
||||||
|
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||||
|
acc = acc * alpha[:, None]
|
||||||
|
|
||||||
|
# update constants
|
||||||
|
L = L * alpha + l_j
|
||||||
|
M = m_j
|
||||||
|
|
||||||
|
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||||
|
acc += tl.dot(P.to(V.dtype), V)
|
||||||
|
|
||||||
|
# epilogue
|
||||||
|
acc = acc / L[:, None]
|
||||||
|
if USE_FP8:
|
||||||
|
acc = acc * tl.load(out_scale)
|
||||||
|
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||||
|
|
||||||
|
output_offset = (query_offset_0[:, None] * output_stride_0 +
|
||||||
|
query_offset_1[:, None] * output_stride_1 +
|
||||||
|
offs_d[None, :])
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
output_ptr + output_offset,
|
||||||
|
acc,
|
||||||
|
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def kernel_unified_attention_3d(
|
||||||
|
segm_output_ptr,
|
||||||
|
# [num_tokens, num_query_heads, num_segments, head_size]
|
||||||
|
segm_max_ptr, # [num_tokens, num_query_heads, num_segments]
|
||||||
|
segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments]
|
||||||
|
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||||
|
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
|
||||||
|
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
|
||||||
|
sink_ptr, # [num_query_heads]
|
||||||
|
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||||
|
seq_lens_ptr, # [num_seqs]
|
||||||
|
alibi_slopes_ptr, # [num_query_heads]
|
||||||
|
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
|
||||||
|
scale, # float32
|
||||||
|
k_scale, # float32
|
||||||
|
v_scale, # float32
|
||||||
|
softcap, # float32
|
||||||
|
num_query_heads: tl.constexpr, # int
|
||||||
|
num_queries_per_kv: tl.constexpr, # int
|
||||||
|
block_table_stride: tl.int64, # int
|
||||||
|
query_stride_0: tl.int64, # int
|
||||||
|
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||||
|
qq_bias_stride_0: tl.int64, # int
|
||||||
|
BLOCK_SIZE: tl.constexpr, # int
|
||||||
|
TILE_SIZE: tl.constexpr, # int, must be power of 2
|
||||||
|
HEAD_SIZE: tl.constexpr, # int
|
||||||
|
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||||
|
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||||
|
USE_QQ_BIAS: tl.constexpr, # bool
|
||||||
|
USE_SOFTCAP: tl.constexpr, # bool
|
||||||
|
USE_SINKS: tl.constexpr, # bool
|
||||||
|
SLIDING_WINDOW: tl.constexpr, # int
|
||||||
|
stride_k_cache_0: tl.int64, # int
|
||||||
|
stride_k_cache_1: tl.int64, # int
|
||||||
|
stride_k_cache_2: tl.int64, # int
|
||||||
|
stride_k_cache_3: tl.constexpr, # int
|
||||||
|
stride_v_cache_0: tl.int64, # int
|
||||||
|
stride_v_cache_1: tl.int64, # int
|
||||||
|
stride_v_cache_2: tl.int64, # int
|
||||||
|
stride_v_cache_3: tl.constexpr, # int
|
||||||
|
query_start_len_ptr, # [num_seqs+1]
|
||||||
|
BLOCK_Q: tl.constexpr, # int
|
||||||
|
num_seqs: tl.int32,
|
||||||
|
BLOCK_M: tl.constexpr, # int
|
||||||
|
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
||||||
|
):
|
||||||
|
q_block_global_idx = tl.program_id(0)
|
||||||
|
kv_head_idx = tl.program_id(1)
|
||||||
|
segm_idx = tl.program_id(2)
|
||||||
|
|
||||||
|
seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs,
|
||||||
|
BLOCK_Q, True)
|
||||||
|
|
||||||
|
q_block_start_idx = tl.load(query_start_len_ptr +
|
||||||
|
seq_idx) // BLOCK_Q + seq_idx
|
||||||
|
|
||||||
|
q_block_local_idx = q_block_global_idx - q_block_start_idx
|
||||||
|
|
||||||
|
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
||||||
|
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
|
||||||
|
|
||||||
|
cur_batch_query_len = cur_batch_in_all_stop_index \
|
||||||
|
- cur_batch_in_all_start_index
|
||||||
|
|
||||||
|
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
|
||||||
|
return
|
||||||
|
|
||||||
|
# sequence len for this particular sequence
|
||||||
|
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||||
|
|
||||||
|
# number of segments for this particular sequence
|
||||||
|
num_segments = NUM_SEGMENTS_PER_SEQ
|
||||||
|
tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE)
|
||||||
|
|
||||||
|
if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len:
|
||||||
|
return
|
||||||
|
|
||||||
|
offs_m = tl.arange(0, BLOCK_M)
|
||||||
|
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||||
|
offs_t = tl.arange(0, TILE_SIZE)
|
||||||
|
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
|
||||||
|
|
||||||
|
query_offset_0 = cur_batch_in_all_start_index + query_pos
|
||||||
|
query_offset_1 = kv_head_idx * num_queries_per_kv + \
|
||||||
|
offs_m % num_queries_per_kv
|
||||||
|
query_offset = (query_offset_0[:, None] * query_stride_0 +
|
||||||
|
query_offset_1[:, None] * query_stride_1 + offs_d[None, :])
|
||||||
|
|
||||||
|
dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
|
||||||
|
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
|
||||||
|
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
|
||||||
|
|
||||||
|
# Q : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||||
|
Q = tl.load(
|
||||||
|
query_ptr + query_offset,
|
||||||
|
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
block_table_offset = seq_idx * block_table_stride
|
||||||
|
|
||||||
|
if USE_SINKS:
|
||||||
|
if segm_idx == 0:
|
||||||
|
M = tl.load(
|
||||||
|
sink_ptr + query_offset_1,
|
||||||
|
mask=query_mask_1,
|
||||||
|
other=float("-inf"),
|
||||||
|
).to(dtype=tl.float32)
|
||||||
|
else:
|
||||||
|
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||||
|
else:
|
||||||
|
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||||
|
|
||||||
|
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||||
|
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
|
||||||
|
|
||||||
|
# context length for this particular sequences
|
||||||
|
context_len = seq_len - cur_batch_query_len
|
||||||
|
|
||||||
|
# alibi slope for this head
|
||||||
|
if USE_ALIBI_SLOPES:
|
||||||
|
alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1,
|
||||||
|
mask=query_mask_1,
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
|
# query-query attention bias
|
||||||
|
if USE_QQ_BIAS:
|
||||||
|
qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
|
||||||
|
) # shape: [BLOCK_M]
|
||||||
|
|
||||||
|
# compute the length of the longest sequence prefix spanned by any
|
||||||
|
# query token in the current q_block (q_block_local_idx)
|
||||||
|
max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
|
||||||
|
BLOCK_M - 1) // num_queries_per_kv + 1
|
||||||
|
|
||||||
|
# adjust for potential padding in the last q_block by considering the
|
||||||
|
# actual sequence length
|
||||||
|
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
|
||||||
|
|
||||||
|
# calculate the number of tiles that need to be processed to
|
||||||
|
# cover the longest sequence prefix (due to causal masking, tiles beyond
|
||||||
|
# this prefix can be skipped)
|
||||||
|
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
|
||||||
|
|
||||||
|
# iterate through tiles within current segment
|
||||||
|
for j in range(
|
||||||
|
segm_idx * tiles_per_segment,
|
||||||
|
min((segm_idx + 1) * tiles_per_segment, num_tiles),
|
||||||
|
):
|
||||||
|
seq_offset = j * TILE_SIZE + offs_t
|
||||||
|
tile_mask = seq_offset < max_seq_prefix_len
|
||||||
|
|
||||||
|
physical_block_idx = tl.load(block_tables_ptr + block_table_offset +
|
||||||
|
seq_offset // BLOCK_SIZE).to(tl.int64)
|
||||||
|
|
||||||
|
v_offset = (physical_block_idx[:, None] * stride_v_cache_0 +
|
||||||
|
kv_head_idx * stride_v_cache_2 +
|
||||||
|
offs_d[None, :] * stride_v_cache_3 +
|
||||||
|
(seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1)
|
||||||
|
|
||||||
|
k_offset = (physical_block_idx[None, :] * stride_k_cache_0 +
|
||||||
|
kv_head_idx * stride_k_cache_2 +
|
||||||
|
offs_d[:, None] * stride_k_cache_3 +
|
||||||
|
(seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1)
|
||||||
|
|
||||||
|
# K : (HEAD_SIZE, TILE_SIZE)
|
||||||
|
K_load = tl.load(key_cache_ptr + k_offset,
|
||||||
|
mask=dim_mask[:, None] & tile_mask[None, :],
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
|
if K_load.dtype.is_fp8():
|
||||||
|
if Q.dtype.is_fp8():
|
||||||
|
K = K_load
|
||||||
|
else:
|
||||||
|
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
||||||
|
else:
|
||||||
|
K = K_load
|
||||||
|
|
||||||
|
# V : (TILE_SIZE, HEAD_SIZE)
|
||||||
|
V_load = tl.load(value_cache_ptr + v_offset,
|
||||||
|
mask=dim_mask[None, :] & tile_mask[:, None],
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
|
if V_load.dtype.is_fp8():
|
||||||
|
if Q.dtype.is_fp8():
|
||||||
|
V = V_load
|
||||||
|
else:
|
||||||
|
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
||||||
|
else:
|
||||||
|
V = V_load
|
||||||
|
|
||||||
|
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
||||||
|
|
||||||
|
# S : (BLOCK_M, TILE_SIZE)
|
||||||
|
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||||
|
S += scale * tl.dot(Q, K)
|
||||||
|
|
||||||
|
if USE_SOFTCAP:
|
||||||
|
S = apply_softcap(S, softcap)
|
||||||
|
|
||||||
|
S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask,
|
||||||
|
S, float("-inf"))
|
||||||
|
|
||||||
|
if SLIDING_WINDOW > 0:
|
||||||
|
S = tl.where((context_len + query_pos[:, None] - seq_offset)
|
||||||
|
< SLIDING_WINDOW, S, float("-inf"))
|
||||||
|
|
||||||
|
if USE_ALIBI_SLOPES:
|
||||||
|
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||||
|
|
||||||
|
if USE_QQ_BIAS:
|
||||||
|
# compute key positions relative to query section
|
||||||
|
key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE]
|
||||||
|
# load bias only for keys that correspond to queries
|
||||||
|
is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0
|
||||||
|
qq_bias = tl.load(
|
||||||
|
qq_bias_row_ptrs + key_rel_pos[None, :],
|
||||||
|
mask=is_query_key[None, :], # avoid OOB for context keys
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
S += qq_bias
|
||||||
|
|
||||||
|
# compute running maximum
|
||||||
|
# m_j : (BLOCK_M,)
|
||||||
|
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||||
|
|
||||||
|
# For sliding window there's a chance the max is -inf due to masking of
|
||||||
|
# the entire row. In this case we need to set m_j 0 to avoid NaN
|
||||||
|
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
|
||||||
|
|
||||||
|
# P : (BLOCK_M, TILE_SIZE,)
|
||||||
|
P = tl.exp(S - m_j[:, None])
|
||||||
|
|
||||||
|
# l_j : (BLOCK_M,)
|
||||||
|
l_j = tl.sum(P, axis=1)
|
||||||
|
|
||||||
|
# alpha : (BLOCK_M, )
|
||||||
|
alpha = tl.exp(M - m_j)
|
||||||
|
|
||||||
|
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||||
|
acc = acc * alpha[:, None]
|
||||||
|
|
||||||
|
# update constants
|
||||||
|
L = L * alpha + l_j
|
||||||
|
M = m_j
|
||||||
|
|
||||||
|
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||||
|
acc += tl.dot(P.to(V.dtype), V)
|
||||||
|
|
||||||
|
segm_output_offset = (
|
||||||
|
query_offset_0[:, None].to(tl.int64) *
|
||||||
|
(num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
|
||||||
|
query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
|
||||||
|
segm_idx * HEAD_SIZE_PADDED + tl.arange(0, HEAD_SIZE_PADDED)[None, :])
|
||||||
|
tl.store(
|
||||||
|
segm_output_ptr + segm_output_offset,
|
||||||
|
acc,
|
||||||
|
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||||
|
)
|
||||||
|
segm_offset = (query_offset_0.to(tl.int64) *
|
||||||
|
(num_query_heads * NUM_SEGMENTS_PER_SEQ) +
|
||||||
|
query_offset_1 * NUM_SEGMENTS_PER_SEQ + segm_idx)
|
||||||
|
tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1)
|
||||||
|
tl.store(segm_expsum_ptr + segm_offset,
|
||||||
|
L,
|
||||||
|
mask=query_mask_0 & query_mask_1)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def reduce_segments(
|
||||||
|
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||||
|
segm_output_ptr,
|
||||||
|
#[num_tokens, num_query_heads, max_num_segments, head_size]
|
||||||
|
segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments]
|
||||||
|
segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments]
|
||||||
|
seq_lens_ptr, # [num_seqs]
|
||||||
|
num_seqs, # int
|
||||||
|
num_query_heads: tl.constexpr, # int
|
||||||
|
out_scale_inv, # float32
|
||||||
|
output_stride_0: tl.int64, # int
|
||||||
|
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||||
|
block_table_stride: tl.int64, # int
|
||||||
|
TILE_SIZE: tl.constexpr, # int
|
||||||
|
HEAD_SIZE: tl.constexpr, # int, must be power of 2
|
||||||
|
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||||
|
query_start_len_ptr, # [num_seqs+1]
|
||||||
|
BLOCK_Q: tl.constexpr, # int
|
||||||
|
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
||||||
|
USE_FP8: tl.constexpr, # bool
|
||||||
|
FP8_MIN: tl.constexpr = float8_info.min,
|
||||||
|
FP8_MAX: tl.constexpr = float8_info.max,
|
||||||
|
):
|
||||||
|
query_token_idx = tl.program_id(0)
|
||||||
|
query_head_idx = tl.program_id(1)
|
||||||
|
|
||||||
|
seq_idx = find_seq_idx(query_start_len_ptr, query_token_idx, num_seqs,
|
||||||
|
BLOCK_Q, False)
|
||||||
|
|
||||||
|
# sequence len for this particular sequence
|
||||||
|
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||||
|
|
||||||
|
# number of segments for this particular sequence
|
||||||
|
num_segments = NUM_SEGMENTS_PER_SEQ
|
||||||
|
tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE)
|
||||||
|
|
||||||
|
# create masks for subsequent loads
|
||||||
|
act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE)
|
||||||
|
segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full(
|
||||||
|
[NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32)
|
||||||
|
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
|
||||||
|
0).to(tl.int1)
|
||||||
|
|
||||||
|
# load segment maxima
|
||||||
|
segm_offset = (query_token_idx.to(tl.int64) *
|
||||||
|
(num_query_heads * NUM_SEGMENTS_PER_SEQ) +
|
||||||
|
query_head_idx * NUM_SEGMENTS_PER_SEQ +
|
||||||
|
tl.arange(0, NUM_SEGMENTS_PER_SEQ))
|
||||||
|
segm_max = tl.load(segm_max_ptr + segm_offset,
|
||||||
|
mask=segm_mask,
|
||||||
|
other=float("-inf"))
|
||||||
|
overall_max = tl.max(segm_max)
|
||||||
|
|
||||||
|
# load and rescale segment exp sums
|
||||||
|
segm_expsum = tl.load(segm_expsum_ptr + segm_offset,
|
||||||
|
mask=segm_mask,
|
||||||
|
other=0.0)
|
||||||
|
segm_expsum = segm_expsum * tl.exp(segm_max - overall_max)
|
||||||
|
overall_expsum = tl.sum(segm_expsum)
|
||||||
|
|
||||||
|
# load, rescale, and add segment attention outputs
|
||||||
|
segm_output_offset = (
|
||||||
|
query_token_idx.to(tl.int64) *
|
||||||
|
(num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
|
||||||
|
query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
|
||||||
|
tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED +
|
||||||
|
tl.arange(0, HEAD_SIZE_PADDED)[None, :])
|
||||||
|
segm_output = tl.load(
|
||||||
|
segm_output_ptr + segm_output_offset,
|
||||||
|
mask=segm_mask[:, None] & dim_mask[None, :],
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
segm_output *= tl.exp(segm_max - overall_max)[:, None]
|
||||||
|
acc_sum = tl.sum(segm_output, axis=0)
|
||||||
|
# safely divide by overall_expsum, returning 0.0 if overall_expsum is 0
|
||||||
|
acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum)
|
||||||
|
|
||||||
|
if USE_FP8:
|
||||||
|
acc = acc * tl.load(out_scale_inv)
|
||||||
|
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||||
|
|
||||||
|
# write result
|
||||||
|
output_offset = (query_token_idx * output_stride_0 +
|
||||||
|
query_head_idx * output_stride_1 +
|
||||||
|
tl.arange(0, HEAD_SIZE_PADDED))
|
||||||
|
tl.store(output_ptr + output_offset, acc, mask=dim_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def unified_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens_q,
|
||||||
|
max_seqlen_q,
|
||||||
|
seqused_k,
|
||||||
|
max_seqlen_k,
|
||||||
|
softmax_scale,
|
||||||
|
causal,
|
||||||
|
window_size,
|
||||||
|
block_table,
|
||||||
|
softcap,
|
||||||
|
q_descale,
|
||||||
|
k_descale,
|
||||||
|
v_descale,
|
||||||
|
alibi_slopes=None,
|
||||||
|
output_scale=None,
|
||||||
|
qq_bias=None,
|
||||||
|
# Optional tensor for sinks
|
||||||
|
sinks=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
assert causal, "Only causal attention is supported"
|
||||||
|
assert q_descale is None, "Q scales not supported"
|
||||||
|
|
||||||
|
if sinks is not None:
|
||||||
|
assert sinks.shape[0] == q.shape[1], \
|
||||||
|
"Sinks must be num_query_heads size"
|
||||||
|
|
||||||
|
use_alibi_slopes = alibi_slopes is not None
|
||||||
|
use_qq_bias = qq_bias is not None
|
||||||
|
|
||||||
|
block_size = v.shape[1]
|
||||||
|
num_seqs = len(seqused_k)
|
||||||
|
num_query_heads = q.shape[1]
|
||||||
|
num_kv_heads = k.shape[2]
|
||||||
|
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||||
|
head_size = q.shape[2]
|
||||||
|
|
||||||
|
BLOCK_M = 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(
|
||||||
|
num_queries_per_kv)
|
||||||
|
BLOCK_Q = BLOCK_M // num_queries_per_kv
|
||||||
|
|
||||||
|
# Ideally we would launch with kernel with:
|
||||||
|
# \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks.
|
||||||
|
# However, it is slow to realize the query_lens on cpu.
|
||||||
|
# Instead we use upper-bound:
|
||||||
|
# \sum_i[ceil(query_len[i] / BLOCK_Q)]
|
||||||
|
# <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1]
|
||||||
|
# = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs
|
||||||
|
# <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs
|
||||||
|
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
|
||||||
|
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
|
||||||
|
|
||||||
|
# Assigning default tile sizes for prefill and decode.
|
||||||
|
# Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1)
|
||||||
|
# and at least 16 for all other data types.
|
||||||
|
TILE_SIZE_PREFILL = 32
|
||||||
|
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
|
||||||
|
|
||||||
|
# if batch contains a prefill
|
||||||
|
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
|
||||||
|
kernel_unified_attention_2d[(
|
||||||
|
total_num_q_blocks,
|
||||||
|
num_kv_heads,
|
||||||
|
)](
|
||||||
|
output_ptr=out,
|
||||||
|
query_ptr=q,
|
||||||
|
key_cache_ptr=k,
|
||||||
|
value_cache_ptr=v,
|
||||||
|
sink_ptr=sinks,
|
||||||
|
block_tables_ptr=block_table,
|
||||||
|
seq_lens_ptr=seqused_k,
|
||||||
|
alibi_slopes_ptr=alibi_slopes,
|
||||||
|
qq_bias_ptr=qq_bias,
|
||||||
|
scale=softmax_scale,
|
||||||
|
k_scale=k_descale,
|
||||||
|
v_scale=v_descale,
|
||||||
|
out_scale=1 / output_scale if output_scale is not None else 1.0,
|
||||||
|
softcap=softcap,
|
||||||
|
num_query_heads=num_query_heads,
|
||||||
|
num_queries_per_kv=num_queries_per_kv,
|
||||||
|
block_table_stride=block_table.stride(0),
|
||||||
|
query_stride_0=q.stride(0),
|
||||||
|
query_stride_1=q.stride(1),
|
||||||
|
output_stride_0=out.stride(0),
|
||||||
|
output_stride_1=out.stride(1),
|
||||||
|
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
|
||||||
|
BLOCK_SIZE=block_size,
|
||||||
|
TILE_SIZE=TILE_SIZE_PREFILL,
|
||||||
|
HEAD_SIZE=head_size,
|
||||||
|
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||||
|
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||||
|
USE_QQ_BIAS=use_qq_bias,
|
||||||
|
USE_SOFTCAP=(softcap > 0),
|
||||||
|
USE_SINKS=(sinks is not None),
|
||||||
|
SLIDING_WINDOW=(1 + window_size[0]),
|
||||||
|
stride_k_cache_0=k.stride(0),
|
||||||
|
stride_k_cache_1=k.stride(1),
|
||||||
|
stride_k_cache_2=k.stride(2),
|
||||||
|
stride_k_cache_3=k.stride(3),
|
||||||
|
stride_v_cache_0=v.stride(0),
|
||||||
|
stride_v_cache_1=v.stride(1),
|
||||||
|
stride_v_cache_2=v.stride(2),
|
||||||
|
stride_v_cache_3=v.stride(3),
|
||||||
|
query_start_len_ptr=cu_seqlens_q,
|
||||||
|
BLOCK_Q=BLOCK_Q,
|
||||||
|
num_seqs=num_seqs,
|
||||||
|
BLOCK_M=BLOCK_M,
|
||||||
|
USE_FP8=output_scale is not None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
|
||||||
|
# value that showed good performance in tests
|
||||||
|
NUM_SEGMENTS = 16
|
||||||
|
|
||||||
|
segm_output = torch.empty(
|
||||||
|
q.shape[0],
|
||||||
|
num_query_heads,
|
||||||
|
NUM_SEGMENTS,
|
||||||
|
triton.next_power_of_2(head_size),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=q.device,
|
||||||
|
)
|
||||||
|
segm_max = torch.empty(
|
||||||
|
q.shape[0],
|
||||||
|
num_query_heads,
|
||||||
|
NUM_SEGMENTS,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=q.device,
|
||||||
|
)
|
||||||
|
segm_expsum = torch.empty(
|
||||||
|
q.shape[0],
|
||||||
|
num_query_heads,
|
||||||
|
NUM_SEGMENTS,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=q.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
kernel_unified_attention_3d[(
|
||||||
|
total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)](
|
||||||
|
segm_output_ptr=segm_output,
|
||||||
|
segm_max_ptr=segm_max,
|
||||||
|
segm_expsum_ptr=segm_expsum,
|
||||||
|
query_ptr=q,
|
||||||
|
key_cache_ptr=k,
|
||||||
|
value_cache_ptr=v,
|
||||||
|
sink_ptr=sinks,
|
||||||
|
block_tables_ptr=block_table,
|
||||||
|
seq_lens_ptr=seqused_k,
|
||||||
|
alibi_slopes_ptr=alibi_slopes,
|
||||||
|
qq_bias_ptr=qq_bias,
|
||||||
|
scale=softmax_scale,
|
||||||
|
k_scale=k_descale,
|
||||||
|
v_scale=v_descale,
|
||||||
|
softcap=softcap,
|
||||||
|
num_query_heads=num_query_heads,
|
||||||
|
num_queries_per_kv=num_queries_per_kv,
|
||||||
|
block_table_stride=block_table.stride(0),
|
||||||
|
query_stride_0=q.stride(0),
|
||||||
|
query_stride_1=q.stride(1),
|
||||||
|
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
|
||||||
|
BLOCK_SIZE=block_size,
|
||||||
|
TILE_SIZE=TILE_SIZE_DECODE,
|
||||||
|
HEAD_SIZE=head_size,
|
||||||
|
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||||
|
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||||
|
USE_QQ_BIAS=use_qq_bias,
|
||||||
|
USE_SOFTCAP=(softcap > 0),
|
||||||
|
USE_SINKS=(sinks is not None),
|
||||||
|
SLIDING_WINDOW=(1 + window_size[0]),
|
||||||
|
stride_k_cache_0=k.stride(0),
|
||||||
|
stride_k_cache_1=k.stride(1),
|
||||||
|
stride_k_cache_2=k.stride(2),
|
||||||
|
stride_k_cache_3=k.stride(3),
|
||||||
|
stride_v_cache_0=v.stride(0),
|
||||||
|
stride_v_cache_1=v.stride(1),
|
||||||
|
stride_v_cache_2=v.stride(2),
|
||||||
|
stride_v_cache_3=v.stride(3),
|
||||||
|
query_start_len_ptr=cu_seqlens_q,
|
||||||
|
BLOCK_Q=BLOCK_Q,
|
||||||
|
num_seqs=num_seqs,
|
||||||
|
BLOCK_M=BLOCK_M,
|
||||||
|
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
|
||||||
|
)
|
||||||
|
reduce_segments[(q.shape[0], num_query_heads)](
|
||||||
|
output_ptr=out,
|
||||||
|
segm_output_ptr=segm_output,
|
||||||
|
segm_max_ptr=segm_max,
|
||||||
|
segm_expsum_ptr=segm_expsum,
|
||||||
|
seq_lens_ptr=seqused_k,
|
||||||
|
num_seqs=num_seqs,
|
||||||
|
num_query_heads=num_query_heads,
|
||||||
|
out_scale_inv=1 /
|
||||||
|
output_scale if output_scale is not None else 1.0,
|
||||||
|
output_stride_0=out.stride(0),
|
||||||
|
output_stride_1=out.stride(1),
|
||||||
|
block_table_stride=block_table.stride(0),
|
||||||
|
TILE_SIZE=TILE_SIZE_DECODE,
|
||||||
|
HEAD_SIZE=head_size,
|
||||||
|
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||||
|
query_start_len_ptr=cu_seqlens_q,
|
||||||
|
BLOCK_Q=BLOCK_Q,
|
||||||
|
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
|
||||||
|
USE_FP8=output_scale is not None,
|
||||||
|
)
|
||||||
245
vllm/attention/selector.py
Normal file
245
vllm/attention/selector.py
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import os
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import cache
|
||||||
|
from typing import Generator, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import _Backend, current_platform
|
||||||
|
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
|
||||||
|
"""
|
||||||
|
Convert a string backend name to a _Backend enum value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
* _Backend: enum value if backend_name is a valid in-tree type
|
||||||
|
* None: otherwise it's an invalid in-tree type or an out-of-tree platform is
|
||||||
|
loaded.
|
||||||
|
"""
|
||||||
|
assert backend_name is not None
|
||||||
|
return _Backend[backend_name] if backend_name in _Backend.__members__ else \
|
||||||
|
None
|
||||||
|
|
||||||
|
|
||||||
|
def get_env_variable_attn_backend() -> Optional[_Backend]:
|
||||||
|
'''
|
||||||
|
Get the backend override specified by the vLLM attention
|
||||||
|
backend environment variable, if one is specified.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* _Backend enum value if an override is specified
|
||||||
|
* None otherwise
|
||||||
|
'''
|
||||||
|
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
|
||||||
|
return (None
|
||||||
|
if backend_name is None else backend_name_to_enum(backend_name))
|
||||||
|
|
||||||
|
|
||||||
|
# Global state allows a particular choice of backend
|
||||||
|
# to be forced, overriding the logic which auto-selects
|
||||||
|
# a backend based on system & workload configuration
|
||||||
|
# (default behavior if this variable is None)
|
||||||
|
#
|
||||||
|
# THIS SELECTION TAKES PRECEDENCE OVER THE
|
||||||
|
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
|
||||||
|
forced_attn_backend: Optional[_Backend] = None
|
||||||
|
|
||||||
|
|
||||||
|
def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
|
||||||
|
'''
|
||||||
|
Force all attention operations to use a specified backend.
|
||||||
|
|
||||||
|
Passing `None` for the argument re-enables automatic
|
||||||
|
backend selection.,
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* attn_backend: backend selection (None to revert to auto)
|
||||||
|
'''
|
||||||
|
global forced_attn_backend
|
||||||
|
forced_attn_backend = attn_backend
|
||||||
|
|
||||||
|
|
||||||
|
def get_global_forced_attn_backend() -> Optional[_Backend]:
|
||||||
|
'''
|
||||||
|
Get the currently-forced choice of attention backend,
|
||||||
|
or None if auto-selection is currently enabled.
|
||||||
|
'''
|
||||||
|
return forced_attn_backend
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _IsSupported:
|
||||||
|
can_import: bool
|
||||||
|
head_size: bool
|
||||||
|
dtype: bool
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
return self.can_import and self.head_size and self.dtype
|
||||||
|
|
||||||
|
|
||||||
|
def is_attn_backend_supported(
|
||||||
|
attn_backend: Union[str, type[AttentionBackend]],
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
*,
|
||||||
|
allow_import_error: bool = True,
|
||||||
|
) -> _IsSupported:
|
||||||
|
if isinstance(attn_backend, str):
|
||||||
|
try:
|
||||||
|
attn_backend = resolve_obj_by_qualname(attn_backend)
|
||||||
|
except ImportError:
|
||||||
|
if not allow_import_error:
|
||||||
|
raise
|
||||||
|
|
||||||
|
return _IsSupported(can_import=False, head_size=False, dtype=False)
|
||||||
|
|
||||||
|
assert isinstance(attn_backend, type)
|
||||||
|
|
||||||
|
# TODO: Update the interface once V0 is removed
|
||||||
|
if get_supported_head_sizes := getattr(attn_backend,
|
||||||
|
"get_supported_head_sizes", None):
|
||||||
|
is_head_size_supported = head_size in get_supported_head_sizes()
|
||||||
|
elif validate_head_size := getattr(attn_backend, "validate_head_size",
|
||||||
|
None):
|
||||||
|
try:
|
||||||
|
validate_head_size(head_size)
|
||||||
|
is_head_size_supported = True
|
||||||
|
except Exception:
|
||||||
|
is_head_size_supported = False
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"{attn_backend.__name__} does not support "
|
||||||
|
"head size validation")
|
||||||
|
|
||||||
|
if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes",
|
||||||
|
None):
|
||||||
|
is_dtype_supported = dtype in get_supported_dtypes()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"{attn_backend.__name__} does not support "
|
||||||
|
"dtype validation")
|
||||||
|
|
||||||
|
return _IsSupported(
|
||||||
|
can_import=True,
|
||||||
|
head_size=is_head_size_supported,
|
||||||
|
dtype=is_dtype_supported,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_attn_backend(
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: Optional[str],
|
||||||
|
block_size: int,
|
||||||
|
use_mla: bool = False,
|
||||||
|
has_sink: bool = False,
|
||||||
|
use_sparse: bool = False,
|
||||||
|
) -> type[AttentionBackend]:
|
||||||
|
"""Selects which attention backend to use and lazily imports it."""
|
||||||
|
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||||
|
# value to be returned from the cache if the value changes between calls.
|
||||||
|
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
|
||||||
|
# private function.
|
||||||
|
return _cached_get_attn_backend(
|
||||||
|
head_size=head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
block_size=block_size,
|
||||||
|
use_v1=envs.VLLM_USE_V1,
|
||||||
|
use_mla=use_mla,
|
||||||
|
has_sink=has_sink,
|
||||||
|
use_sparse=use_sparse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def _cached_get_attn_backend(
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: Optional[str],
|
||||||
|
block_size: int,
|
||||||
|
use_v1: bool = False,
|
||||||
|
use_mla: bool = False,
|
||||||
|
has_sink: bool = False,
|
||||||
|
use_sparse: bool = False,
|
||||||
|
) -> type[AttentionBackend]:
|
||||||
|
|
||||||
|
# Check whether a particular choice of backend was
|
||||||
|
# previously forced.
|
||||||
|
#
|
||||||
|
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||||
|
# ENVIRONMENT VARIABLE.
|
||||||
|
selected_backend = None
|
||||||
|
backend_by_global_setting: Optional[_Backend] = (
|
||||||
|
get_global_forced_attn_backend())
|
||||||
|
if backend_by_global_setting is not None:
|
||||||
|
selected_backend = backend_by_global_setting
|
||||||
|
else:
|
||||||
|
# Check the environment variable and override if specified
|
||||||
|
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||||
|
if backend_by_env_var is not None:
|
||||||
|
if backend_by_env_var.endswith("_VLLM_V1"):
|
||||||
|
logger.warning(
|
||||||
|
"The suffix '_VLLM_V1' in the environment variable "
|
||||||
|
"%s is no longer necessary as V0 backends have been "
|
||||||
|
"deprecated. Please remove this suffix from your "
|
||||||
|
"environment variable setting.", STR_BACKEND_ENV_VAR)
|
||||||
|
backend_by_env_var = backend_by_env_var.removesuffix(
|
||||||
|
"_VLLM_V1")
|
||||||
|
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||||
|
if selected_backend is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid attention backend: '{backend_by_env_var}'. "
|
||||||
|
f"Valid backends are: {list(_Backend.__members__.keys())}")
|
||||||
|
|
||||||
|
# get device-specific attn_backend
|
||||||
|
attention_cls = current_platform.get_attn_backend_cls(
|
||||||
|
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
|
||||||
|
use_mla, has_sink, use_sparse)
|
||||||
|
if not attention_cls:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid attention backend for {current_platform.device_name}")
|
||||||
|
return resolve_obj_by_qualname(attention_cls)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def global_force_attn_backend_context_manager(
|
||||||
|
attn_backend: _Backend) -> Generator[None, None, None]:
|
||||||
|
'''
|
||||||
|
Globally force a vLLM attention backend override within a
|
||||||
|
context manager, reverting the global attention backend
|
||||||
|
override to its prior state upon exiting the context
|
||||||
|
manager.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* attn_backend: attention backend to force
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* Generator
|
||||||
|
'''
|
||||||
|
|
||||||
|
# Save the current state of the global backend override (if any)
|
||||||
|
original_value = get_global_forced_attn_backend()
|
||||||
|
|
||||||
|
# Globally force the new backend override
|
||||||
|
global_force_attn_backend(attn_backend)
|
||||||
|
|
||||||
|
# Yield control back to the enclosed code block
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# Revert the original global backend override, if any
|
||||||
|
global_force_attn_backend(original_value)
|
||||||
0
vllm/attention/utils/__init__.py
Normal file
0
vllm/attention/utils/__init__.py
Normal file
BIN
vllm/attention/utils/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/attention/utils/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/attention/utils/__pycache__/fa_utils.cpython-310.pyc
Normal file
BIN
vllm/attention/utils/__pycache__/fa_utils.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
85
vllm/attention/utils/fa_utils.py
Normal file
85
vllm/attention/utils/fa_utils.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||||
|
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||||
|
get_scheduler_metadata)
|
||||||
|
elif current_platform.is_xpu():
|
||||||
|
from vllm._ipex_ops import ipex_ops as ops
|
||||||
|
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||||
|
flash_attn_varlen_func = ops.flash_attn_varlen_func
|
||||||
|
get_scheduler_metadata = ops.get_scheduler_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
|
||||||
|
# import here to avoid circular dependencies
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
if current_platform.is_xpu():
|
||||||
|
return 2
|
||||||
|
try:
|
||||||
|
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||||
|
fa_version_unsupported_reason, is_fa_version_supported)
|
||||||
|
device_capability = current_platform.get_device_capability()
|
||||||
|
|
||||||
|
assert device_capability is not None
|
||||||
|
|
||||||
|
# 1. default version depending on platform
|
||||||
|
fa_version = 3 if (device_capability.major == 9
|
||||||
|
and is_fa_version_supported(3)) else 2
|
||||||
|
|
||||||
|
# 2. override if passed by environment
|
||||||
|
if envs.VLLM_FLASH_ATTN_VERSION is not None:
|
||||||
|
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
|
||||||
|
fa_version = envs.VLLM_FLASH_ATTN_VERSION
|
||||||
|
|
||||||
|
# 3. fallback for unsupported combinations
|
||||||
|
if device_capability.major == 10 and fa_version == 3:
|
||||||
|
logger.warning_once(
|
||||||
|
"Cannot use FA version 3 on Blackwell platform "
|
||||||
|
"defaulting to FA version 2.")
|
||||||
|
fa_version = 2
|
||||||
|
|
||||||
|
if requires_alibi and fa_version == 3:
|
||||||
|
logger.warning_once("Cannot use FA version 3 with ALiBi, "
|
||||||
|
"defaulting to FA version 2.")
|
||||||
|
fa_version = 2
|
||||||
|
|
||||||
|
if not is_fa_version_supported(fa_version):
|
||||||
|
logger.error("Cannot use FA version %d is not supported due to %s",
|
||||||
|
fa_version, fa_version_unsupported_reason(fa_version))
|
||||||
|
|
||||||
|
assert is_fa_version_supported(fa_version)
|
||||||
|
return fa_version
|
||||||
|
except (ImportError, AssertionError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attn_supports_fp8() -> bool:
|
||||||
|
return get_flash_attn_version() == 3 and \
|
||||||
|
current_platform.get_device_capability().major == 9
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attn_supports_mla():
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
try:
|
||||||
|
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||||
|
is_fa_version_supported)
|
||||||
|
return is_fa_version_supported(3) \
|
||||||
|
and current_platform.get_device_capability()[0] == 9
|
||||||
|
except (ImportError, AssertionError):
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_flash_attn_varlen_func_available() -> bool:
|
||||||
|
return current_platform.is_cuda() or current_platform.is_xpu()
|
||||||
33
vllm/attention/utils/kv_sharing_utils.py
Normal file
33
vllm/attention/utils/kv_sharing_utils.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
def validate_kv_sharing_target(current_layer_name, target_layer_name,
|
||||||
|
static_forward_context):
|
||||||
|
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "
|
||||||
|
f"is not valid: target layer {target_layer_name} ")
|
||||||
|
|
||||||
|
if current_layer_name == target_layer_name:
|
||||||
|
raise ValueError(error_msg +
|
||||||
|
"cannot be the same as the current layer.")
|
||||||
|
|
||||||
|
if target_layer_name not in static_forward_context:
|
||||||
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
|
|
||||||
|
# If target layer name is not in the static fwd context, it means either
|
||||||
|
# a) the target layer does not come BEFORE the current layer, or
|
||||||
|
# b) the target layer is not an Attention layer that exists in the model
|
||||||
|
current_layer_idx = extract_layer_index(current_layer_name)
|
||||||
|
target_layer_idx = extract_layer_index(target_layer_name)
|
||||||
|
if current_layer_idx <= target_layer_idx:
|
||||||
|
raise ValueError(error_msg + "must come before the current layer.")
|
||||||
|
else:
|
||||||
|
raise ValueError(error_msg +
|
||||||
|
"is not a valid Attention layer in the model.")
|
||||||
|
|
||||||
|
# Currently KV sharing is only supported between layers of the same type
|
||||||
|
target_layer_attn_type = static_forward_context[
|
||||||
|
target_layer_name].attn_type
|
||||||
|
expected = static_forward_context[current_layer_name].attn_type
|
||||||
|
if target_layer_attn_type != expected:
|
||||||
|
raise ValueError(
|
||||||
|
error_msg +
|
||||||
|
f"must be the same type as the current layer ({expected}).")
|
||||||
87
vllm/beam_search.py
Normal file
87
vllm/beam_search.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
|
from vllm.logprobs import Logprob
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.multimodal import MultiModalDataDict
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BeamSearchSequence:
|
||||||
|
"""A sequence for beam search.
|
||||||
|
It keeps track of the tokens and the log probability of the sequence.
|
||||||
|
The text field is optional and will only be filled when the sequence is
|
||||||
|
about to be returned to the user.
|
||||||
|
"""
|
||||||
|
# The tokens include the prompt.
|
||||||
|
tokens: list[int]
|
||||||
|
logprobs: list[dict[int, Logprob]]
|
||||||
|
lora_request: Optional[LoRARequest] = None
|
||||||
|
cum_logprob: float = 0.0
|
||||||
|
text: Optional[str] = None
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
stop_reason: Union[int, str, None] = None
|
||||||
|
multi_modal_data: Optional["MultiModalDataDict"] = None
|
||||||
|
mm_processor_kwargs: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BeamSearchOutput:
|
||||||
|
"""The output of beam search.
|
||||||
|
It contains the list of the best beam search sequences.
|
||||||
|
The length of the list is equal to the beam width.
|
||||||
|
"""
|
||||||
|
sequences: list[BeamSearchSequence]
|
||||||
|
|
||||||
|
|
||||||
|
class BeamSearchInstance:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prompt_tokens: list[int],
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
logprobs: Optional[list[dict[int, Logprob]]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.beams: list[BeamSearchSequence] = [
|
||||||
|
BeamSearchSequence(
|
||||||
|
tokens=prompt_tokens,
|
||||||
|
logprobs=[] if logprobs is None else list(logprobs),
|
||||||
|
lora_request=lora_request,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
self.completed: list[BeamSearchSequence] = []
|
||||||
|
|
||||||
|
|
||||||
|
def get_beam_search_score(
|
||||||
|
tokens: list[int],
|
||||||
|
cumulative_logprob: float,
|
||||||
|
eos_token_id: int,
|
||||||
|
length_penalty: float = 1.0,
|
||||||
|
) -> float:
|
||||||
|
"""Calculate the beam search score with length penalty.
|
||||||
|
|
||||||
|
Adapted from
|
||||||
|
|
||||||
|
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
|
||||||
|
"""
|
||||||
|
seq_len = len(tokens)
|
||||||
|
if tokens[-1] == eos_token_id:
|
||||||
|
seq_len -= 1
|
||||||
|
|
||||||
|
return cumulative_logprob / (seq_len**length_penalty)
|
||||||
|
|
||||||
|
|
||||||
|
def create_sort_beams_key_function(eos_token_id: int, length_penalty: float):
|
||||||
|
|
||||||
|
def sort_beams_key(x: BeamSearchSequence) -> float:
|
||||||
|
return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id,
|
||||||
|
length_penalty)
|
||||||
|
|
||||||
|
return sort_beams_key
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user