init
This commit is contained in:
BIN
vllm/_C.abi3.so
Normal file
BIN
vllm/_C.abi3.so
Normal file
Binary file not shown.
102
vllm/__init__.py
Normal file
102
vllm/__init__.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
|
||||
|
||||
# The version.py should be independent library, and we always import the
|
||||
# version library first. Such assumption is critical for some customization.
|
||||
from .version import __version__, __version_tuple__ # isort:skip
|
||||
|
||||
import typing
|
||||
|
||||
# The environment variables override should be imported before any other
|
||||
# modules to ensure that the environment variables are set before any
|
||||
# other modules are imported.
|
||||
import vllm.env_override # noqa: F401
|
||||
|
||||
MODULE_ATTRS = {
|
||||
"bc_linter_skip": "._bc_linter:bc_linter_skip",
|
||||
"bc_linter_include": "._bc_linter:bc_linter_include",
|
||||
"AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs",
|
||||
"EngineArgs": ".engine.arg_utils:EngineArgs",
|
||||
"AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine",
|
||||
"LLMEngine": ".engine.llm_engine:LLMEngine",
|
||||
"LLM": ".entrypoints.llm:LLM",
|
||||
"initialize_ray_cluster": ".executor.ray_utils:initialize_ray_cluster",
|
||||
"PromptType": ".inputs:PromptType",
|
||||
"TextPrompt": ".inputs:TextPrompt",
|
||||
"TokensPrompt": ".inputs:TokensPrompt",
|
||||
"ModelRegistry": ".model_executor.models:ModelRegistry",
|
||||
"SamplingParams": ".sampling_params:SamplingParams",
|
||||
"PoolingParams": ".pooling_params:PoolingParams",
|
||||
"ClassificationOutput": ".outputs:ClassificationOutput",
|
||||
"ClassificationRequestOutput": ".outputs:ClassificationRequestOutput",
|
||||
"CompletionOutput": ".outputs:CompletionOutput",
|
||||
"EmbeddingOutput": ".outputs:EmbeddingOutput",
|
||||
"EmbeddingRequestOutput": ".outputs:EmbeddingRequestOutput",
|
||||
"PoolingOutput": ".outputs:PoolingOutput",
|
||||
"PoolingRequestOutput": ".outputs:PoolingRequestOutput",
|
||||
"RequestOutput": ".outputs:RequestOutput",
|
||||
"ScoringOutput": ".outputs:ScoringOutput",
|
||||
"ScoringRequestOutput": ".outputs:ScoringRequestOutput",
|
||||
}
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.outputs import (ClassificationOutput,
|
||||
ClassificationRequestOutput, CompletionOutput,
|
||||
EmbeddingOutput, EmbeddingRequestOutput,
|
||||
PoolingOutput, PoolingRequestOutput,
|
||||
RequestOutput, ScoringOutput,
|
||||
ScoringRequestOutput)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
from ._bc_linter import bc_linter_include, bc_linter_skip
|
||||
else:
|
||||
|
||||
def __getattr__(name: str) -> typing.Any:
|
||||
from importlib import import_module
|
||||
|
||||
if name in MODULE_ATTRS:
|
||||
module_name, attr_name = MODULE_ATTRS[name].split(":")
|
||||
module = import_module(module_name, __package__)
|
||||
return getattr(module, attr_name)
|
||||
else:
|
||||
raise AttributeError(
|
||||
f'module {__package__} has no attribute {name}')
|
||||
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"bc_linter_skip",
|
||||
"bc_linter_include",
|
||||
"__version_tuple__",
|
||||
"LLM",
|
||||
"ModelRegistry",
|
||||
"PromptType",
|
||||
"TextPrompt",
|
||||
"TokensPrompt",
|
||||
"SamplingParams",
|
||||
"RequestOutput",
|
||||
"CompletionOutput",
|
||||
"PoolingOutput",
|
||||
"PoolingRequestOutput",
|
||||
"EmbeddingOutput",
|
||||
"EmbeddingRequestOutput",
|
||||
"ClassificationOutput",
|
||||
"ClassificationRequestOutput",
|
||||
"ScoringOutput",
|
||||
"ScoringRequestOutput",
|
||||
"LLMEngine",
|
||||
"EngineArgs",
|
||||
"AsyncLLMEngine",
|
||||
"AsyncEngineArgs",
|
||||
"initialize_ray_cluster",
|
||||
"PoolingParams",
|
||||
]
|
||||
BIN
vllm/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/_bc_linter.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/_bc_linter.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/_custom_ops.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/_custom_ops.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/_version.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/_version.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/beam_search.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/beam_search.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/connections.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/connections.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/env_override.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/env_override.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/envs.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/envs.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/logger.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/logger.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/logits_process.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/logits_process.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/logprobs.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/logprobs.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/outputs.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/outputs.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/pooling_params.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/pooling_params.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/sampling_params.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/sampling_params.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/scalar_type.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/scalar_type.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/sequence.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/sequence.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/tasks.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/tasks.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/test_utils.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/test_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/tracing.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/tracing.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/version.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/version.cpython-312.pyc
Normal file
Binary file not shown.
59
vllm/_bc_linter.py
Normal file
59
vllm/_bc_linter.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# vllm/_bc_linter.py
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, TypeVar, overload
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@overload
|
||||
def bc_linter_skip(obj: T) -> T:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def bc_linter_skip(*, reason: str | None = ...) -> Callable[[T], T]:
|
||||
...
|
||||
|
||||
|
||||
def bc_linter_skip(obj: Any = None, *, reason: str | None = None):
|
||||
"""
|
||||
No-op decorator to mark symbols/files for BC-linter suppression.
|
||||
|
||||
Usage:
|
||||
@bc_linter_skip
|
||||
def legacy_api(...): ...
|
||||
"""
|
||||
|
||||
def _wrap(x: T) -> T:
|
||||
return x
|
||||
|
||||
return _wrap if obj is None else obj
|
||||
|
||||
|
||||
@overload
|
||||
def bc_linter_include(obj: T) -> T:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def bc_linter_include(*, reason: str | None = ...) -> Callable[[T], T]:
|
||||
...
|
||||
|
||||
|
||||
def bc_linter_include(obj: Any = None, *, reason: str | None = None):
|
||||
"""
|
||||
Usage:
|
||||
@bc_linter_include
|
||||
def public_api(...): ...
|
||||
"""
|
||||
|
||||
def _wrap(x: T) -> T:
|
||||
return x
|
||||
|
||||
return _wrap if obj is None else obj
|
||||
|
||||
|
||||
__all__ = ["bc_linter_skip", "bc_linter_include"]
|
||||
2044
vllm/_custom_ops.py
Normal file
2044
vllm/_custom_ops.py
Normal file
File diff suppressed because it is too large
Load Diff
393
vllm/_ipex_ops.py
Normal file
393
vllm/_ipex_ops.py
Normal file
@@ -0,0 +1,393 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
except ImportError as e:
|
||||
logger.debug("Import error msg: %s", e.msg)
|
||||
|
||||
|
||||
class ipex_ops:
|
||||
|
||||
@staticmethod
|
||||
def _reshape_activation_tensor(
|
||||
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
num = x.size(0)
|
||||
d = x.size(1) // 2
|
||||
x = x.reshape(num, 2, d)
|
||||
x1, x2 = torch.chunk(x, chunks=2, dim=1)
|
||||
x1 = x1.reshape(num, d)
|
||||
x2 = x2.reshape(num, d)
|
||||
return x1, x2
|
||||
|
||||
@staticmethod
|
||||
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
ipex.llm.functional.silu_and_mul(x, out)
|
||||
|
||||
@staticmethod
|
||||
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
ipex.llm.functional.gelu_and_mul(x, out)
|
||||
|
||||
@staticmethod
|
||||
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
ipex.llm.functional.gelu_and_mul(x, out)
|
||||
|
||||
@staticmethod
|
||||
def gelu_fast(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.nn.functional.gelu(x)
|
||||
|
||||
@staticmethod
|
||||
def gelu_new(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.nn.functional.gelu(x)
|
||||
|
||||
@staticmethod
|
||||
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
ipex.llm.functional.gelu_quick(x, out)
|
||||
|
||||
@staticmethod
|
||||
def paged_attention_v1(
|
||||
out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_size: int,
|
||||
max_context_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
tp_rank: int = 0,
|
||||
blocksparse_local_blocks: int = 0,
|
||||
blocksparse_vert_stride: int = 0,
|
||||
blocksparse_block_size: int = 64,
|
||||
blocksparse_head_sliding_step: int = 0,
|
||||
) -> None:
|
||||
assert kv_cache_dtype == "auto"
|
||||
num_heads = out.size(1)
|
||||
num_queries_per_tokens = num_heads // num_kv_heads
|
||||
ipex.llm.modules.PagedAttention.single_query_kv_attention(
|
||||
out,
|
||||
query.contiguous(),
|
||||
key_cache.view_as(value_cache),
|
||||
value_cache,
|
||||
num_queries_per_tokens,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def paged_attention_v2(
|
||||
out: torch.Tensor,
|
||||
exp_sum: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
tmp_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_size: int,
|
||||
max_context_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
tp_rank: int = 0,
|
||||
blocksparse_local_blocks: int = 0,
|
||||
blocksparse_vert_stride: int = 0,
|
||||
blocksparse_block_size: int = 64,
|
||||
blocksparse_head_sliding_step: int = 0,
|
||||
) -> None:
|
||||
assert kv_cache_dtype == "auto"
|
||||
num_heads = out.size(1)
|
||||
num_queries_per_tokens = num_heads // num_kv_heads
|
||||
ipex.llm.modules.PagedAttention.single_query_kv_attention(
|
||||
out,
|
||||
query.contiguous(),
|
||||
key_cache.view_as(value_cache),
|
||||
value_cache,
|
||||
num_queries_per_tokens,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def rotary_embedding(
|
||||
positions: torch.Tensor, # [batch_size, seq_len]
|
||||
query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size]
|
||||
key: torch.Tensor, # [batch_size, seq_len, num_kv_heads*head_size]
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim]
|
||||
is_neox: bool,
|
||||
) -> None:
|
||||
rot_dim = cos_sin_cache.size(1)
|
||||
ipex.llm.functional.rotary_embedding_batched(positions, query, key,
|
||||
head_size, cos_sin_cache,
|
||||
is_neox, rot_dim)
|
||||
|
||||
@staticmethod
|
||||
def rms_norm(input: torch.Tensor, weight: torch.Tensor,
|
||||
epsilon: float) -> torch.Tensor:
|
||||
return ipex.llm.functional.rms_norm(input, weight, epsilon)
|
||||
|
||||
@staticmethod
|
||||
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
|
||||
weight: torch.Tensor, epsilon: float) -> None:
|
||||
tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
|
||||
epsilon, True)
|
||||
input.copy_(tmp)
|
||||
|
||||
@staticmethod
|
||||
def varlen_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
seqlen_q: torch.Tensor,
|
||||
seqlen_k: torch.Tensor,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
pdropout: float,
|
||||
softmax_scale: float,
|
||||
zero_tensors: bool,
|
||||
is_causal: bool,
|
||||
return_softmax: bool,
|
||||
gen_: torch.Generator,
|
||||
window_size_left: float,
|
||||
window_size_right: float,
|
||||
logits_soft_cap: float,
|
||||
) -> None:
|
||||
if ipex.__version__.endswith("cpu"):
|
||||
if logits_soft_cap != 0.0:
|
||||
raise ValueError("IPEX CPU does not support logits_soft_cap")
|
||||
assert alibi_slopes is None
|
||||
assert window_size_left < 0 and window_size_right < 0
|
||||
ipex.llm.functional.varlen_attention(query.contiguous(),
|
||||
key.contiguous(),
|
||||
value.contiguous(), out,
|
||||
seqlen_q.int(),
|
||||
seqlen_k.int(), max_seqlen_q,
|
||||
max_seqlen_k, pdropout,
|
||||
softmax_scale, zero_tensors,
|
||||
is_causal, return_softmax,
|
||||
gen_)
|
||||
else: # XPU build
|
||||
ipex.llm.functional.varlen_attention(
|
||||
query.contiguous(), key.contiguous(), value.contiguous(), out,
|
||||
seqlen_q.int(), seqlen_k.int(), alibi_slopes, max_seqlen_q,
|
||||
max_seqlen_k, pdropout, softmax_scale, zero_tensors, is_causal,
|
||||
return_softmax, gen_, window_size_left, window_size_right,
|
||||
logits_soft_cap)
|
||||
|
||||
@staticmethod
|
||||
def reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
) -> None:
|
||||
assert kv_cache_dtype == "auto"
|
||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slot_mapping)
|
||||
|
||||
@staticmethod
|
||||
def reshape_and_cache_flash(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: Optional[torch.Tensor] = None,
|
||||
v_scale: Optional[torch.Tensor] = None,
|
||||
k_scale_float: float = 1.0,
|
||||
v_scale_float: float = 1.0,
|
||||
) -> None:
|
||||
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
|
||||
key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
|
||||
k_scale_float, v_scale_float)
|
||||
|
||||
@staticmethod
|
||||
def flash_attn_varlen_func(
|
||||
out: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
seqused_k: torch.Tensor, # we don't support this in ipex kernel
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
softmax_scale: float,
|
||||
causal: bool,
|
||||
block_table: torch.Tensor,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
window_size: Optional[list[int]] = None,
|
||||
softcap: Optional[float] = 0.0,
|
||||
cu_seqlens_k: Optional[torch.Tensor] = None,
|
||||
# The following parameters are not used in ipex kernel currently,
|
||||
# we keep API compatible to CUDA's.
|
||||
scheduler_metadata=None,
|
||||
fa_version: int = 2,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
num_splits=0,
|
||||
s_aux: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if cu_seqlens_k is None:
|
||||
# cu_seqlens_k is not used in ipex kernel.
|
||||
cu_seqlens_k = torch.cumsum(seqused_k, dim=0)
|
||||
cu_seqlens_k = torch.cat([
|
||||
torch.tensor([0], device=seqused_k.device, dtype=torch.int32),
|
||||
cu_seqlens_k
|
||||
]).to(torch.int32)
|
||||
|
||||
real_window_size: tuple[int, int]
|
||||
if window_size is None:
|
||||
real_window_size = (-1, -1)
|
||||
else:
|
||||
assert len(window_size) == 2
|
||||
real_window_size = (window_size[0], window_size[1])
|
||||
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||
out,
|
||||
q.contiguous(),
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
softmax_scale,
|
||||
causal,
|
||||
block_table,
|
||||
alibi_slopes,
|
||||
softcap=softcap,
|
||||
window_size_left=real_window_size[0],
|
||||
window_size_right=real_window_size[1],
|
||||
k_scale=1.0,
|
||||
v_scale=1.0,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_scheduler_metadata(
|
||||
batch_size,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
num_heads_q,
|
||||
num_heads_kv,
|
||||
headdim,
|
||||
cache_seqlens: torch.Tensor,
|
||||
qkv_dtype=torch.bfloat16,
|
||||
headdim_v=None,
|
||||
cu_seqlens_q: Optional[torch.Tensor] = None,
|
||||
cu_seqlens_k_new: Optional[torch.Tensor] = None,
|
||||
cache_leftpad: Optional[torch.Tensor] = None,
|
||||
page_size: Optional[int] = None,
|
||||
max_seqlen_k_new=0,
|
||||
causal=False,
|
||||
window_size=(-1, -1), # -1 means infinite context window
|
||||
has_softcap=False,
|
||||
num_splits=0, # Can be tuned for speed
|
||||
pack_gqa=None, # Can be tuned for speed
|
||||
sm_margin=0, # Can be tuned if some SMs are used for communication
|
||||
) -> None:
|
||||
logger.warning_once(
|
||||
"get_scheduler_metadata is not implemented for ipex_ops, "
|
||||
"returning None.")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(key_caches: list[torch.Tensor],
|
||||
value_caches: list[torch.Tensor],
|
||||
block_mapping: torch.Tensor) -> None:
|
||||
torch.xpu.copy_blocks( # type: ignore
|
||||
key_caches,
|
||||
value_caches,
|
||||
block_mapping,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
|
||||
block_mapping: torch.Tensor) -> None:
|
||||
torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
num_token_padding: Optional[int] = None,
|
||||
scale_ub: Optional[torch.Tensor] = None,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP8 and return quantized tensor and scale.
|
||||
|
||||
This function is designed for both static and dynamic quantization:
|
||||
If you provide the scale, it will use static scaling and if you omit
|
||||
it, the scale will be determined dynamically. Currently, XPU platform
|
||||
only supports dynamic quantization. The function also allows optional
|
||||
padding of the output tensors for downstream kernels that will benefit
|
||||
from padding.
|
||||
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP8
|
||||
scale: Optional scaling factor for the FP8 quantization
|
||||
scale_ub: Optional upper bound for scaling factor in dynamic
|
||||
per token case
|
||||
num_token_padding: If specified, pad the first dimension
|
||||
of the output to at least this value.
|
||||
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
||||
in the dynamic quantization case.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
||||
scaling factor.
|
||||
"""
|
||||
# This code assumes batch_dim and num_tokens are flattened
|
||||
assert (input.ndim == 2)
|
||||
shape: Union[tuple[int, int], torch.Size] = input.shape
|
||||
out_dtype: torch.dtype = current_platform.fp8_dtype()
|
||||
if num_token_padding:
|
||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||
if output is None:
|
||||
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
||||
else:
|
||||
assert num_token_padding is None, \
|
||||
"padding not supported if output passed in"
|
||||
assert output.dtype == out_dtype
|
||||
assert scale is None, "only dynamic fp8 quantization supported on XPU"
|
||||
assert not use_per_token_if_dynamic, (
|
||||
"per token dynamic fp8 quantization not supported on XPU")
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
torch.ops.torch_ipex.dynamic_scaled_fp8_quant(output, input, scale)
|
||||
|
||||
return output, scale
|
||||
34
vllm/_version.py
Normal file
34
vllm/_version.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# file generated by setuptools-scm
|
||||
# don't change, don't track in version control
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"__version_tuple__",
|
||||
"version",
|
||||
"version_tuple",
|
||||
"__commit_id__",
|
||||
"commit_id",
|
||||
]
|
||||
|
||||
TYPE_CHECKING = False
|
||||
if TYPE_CHECKING:
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
||||
COMMIT_ID = Union[str, None]
|
||||
else:
|
||||
VERSION_TUPLE = object
|
||||
COMMIT_ID = object
|
||||
|
||||
version: str
|
||||
__version__: str
|
||||
__version_tuple__: VERSION_TUPLE
|
||||
version_tuple: VERSION_TUPLE
|
||||
commit_id: COMMIT_ID
|
||||
__commit_id__: COMMIT_ID
|
||||
|
||||
__version__ = version = '0.11.0'
|
||||
__version_tuple__ = version_tuple = (0, 11, 0)
|
||||
|
||||
__commit_id__ = commit_id = None
|
||||
0
vllm/assets/__init__.py
Normal file
0
vllm/assets/__init__.py
Normal file
BIN
vllm/assets/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm/assets/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/assets/__pycache__/base.cpython-312.pyc
Normal file
BIN
vllm/assets/__pycache__/base.cpython-312.pyc
Normal file
Binary file not shown.
45
vllm/assets/audio.py
Normal file
45
vllm/assets/audio.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import numpy.typing as npt
|
||||
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
from .base import VLLM_S3_BUCKET_URL, get_vllm_public_assets
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
ASSET_DIR = "multimodal_asset"
|
||||
|
||||
AudioAssetName = Literal["winning_call", "mary_had_lamb"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AudioAsset:
|
||||
name: AudioAssetName
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return f"{self.name}.ogg"
|
||||
|
||||
@property
|
||||
def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
|
||||
audio_path = get_vllm_public_assets(filename=self.filename,
|
||||
s3_prefix=ASSET_DIR)
|
||||
return librosa.load(audio_path, sr=None)
|
||||
|
||||
def get_local_path(self) -> Path:
|
||||
return get_vllm_public_assets(filename=self.filename,
|
||||
s3_prefix=ASSET_DIR)
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg")
|
||||
41
vllm/assets/base.py
Normal file
41
vllm/assets/base.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import global_http_connection
|
||||
|
||||
VLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com"
|
||||
|
||||
|
||||
def get_cache_dir() -> Path:
|
||||
"""Get the path to the cache for storing downloaded assets."""
|
||||
path = Path(envs.VLLM_ASSETS_CACHE)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_vllm_public_assets(filename: str,
|
||||
s3_prefix: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Download an asset file from ``s3://vllm-public-assets``
|
||||
and return the path to the downloaded file.
|
||||
"""
|
||||
asset_directory = get_cache_dir() / "vllm_public_assets"
|
||||
asset_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
asset_path = asset_directory / filename
|
||||
if not asset_path.exists():
|
||||
if s3_prefix is not None:
|
||||
filename = s3_prefix + "/" + filename
|
||||
global_http_connection.download_file(
|
||||
f"{VLLM_S3_BUCKET_URL}/{filename}",
|
||||
asset_path,
|
||||
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT)
|
||||
|
||||
return asset_path
|
||||
50
vllm/assets/image.py
Normal file
50
vllm/assets/image.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from .base import get_vllm_public_assets
|
||||
|
||||
VLM_IMAGES_DIR = "vision_model_images"
|
||||
|
||||
ImageAssetName = Literal["stop_sign", "cherry_blossom", "hato",
|
||||
"2560px-Gfp-wisconsin-madison-the-nature-boardwalk",
|
||||
"Grayscale_8bits_palette_sample_image",
|
||||
"1280px-Venn_diagram_rgb", "RGBA_comp", "237-400x300",
|
||||
"231-200x300", "27-500x500", "17-150x600",
|
||||
"handelsblatt-preview", "paper-11"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ImageAsset:
|
||||
name: ImageAssetName
|
||||
|
||||
def get_path(self, ext: str) -> Path:
|
||||
"""
|
||||
Return s3 path for given image.
|
||||
"""
|
||||
return get_vllm_public_assets(filename=f"{self.name}.{ext}",
|
||||
s3_prefix=VLM_IMAGES_DIR)
|
||||
|
||||
@property
|
||||
def pil_image(self, ext="jpg") -> Image.Image:
|
||||
|
||||
image_path = self.get_path(ext)
|
||||
return Image.open(image_path)
|
||||
|
||||
@property
|
||||
def image_embeds(self) -> torch.Tensor:
|
||||
"""
|
||||
Image embeddings, only used for testing purposes with llava 1.5.
|
||||
"""
|
||||
image_path = self.get_path('pt')
|
||||
return torch.load(image_path, map_location="cpu", weights_only=True)
|
||||
|
||||
def read_bytes(self, ext: str) -> bytes:
|
||||
p = Path(self.get_path(ext))
|
||||
return p.read_bytes()
|
||||
145
vllm/assets/video.py
Normal file
145
vllm/assets/video.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Any, ClassVar, Literal, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
from .base import get_cache_dir
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
|
||||
@lru_cache
|
||||
def download_video_asset(filename: str) -> str:
|
||||
"""
|
||||
Download and open an image from huggingface
|
||||
repo: raushan-testing-hf/videos-test
|
||||
"""
|
||||
video_directory = get_cache_dir() / "video-example-data"
|
||||
video_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
video_path = video_directory / filename
|
||||
video_path_str = str(video_path)
|
||||
if not video_path.exists():
|
||||
video_path_str = hf_hub_download(
|
||||
repo_id="raushan-testing-hf/videos-test",
|
||||
filename=filename,
|
||||
repo_type="dataset",
|
||||
cache_dir=video_directory,
|
||||
)
|
||||
return video_path_str
|
||||
|
||||
|
||||
def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
|
||||
cap = cv2.VideoCapture(path)
|
||||
if not cap.isOpened():
|
||||
raise ValueError(f"Could not open video file {path}")
|
||||
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
frames = []
|
||||
|
||||
num_frames = num_frames if num_frames > 0 else total_frames
|
||||
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
||||
for idx in range(total_frames):
|
||||
ok = cap.grab() # next img
|
||||
if not ok:
|
||||
break
|
||||
if idx in frame_indices: # only decompress needed
|
||||
ret, frame = cap.retrieve()
|
||||
if ret:
|
||||
# OpenCV uses BGR format, we need to convert it to RGB
|
||||
# for PIL and transformers compatibility
|
||||
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
|
||||
frames = np.stack(frames)
|
||||
if len(frames) < num_frames:
|
||||
raise ValueError(f"Could not read enough frames from video file {path}"
|
||||
f" (expected {num_frames} frames, got {len(frames)})")
|
||||
return frames
|
||||
|
||||
|
||||
def video_to_pil_images_list(path: str,
|
||||
num_frames: int = -1) -> list[Image.Image]:
|
||||
frames = video_to_ndarrays(path, num_frames)
|
||||
return [Image.fromarray(frame) for frame in frames]
|
||||
|
||||
|
||||
def video_get_metadata(path: str, num_frames: int = -1) -> dict[str, Any]:
|
||||
cap = cv2.VideoCapture(path)
|
||||
if not cap.isOpened():
|
||||
raise ValueError(f"Could not open video file {path}")
|
||||
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
duration = total_frames / fps if fps > 0 else 0
|
||||
|
||||
if num_frames == -1 or num_frames > total_frames:
|
||||
num_frames = total_frames
|
||||
|
||||
metadata = {
|
||||
"total_num_frames": num_frames,
|
||||
"fps": fps,
|
||||
"duration": duration,
|
||||
"video_backend": "opencv",
|
||||
"frames_indices": list(range(num_frames)),
|
||||
# extra field used to control hf processor's video
|
||||
# sampling behavior
|
||||
"do_sample_frames": num_frames == total_frames,
|
||||
}
|
||||
return metadata
|
||||
|
||||
|
||||
VideoAssetName = Literal["baby_reading"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VideoAsset:
|
||||
name: VideoAssetName
|
||||
num_frames: int = -1
|
||||
|
||||
_NAME_TO_FILE: ClassVar[dict[VideoAssetName, str]] = {
|
||||
"baby_reading": "sample_demo_1.mp4",
|
||||
}
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self._NAME_TO_FILE[self.name]
|
||||
|
||||
@property
|
||||
def video_path(self) -> str:
|
||||
return download_video_asset(self.filename)
|
||||
|
||||
@property
|
||||
def pil_images(self) -> list[Image.Image]:
|
||||
ret = video_to_pil_images_list(self.video_path, self.num_frames)
|
||||
return ret
|
||||
|
||||
@property
|
||||
def np_ndarrays(self) -> npt.NDArray:
|
||||
ret = video_to_ndarrays(self.video_path, self.num_frames)
|
||||
return ret
|
||||
|
||||
@property
|
||||
def metadata(self) -> dict[str, Any]:
|
||||
ret = video_get_metadata(self.video_path, self.num_frames)
|
||||
return ret
|
||||
|
||||
def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray:
|
||||
"""
|
||||
Read audio data from the video asset, used in Qwen2.5-Omni examples.
|
||||
|
||||
See also: examples/offline_inference/qwen2_5_omni/only_thinker.py
|
||||
"""
|
||||
return librosa.load(self.video_path, sr=sampling_rate)[0]
|
||||
15
vllm/attention/__init__.py
Normal file
15
vllm/attention/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
|
||||
__all__ = [
|
||||
"Attention",
|
||||
"AttentionBackend",
|
||||
"AttentionMetadata",
|
||||
"AttentionType",
|
||||
"get_attn_backend",
|
||||
]
|
||||
0
vllm/attention/backends/__init__.py
Normal file
0
vllm/attention/backends/__init__.py
Normal file
204
vllm/attention/backends/abstract.py
Normal file
204
vllm/attention/backends/abstract.py
Normal file
@@ -0,0 +1,204 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, List, Optional, Protocol, Tuple, Type, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
|
||||
|
||||
class AttentionType:
|
||||
"""
|
||||
Attention type.
|
||||
Use string to be compatible with `torch.compile`.
|
||||
"""
|
||||
DECODER = "decoder"
|
||||
"""Decoder attention between previous layer Q/K/V."""
|
||||
ENCODER = "encoder"
|
||||
"""Encoder attention between previous layer Q/K/V for encoder-decoder."""
|
||||
ENCODER_ONLY = "encoder_only"
|
||||
"""Encoder attention between previous layer Q/K/V."""
|
||||
ENCODER_DECODER = "encoder_decoder"
|
||||
"""Attention between dec. Q and enc. K/V for encoder-decoder."""
|
||||
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
"""Abstract class for attention backends."""
|
||||
# For some attention backends, we allocate an output tensor before
|
||||
# calling the custom op. When piecewise cudagraph is enabled, this
|
||||
# makes sure the output tensor is allocated inside the cudagraph.
|
||||
accept_output_buffer: bool = False
|
||||
|
||||
# Whether this backend supports receiving pre-quantized query input.
|
||||
# If True, the attention layer will handle query quantization instead
|
||||
# of the backend, allowing torch.compile to fuse quantization with
|
||||
# previous operations.
|
||||
# Needs to be worked through for all backends
|
||||
# https://github.com/vllm-project/vllm/issues/25584
|
||||
supports_quant_query_input: bool = False
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_name() -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_impl_cls() -> Type["AttentionImpl"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
|
||||
return cls.get_metadata_cls()(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> Tuple[int, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> Tuple[int, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def full_cls_name(cls) -> tuple[str, str]:
|
||||
return (cls.__module__, cls.__qualname__)
|
||||
|
||||
|
||||
class AttentionMetadata:
|
||||
pass
|
||||
|
||||
|
||||
T = TypeVar("T", bound=AttentionMetadata)
|
||||
|
||||
|
||||
class AttentionLayer(Protocol):
|
||||
|
||||
_q_scale: torch.Tensor
|
||||
_k_scale: torch.Tensor
|
||||
_v_scale: torch.Tensor
|
||||
_q_scale_float: float
|
||||
_k_scale_float: float
|
||||
_v_scale_float: float
|
||||
_prob_scale: torch.Tensor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
|
||||
|
||||
class AttentionImpl(ABC, Generic[T]):
|
||||
|
||||
# Whether the attention impl can return the softmax lse for decode.
|
||||
# Some features like decode context parallelism require the softmax lse.
|
||||
can_return_lse_for_decode: bool = False
|
||||
|
||||
# some attention backends might not always want to return lse
|
||||
# even if they can return lse (for efficiency reasons)
|
||||
need_to_return_lse_for_decode: bool = False
|
||||
|
||||
dcp_world_size: int
|
||||
dcp_rank: int
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# use __new__ so that all subclasses will call this
|
||||
self = super().__new__(cls)
|
||||
try:
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.need_to_return_lse_for_decode = self.dcp_world_size > 1 \
|
||||
and self.can_return_lse_for_decode
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
kv_cache_dtype: str = "auto",
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
"""
|
||||
Does this attention implementation support fused output quantization.
|
||||
This is used by the AttnFusionPass to only fuse output quantization
|
||||
onto implementations that support it.
|
||||
|
||||
:param quant_key: QuantKey object that describes the quantization op
|
||||
:return: is fusion supported for this type of quantization
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states_or_cq: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
||||
return kv_cache_dtype != "auto"
|
||||
33
vllm/attention/backends/utils.py
Normal file
33
vllm/attention/backends/utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention backend utils"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLADims:
|
||||
q_lora_rank: Optional[int]
|
||||
kv_lora_rank: int
|
||||
qk_nope_head_dim: int
|
||||
qk_rope_head_dim: int
|
||||
v_head_dim: int
|
||||
|
||||
|
||||
def get_mla_dims(model_config: ModelConfig) -> MLADims:
|
||||
hf_text_config = model_config.hf_text_config
|
||||
|
||||
return MLADims(
|
||||
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
|
||||
kv_lora_rank=hf_text_config.kv_lora_rank,
|
||||
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
|
||||
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
|
||||
v_head_dim=hf_text_config.v_head_dim,
|
||||
)
|
||||
645
vllm/attention/layer.py
Normal file
645
vllm/attention/layer.py
Normal file
@@ -0,0 +1,645 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer."""
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import GiB_bytes, direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
USE_XFORMERS_OPS = None
|
||||
try:
|
||||
tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, )
|
||||
except AttributeError:
|
||||
tag_cudagraph_unsafe = () # type: ignore[assignment]
|
||||
|
||||
|
||||
def check_xformers_availability():
|
||||
global USE_XFORMERS_OPS
|
||||
if USE_XFORMERS_OPS is not None:
|
||||
return USE_XFORMERS_OPS
|
||||
|
||||
if current_platform.is_cuda() and current_platform.has_device_capability(
|
||||
100):
|
||||
# Xformers FA is not compatible with B200
|
||||
USE_XFORMERS_OPS = False
|
||||
else:
|
||||
try:
|
||||
from importlib.util import find_spec
|
||||
|
||||
find_spec("xformers.ops")
|
||||
USE_XFORMERS_OPS = True
|
||||
except ImportError:
|
||||
USE_XFORMERS_OPS = False
|
||||
|
||||
# the warning only needs to be shown once
|
||||
if not USE_XFORMERS_OPS:
|
||||
logger.warning("Xformers is not available, falling back.")
|
||||
|
||||
return USE_XFORMERS_OPS
|
||||
|
||||
|
||||
def check_upstream_fa_availability(dtype: torch.dtype):
|
||||
if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda(
|
||||
) and current_platform.has_device_capability(80):
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
return is_flash_attn_2_available()
|
||||
return False
|
||||
|
||||
|
||||
class Attention(nn.Module, AttentionLayerBase):
|
||||
"""Attention layer.
|
||||
|
||||
This class takes query, key, and value tensors as input. The input tensors
|
||||
can either contain prompt tokens or generation tokens.
|
||||
The class does the following:
|
||||
|
||||
1. Store the input key and value tensors in the KV cache.
|
||||
2. Perform (multi-head/multi-query/grouped-query) attention.
|
||||
3. Return the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
per_layer_sliding_window: Optional[int] = None,
|
||||
use_mla: bool = False,
|
||||
use_sparse: bool = False,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
attn_backend: Optional[type[AttentionBackend]] = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
"""
|
||||
The KV cache is stored inside this class and is accessed via
|
||||
`self.kv_cache`.
|
||||
"""
|
||||
super().__init__()
|
||||
if per_layer_sliding_window is not None:
|
||||
# per-layer sliding window
|
||||
sliding_window = per_layer_sliding_window
|
||||
elif cache_config is not None:
|
||||
# model-level sliding window
|
||||
sliding_window = cache_config.sliding_window
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
calculate_kv_scales = cache_config.calculate_kv_scales
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
calculate_kv_scales = False
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = num_heads
|
||||
assert num_heads % num_kv_heads == 0, \
|
||||
f"num_heads ({num_heads}) is not " \
|
||||
f"divisible by num_kv_heads ({num_kv_heads})"
|
||||
|
||||
# The default k/v_scale is set to 1.0. This is ignored
|
||||
# when kv-cache is not fp8, and should be used with
|
||||
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
|
||||
# expect the pre-quantized k/v_scale to be loaded along
|
||||
# with the model weights.
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
# FlashAttn doesn't support quantizing the kv-cache only
|
||||
# but requires q to be quantized as well.
|
||||
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
|
||||
# We also keep q/k/v_scale on host (cpu) memory for attention
|
||||
# backends that require the scales to be on host instead of on device.
|
||||
# e.g. Flashinfer
|
||||
self._q_scale_float = 1.0
|
||||
self._k_scale_float = 1.0
|
||||
self._v_scale_float = 1.0
|
||||
|
||||
# The output scale on host memory. This should be the input scale of
|
||||
# the quant op after this attention layer.
|
||||
self._o_scale_float: Optional[float] = None
|
||||
|
||||
self.use_mla = use_mla
|
||||
self.use_sparse = use_sparse
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.has_sink = extra_impl_args.get("sinks") is not None
|
||||
|
||||
quant_method = quant_config.get_quant_method(
|
||||
self, prefix=prefix) if quant_config else None
|
||||
if quant_method is not None and not isinstance(
|
||||
quant_method, UnquantizedLinearMethod):
|
||||
assert isinstance(quant_method, BaseKVCacheMethod)
|
||||
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
||||
# checkpoint config and become the "auto" behavior
|
||||
if self.kv_cache_dtype == "fp8_e5m2":
|
||||
raise ValueError("fp8_e5m2 kv-cache is not supported with "
|
||||
"fp8 checkpoints.")
|
||||
# If quantization is enabled, we make "k_scale" and "v_scale"
|
||||
# parameters so that it can be loaded from the model checkpoint.
|
||||
# The k/v_scale will then be converted back to native float32
|
||||
# values after weight loading.
|
||||
self.quant_method = quant_method
|
||||
self.quant_method.create_weights(self)
|
||||
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
if attn_backend is None:
|
||||
self.attn_backend = get_attn_backend(head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla=use_mla,
|
||||
has_sink=self.has_sink,
|
||||
use_sparse=use_sparse)
|
||||
else:
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
impl_cls = self.attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **extra_impl_args)
|
||||
self.backend = backend_name_to_enum(self.attn_backend.get_name())
|
||||
self.dtype = dtype
|
||||
|
||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||
# torch.compile works by registering the attention as one giant
|
||||
# opaque custom op. For other platforms, we directly call them
|
||||
# and let torch.compile handle them.
|
||||
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||
|
||||
self.use_output = self.attn_backend.accept_output_buffer
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.layer_name = prefix
|
||||
self.attn_type = attn_type
|
||||
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
validate_kv_sharing_target(
|
||||
prefix,
|
||||
kv_sharing_target_layer_name,
|
||||
compilation_config.static_forward_context,
|
||||
)
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
# use a placeholder kv cache tensor during init, which will be replaced
|
||||
# by bind_kv_cache
|
||||
# this variable will not be accessed if use_direct_call is True
|
||||
self.kv_cache = [
|
||||
torch.tensor([]) for _ in range(get_current_vllm_config(
|
||||
).parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
try:
|
||||
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT,
|
||||
dtype=torch.float32)
|
||||
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT,
|
||||
dtype=torch.float32)
|
||||
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT,
|
||||
dtype=torch.float32)
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
logger.error(
|
||||
"Failed to initialize attention q/k/v range constants: %s", e)
|
||||
if torch.cuda.is_available():
|
||||
logger.debug("CUDA device: %s", torch.cuda.current_device())
|
||||
logger.debug("Allocated: %.2f GiB",
|
||||
torch.cuda.memory_allocated() / GiB_bytes)
|
||||
logger.debug("Reserved: %.2f GiB",
|
||||
torch.cuda.memory_reserved() / GiB_bytes)
|
||||
raise RuntimeError(
|
||||
"Failed to initialize q/k/v range constants. "
|
||||
"This may be caused by insufficient memory to allocate "
|
||||
"kv cache.") from e
|
||||
|
||||
# for attn backends supporting query quantization
|
||||
self.query_quant = None
|
||||
if self.kv_cache_dtype.startswith(
|
||||
"fp8") and self.attn_backend.supports_quant_query_input:
|
||||
self.query_quant = QuantFP8(static=True,
|
||||
group_shape=GroupShape.PER_TENSOR)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
# For some alternate attention backends like MLA the attention output
|
||||
# shape does not match the query shape, so we optionally let the model
|
||||
# definition specify the output tensor shape.
|
||||
output_shape: Optional[torch.Size] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
The KV cache is stored inside this class and is accessed via
|
||||
`self.kv_cache`.
|
||||
|
||||
Attention metadata (`attn_metadata`) is set using a context manager in
|
||||
the model runner's `execute_model` method. It is accessed via forward
|
||||
context using
|
||||
`vllm.forward_context.get_forward_context().attn_metadata`.
|
||||
"""
|
||||
if self.calculate_kv_scales:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata.enable_kv_scales_calculation:
|
||||
self.calc_kv_scales(query, key, value)
|
||||
|
||||
output_dtype = query.dtype
|
||||
if self.query_quant is not None:
|
||||
# quantizing with a simple torch operation enables
|
||||
# torch.compile to fuse this into previous ops
|
||||
# which reduces overheads during decoding.
|
||||
# Otherwise queries are quantized using custom ops
|
||||
# which causes decoding overheads
|
||||
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
|
||||
query, _ = self.query_quant(query, self._q_scale)
|
||||
|
||||
if self.use_output:
|
||||
output_shape = (output_shape
|
||||
if output_shape is not None else query.shape)
|
||||
output = torch.zeros(output_shape,
|
||||
dtype=output_dtype,
|
||||
device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
# We skip reshaping query, key and value tensors for the MLA
|
||||
# backend since these tensors have different semantics and are
|
||||
# processed differently.
|
||||
if not self.use_mla:
|
||||
# Reshape the query, key, and value tensors.
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||
# CPU overheads from the non-CUDA-graph regions.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
output = output.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
if self.use_direct_call:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self_kv_cache,
|
||||
attn_metadata,
|
||||
output=output)
|
||||
else:
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
query, key, value, output, self.layer_name)
|
||||
return output.view(-1, hidden_size)
|
||||
else:
|
||||
if self.use_direct_call:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
return self.impl.forward(self, query, key, value,
|
||||
self_kv_cache, attn_metadata)
|
||||
else:
|
||||
return torch.ops.vllm.unified_attention(
|
||||
query, key, value, self.layer_name)
|
||||
|
||||
def calc_kv_scales(self, query, key, value):
|
||||
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
|
||||
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
|
||||
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
|
||||
self._q_scale_float = self._q_scale.item()
|
||||
self._k_scale_float = self._k_scale.item()
|
||||
self._v_scale_float = self._v_scale.item()
|
||||
# We only calculate the scales once
|
||||
self.calculate_kv_scales = False
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"head_size={self.impl.head_size}" # type: ignore
|
||||
s += f", num_heads={self.impl.num_heads}" # type: ignore
|
||||
s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
|
||||
s += f", scale={self.impl.scale}" # type: ignore
|
||||
s += f", backend={self.impl.__class__.__name__}"
|
||||
return s
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
if hasattr(self.impl, "process_weights_after_loading"):
|
||||
self.impl.process_weights_after_loading(act_dtype)
|
||||
|
||||
# FlashInfer requires attention sinks to be float32
|
||||
if (self.backend == _Backend.FLASHINFER
|
||||
and hasattr(self.impl, 'sinks')):
|
||||
from vllm.v1.attention.backends.flashinfer import FlashInferImpl
|
||||
assert isinstance(self.impl, FlashInferImpl)
|
||||
if (self.impl.sinks is not None
|
||||
and self.impl.sinks.dtype != torch.float32):
|
||||
self.impl.sinks = self.impl.sinks.to(torch.float32)
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return self.attn_backend
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""Multi-headed attention without any cache, used for ViT."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0, \
|
||||
f"num_heads ({self.num_heads}) is not " \
|
||||
f"divisible by num_kv_heads ({self.num_kv_heads})"
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
# Determine the attention backend
|
||||
backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)
|
||||
|
||||
# Some auto-selected backends can be upgraded
|
||||
# to upstream flash attention if available.
|
||||
# If vllm native fa is selected, we use it directly.
|
||||
use_upstream_fa = False
|
||||
if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||
dtype):
|
||||
backend = _Backend.FLASH_ATTN
|
||||
use_upstream_fa = True
|
||||
|
||||
if current_platform.is_rocm() or current_platform.is_xpu():
|
||||
# currently, only torch_sdpa is supported on rocm/xpu
|
||||
self.attn_backend = _Backend.TORCH_SDPA
|
||||
else:
|
||||
|
||||
self.attn_backend = backend if backend in {
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.PALLAS,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
_Backend.FLASH_ATTN,
|
||||
} else _Backend.TORCH_SDPA
|
||||
|
||||
if (self.attn_backend == _Backend.XFORMERS
|
||||
and not check_xformers_availability()):
|
||||
self.attn_backend = _Backend.TORCH_SDPA
|
||||
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
if use_upstream_fa:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
self._flash_attn_varlen_func = flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
self._flash_attn_varlen_func = flash_attn_varlen_func
|
||||
|
||||
logger.info_once(
|
||||
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
|
||||
f"use_upstream_fa: {use_upstream_fa}")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Input shape:
|
||||
(batch_size x seq_len x hidden_size) or
|
||||
(batch_size x seq_len x num_heads x head_size)
|
||||
"""
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
|
||||
query = query.view(bsz, q_len, self.num_heads, self.head_size)
|
||||
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (num_repeat := self.num_queries_per_kv) > 1:
|
||||
# Handle MQA and GQA
|
||||
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
|
||||
step=q_len,
|
||||
dtype=torch.int32,
|
||||
device=query.device)
|
||||
cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
|
||||
step=kv_len,
|
||||
dtype=torch.int32,
|
||||
device=key.device)
|
||||
|
||||
out = self._flash_attn_varlen_func(
|
||||
query.flatten(0, 1),
|
||||
key.flatten(0, 1),
|
||||
value.flatten(0, 1),
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=q_len,
|
||||
max_seqlen_k=kv_len,
|
||||
softmax_scale=self.scale,
|
||||
)
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
|
||||
out = xops.memory_efficient_attention_forward(query,
|
||||
key,
|
||||
value,
|
||||
scale=self.scale)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
query, key, value = (x.transpose(1, 2)
|
||||
for x in (query, key, value))
|
||||
out = F.scaled_dot_product_attention(query,
|
||||
key,
|
||||
value,
|
||||
scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
elif self.attn_backend == _Backend.PALLAS:
|
||||
query, key, value = (x.transpose(1, 2)
|
||||
for x in (query, key, value))
|
||||
from torch_xla.experimental.custom_kernel import flash_attention
|
||||
out = flash_attention(query, key, value, sm_scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
elif self.attn_backend == _Backend.ROCM_AITER_FA:
|
||||
from aiter import flash_attn_varlen_func
|
||||
|
||||
# ROCm Flash Attention expects (batch, seq, heads, head_dim)
|
||||
out = flash_attn_varlen_func(query,
|
||||
key,
|
||||
value,
|
||||
softmax_scale=self.scale)
|
||||
else:
|
||||
# ViT attention hasn't supported this backend yet
|
||||
raise NotImplementedError(
|
||||
f"ViT attention hasn't supported {self.attn_backend} "
|
||||
f"backend yet.")
|
||||
|
||||
return out.reshape(bsz, q_len, -1)
|
||||
|
||||
|
||||
def wait_for_kv_layer_from_connector(layer_name: str):
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return
|
||||
|
||||
connector = get_kv_transfer_group()
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
assert isinstance(attn_metadata, dict)
|
||||
connector.wait_for_layer_load(layer_name)
|
||||
|
||||
|
||||
def maybe_save_kv_layer_to_connector(
|
||||
layer_name: str,
|
||||
kv_cache_layer: List[torch.Tensor],
|
||||
):
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return
|
||||
|
||||
connector = get_kv_transfer_group()
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
assert isinstance(attn_metadata, dict)
|
||||
connector.save_kv_layer(layer_name, kv_cache_layer,
|
||||
attn_metadata[layer_name])
|
||||
|
||||
|
||||
def unified_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
output = self.impl.forward(self, query, key, value, kv_cache,
|
||||
attn_metadata)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
return output
|
||||
|
||||
|
||||
def unified_attention_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query).contiguous()
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_attention",
|
||||
op_func=unified_attention,
|
||||
fake_impl=unified_attention_fake,
|
||||
tags=tag_cudagraph_unsafe,
|
||||
)
|
||||
|
||||
|
||||
def unified_attention_with_output(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
output_scale=output_scale,
|
||||
output_block_scale=output_block_scale)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
|
||||
|
||||
def unified_attention_with_output_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_attention_with_output",
|
||||
op_func=unified_attention_with_output,
|
||||
mutates_args=["output", "output_block_scale"],
|
||||
fake_impl=unified_attention_with_output_fake,
|
||||
tags=tag_cudagraph_unsafe,
|
||||
)
|
||||
0
vllm/attention/layers/__init__.py
Normal file
0
vllm/attention/layers/__init__.py
Normal file
93
vllm/attention/layers/chunked_local_attention.py
Normal file
93
vllm/attention/layers/chunked_local_attention.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from typing import ClassVar, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata)
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig, QuantizationConfig
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, CommonAttentionMetadata,
|
||||
make_local_attention_virtual_batches, subclass_attention_backend)
|
||||
|
||||
from ..layer import Attention
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_chunked_local_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend,
|
||||
attention_chunk_size: int,
|
||||
block_size: int,
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
|
||||
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.NEVER
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> AttentionMetadata:
|
||||
common_attn_metadata = make_local_attention_virtual_batches(
|
||||
attention_chunk_size, common_attn_metadata, block_size)
|
||||
return super().build(common_prefix_len, common_attn_metadata,
|
||||
fast_build)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=ChunkedLocalAttentionBuilder)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class ChunkedLocalAttention(Attention):
|
||||
|
||||
def __init__(self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
attention_chunk_size: int,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
prefix: str = ""):
|
||||
dtype = torch.get_default_dtype()
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
underlying_attn_backend = get_attn_backend(head_size, dtype,
|
||||
kv_cache_dtype,
|
||||
block_size)
|
||||
|
||||
attn_backend = create_chunked_local_attention_backend(
|
||||
underlying_attn_backend, attention_chunk_size, block_size)
|
||||
else:
|
||||
# in v0 the local attention is handled inside the backends
|
||||
attn_backend = None
|
||||
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=alibi_slopes,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
||||
attn_backend=attn_backend)
|
||||
162
vllm/attention/layers/cross_attention.py
Normal file
162
vllm/attention/layers/cross_attention.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from copy import copy
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||
subclass_attention_backend)
|
||||
from vllm.v1.kv_cache_interface import CrossAttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_max_encoder_len(vllm_config: "VllmConfig") -> int:
|
||||
"""Gets the max number of encoder input tokens from the config.
|
||||
"""
|
||||
sc = vllm_config.scheduler_config
|
||||
assert sc and isinstance(sc.max_num_encoder_input_tokens, int), \
|
||||
"max_num_encoder_input_tokens must be int for enc-dec models"
|
||||
return sc.max_num_encoder_input_tokens
|
||||
|
||||
|
||||
def _get_cross_slot_mapping(encoder_seq_lens: np.ndarray,
|
||||
block_table_tensor: torch.Tensor,
|
||||
kv_cache_spec: CrossAttentionSpec,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
"""Get cross-attention slot mappings."""
|
||||
|
||||
block_size = kv_cache_spec.block_size
|
||||
slot_mappings = []
|
||||
|
||||
# Find indices with non-zero encoder sequence lengths
|
||||
# The majority of parallel requests will be running the
|
||||
# decoder, so this list should be relatively small.
|
||||
active_indices = np.nonzero(encoder_seq_lens)[0]
|
||||
|
||||
for req_index in active_indices:
|
||||
encoder_seq_len = encoder_seq_lens[req_index].item()
|
||||
|
||||
# Calculate the number of blocks needed for this request
|
||||
num_blocks_needed = cdiv(encoder_seq_len, block_size)
|
||||
|
||||
# Get the block IDs for this request from the tensor
|
||||
req_block_ids = block_table_tensor[req_index]
|
||||
|
||||
# Get only the blocks we need (first num_blocks_needed blocks)
|
||||
needed_block_ids = req_block_ids[:num_blocks_needed]
|
||||
|
||||
# All needed blocks are allocated
|
||||
i_values = torch.arange(encoder_seq_len,
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
block_indices = i_values // block_size
|
||||
block_offsets = i_values % block_size
|
||||
block_numbers = needed_block_ids[block_indices]
|
||||
slot_mapping = block_numbers * block_size + block_offsets
|
||||
|
||||
slot_mappings.append(slot_mapping)
|
||||
|
||||
if slot_mappings:
|
||||
return torch.cat(slot_mappings)
|
||||
else:
|
||||
return torch.empty(0, dtype=torch.int64, device=device)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_cross_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]:
|
||||
prefix = "CrossAttention_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class CrossAttentionBuilder(underlying_builder): # type: ignore
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> AttentionMetadata:
|
||||
new_metadata = copy(common_attn_metadata)
|
||||
new_metadata.causal = False
|
||||
max_encoder_len = _get_max_encoder_len(self.vllm_config)
|
||||
new_metadata.max_seq_len = max_encoder_len
|
||||
|
||||
new_metadata.seq_lens = torch.full(
|
||||
(new_metadata.num_reqs, ),
|
||||
max_encoder_len,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
new_metadata.seq_lens_cpu = torch.full(
|
||||
(new_metadata.num_reqs, ),
|
||||
max_encoder_len,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
new_metadata.slot_mapping = _get_cross_slot_mapping(
|
||||
new_metadata.encoder_seq_lens, new_metadata.block_table_tensor,
|
||||
self.kv_cache_spec, self.device)
|
||||
return super().build(common_prefix_len, new_metadata, fast_build)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=CrossAttentionBuilder)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class CrossAttention(Attention):
|
||||
"""
|
||||
Cross-attention for encoder-decoder models.
|
||||
Handles attention between decoder queries and encoder keys/values.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
attn_type: Optional[str] = None,
|
||||
**kwargs):
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
underlying_attn_backend = get_attn_backend(head_size, dtype,
|
||||
kv_cache_dtype,
|
||||
block_size)
|
||||
|
||||
attn_backend = create_cross_attention_backend(
|
||||
underlying_attn_backend)
|
||||
else:
|
||||
# in v0 cross attention is handled inside the backends
|
||||
attn_backend = None
|
||||
|
||||
if attn_type is not None:
|
||||
assert attn_type == AttentionType.ENCODER_DECODER, (
|
||||
"CrossAttention only supports AttentionType.ENCODER_DECODER")
|
||||
|
||||
super().__init__(num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
cache_config=cache_config,
|
||||
attn_backend=attn_backend,
|
||||
attn_type=AttentionType.ENCODER_DECODER,
|
||||
**kwargs)
|
||||
86
vllm/attention/layers/encoder_only_attention.py
Normal file
86
vllm/attention/layers/encoder_only_attention.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from copy import copy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||
subclass_attention_backend)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_encoder_only_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]:
|
||||
prefix = "EncoderOnlyAttention_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> AttentionMetadata:
|
||||
new_common_attn_metadata = copy(common_attn_metadata)
|
||||
new_common_attn_metadata.causal = False
|
||||
return super().build(common_prefix_len, new_common_attn_metadata,
|
||||
fast_build)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=EncoderOnlyAttentionBuilder)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class EncoderOnlyAttention(Attention):
|
||||
"""
|
||||
Encoder attention is a special case that doesn't need a KV Cache.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
attn_type: Optional[str] = None,
|
||||
**kwargs):
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
underlying_attn_backend = get_attn_backend(head_size, dtype,
|
||||
kv_cache_dtype,
|
||||
block_size)
|
||||
|
||||
attn_backend = create_encoder_only_attention_backend(
|
||||
underlying_attn_backend)
|
||||
else:
|
||||
# in v0 encoder only attention is handled inside the backends
|
||||
attn_backend = None
|
||||
|
||||
if attn_type is not None:
|
||||
assert attn_type == AttentionType.ENCODER_ONLY, \
|
||||
"EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
|
||||
|
||||
super().__init__(num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
cache_config=cache_config,
|
||||
attn_backend=attn_backend,
|
||||
attn_type=AttentionType.ENCODER_ONLY,
|
||||
**kwargs)
|
||||
0
vllm/attention/ops/__init__.py
Normal file
0
vllm/attention/ops/__init__.py
Normal file
405
vllm/attention/ops/chunked_prefill_paged_decode.py
Normal file
405
vllm/attention/ops/chunked_prefill_paged_decode.py
Normal file
@@ -0,0 +1,405 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Authors:
|
||||
# - Burkhard Ringlein <ngl@zurich.ibm.com>
|
||||
# - Jan van Lunteren <jvl@zurich.ibm.com>
|
||||
# - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
|
||||
# - Thomas Parnell <tpa@zurich.ibm.com>
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .prefix_prefill import context_attention_fwd
|
||||
|
||||
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cdiv_fn(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_paged_attention_2d(
|
||||
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
|
||||
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
|
||||
sink_ptr, # [num_query_heads]
|
||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
alibi_slopes_ptr, # [num_query_heads]
|
||||
scale, # float32
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
out_scale_inv,
|
||||
num_query_heads: tl.constexpr, # int
|
||||
num_queries_per_kv: tl.constexpr, # int
|
||||
num_queries_per_kv_padded: tl.constexpr, # int
|
||||
block_table_stride: tl.int64, # int
|
||||
query_stride_0: tl.int64, # int
|
||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||
output_stride_0: tl.int64, # int
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
x: tl.constexpr, # int
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
stride_k_cache_2: tl.int64, # int
|
||||
stride_k_cache_3: tl.int64, # int
|
||||
stride_k_cache_4: tl.int64, # int
|
||||
stride_v_cache_0: tl.int64, # int
|
||||
stride_v_cache_1: tl.int64, # int
|
||||
stride_v_cache_2: tl.int64, # int
|
||||
stride_v_cache_3: tl.int64, # int
|
||||
filter_by_query_len: tl.constexpr, # bool
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
USE_SINKS: tl.constexpr, # bool
|
||||
USE_FP8: tl.constexpr,
|
||||
FP8_MIN: tl.constexpr = float8_info.min,
|
||||
FP8_MAX: tl.constexpr = float8_info.max):
|
||||
seq_idx = tl.program_id(0)
|
||||
kv_head_idx = tl.program_id(1)
|
||||
|
||||
if filter_by_query_len:
|
||||
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
||||
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx +
|
||||
1)
|
||||
cur_batch_query_len = cur_batch_in_all_stop_index \
|
||||
- cur_batch_in_all_start_index
|
||||
if cur_batch_query_len > 1:
|
||||
return
|
||||
else:
|
||||
cur_batch_in_all_start_index = seq_idx
|
||||
|
||||
query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange(
|
||||
0, num_queries_per_kv_padded)
|
||||
|
||||
query_offset = (cur_batch_in_all_start_index * query_stride_0 +
|
||||
query_head_idx[:, None] * query_stride_1)
|
||||
|
||||
head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv
|
||||
head_mask = head_mask & (query_head_idx < num_query_heads)
|
||||
|
||||
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
|
||||
0).to(tl.int1)
|
||||
|
||||
# Q : (num_queries_per_kv, HEAD_SIZE,)
|
||||
Q = tl.load(
|
||||
query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :],
|
||||
mask=dim_mask[None, :] & head_mask[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
block_table_offset = seq_idx * block_table_stride
|
||||
|
||||
if not USE_SINKS:
|
||||
M = tl.full([num_queries_per_kv_padded],
|
||||
float("-inf"),
|
||||
dtype=tl.float32)
|
||||
else:
|
||||
M = tl.load(
|
||||
sink_ptr + query_head_idx,
|
||||
mask=head_mask,
|
||||
other=float("-inf"),
|
||||
).to(dtype=tl.float32)
|
||||
|
||||
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED],
|
||||
dtype=tl.float32)
|
||||
|
||||
# sequence len for this particular sequence
|
||||
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||
|
||||
# alibi slope for this head
|
||||
if USE_ALIBI_SLOPES:
|
||||
alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx,
|
||||
mask=head_mask,
|
||||
other=0.0)
|
||||
|
||||
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
|
||||
|
||||
# iterate through tiles
|
||||
for j in range(0, num_blocks):
|
||||
|
||||
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
|
||||
|
||||
offs_n = tl.arange(0, BLOCK_SIZE)
|
||||
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||
|
||||
v_offset = (physical_block_idx * stride_v_cache_0 +
|
||||
kv_head_idx * stride_v_cache_1 +
|
||||
offs_d[None, :] * stride_v_cache_2 +
|
||||
offs_n[:, None] * stride_v_cache_3)
|
||||
|
||||
k_offset = (physical_block_idx * stride_k_cache_0 +
|
||||
kv_head_idx * stride_k_cache_1 +
|
||||
(offs_d[:, None] // x) * stride_k_cache_2 +
|
||||
offs_n[None, :] * stride_k_cache_3 +
|
||||
(offs_d[:, None] % x) * stride_k_cache_4)
|
||||
|
||||
# K : (HEAD_SIZE, BLOCK_SIZE)
|
||||
K_load = tl.load(key_cache_ptr + k_offset,
|
||||
mask=dim_mask[:, None],
|
||||
other=0.0)
|
||||
|
||||
if K_load.dtype.is_fp8():
|
||||
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
||||
else:
|
||||
K = K_load
|
||||
|
||||
# V : (BLOCK_SIZE, HEAD_SIZE)
|
||||
V_load = tl.load(value_cache_ptr + v_offset,
|
||||
mask=dim_mask[None, :],
|
||||
other=0.0)
|
||||
|
||||
if V_load.dtype.is_fp8():
|
||||
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
||||
else:
|
||||
V = V_load
|
||||
|
||||
seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32)
|
||||
seq_mask = seq_offset[None, :] < boundary
|
||||
|
||||
# S : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
S = tl.where(head_mask[:, None] & seq_mask, 0.0,
|
||||
float("-inf")).to(tl.float32)
|
||||
S += scale * tl.dot(Q, K)
|
||||
|
||||
context_len = seq_len - 1
|
||||
|
||||
if SLIDING_WINDOW > 0:
|
||||
S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S,
|
||||
-10000)
|
||||
|
||||
if USE_ALIBI_SLOPES:
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
|
||||
# compute running maximum
|
||||
# m_j : (num_queries_per_kv,)
|
||||
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||
|
||||
# P : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
P = tl.exp(S - m_j[:, None])
|
||||
|
||||
# l_j : (num_queries_per_kv,)
|
||||
l_j = tl.sum(P, axis=1)
|
||||
|
||||
# alpha : (num_queries_per_kv, )
|
||||
alpha = tl.exp(M - m_j)
|
||||
|
||||
# acc : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update constants
|
||||
L = L * alpha + l_j
|
||||
M = m_j
|
||||
|
||||
# acc : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
acc += tl.dot(P.to(V.dtype), V)
|
||||
|
||||
# epilogue
|
||||
acc = acc / L[:, None]
|
||||
if USE_FP8:
|
||||
acc = acc * tl.load(out_scale_inv)
|
||||
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||
|
||||
output_offset = (cur_batch_in_all_start_index * output_stride_0 +
|
||||
query_head_idx * output_stride_1)
|
||||
|
||||
tl.store(
|
||||
output_ptr + output_offset[:, None] +
|
||||
tl.arange(0, HEAD_SIZE_PADDED)[None, :],
|
||||
acc,
|
||||
mask=dim_mask[None, :] & head_mask[:, None],
|
||||
)
|
||||
|
||||
|
||||
def chunked_prefill_paged_decode(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_table,
|
||||
query_start_loc,
|
||||
seq_lens,
|
||||
max_seq_len,
|
||||
max_query_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
sm_scale=None,
|
||||
output_scale=None,
|
||||
# Optional tensor for sinks
|
||||
sinks=None,
|
||||
):
|
||||
|
||||
if sm_scale is None:
|
||||
sm_scale = 1.0 / (query.shape[1]**0.5)
|
||||
|
||||
use_alibi_slopes = alibi_slopes is not None
|
||||
|
||||
if sliding_window is None or sliding_window <= 0:
|
||||
sliding_window = 0
|
||||
|
||||
if max_query_len > 1:
|
||||
context_attention_fwd(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
o=output,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
b_loc=block_table,
|
||||
b_start_loc=query_start_loc,
|
||||
b_seq_len=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
max_input_len=max_query_len,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
alibi_slopes=alibi_slopes,
|
||||
sliding_window=sliding_window,
|
||||
sm_scale=sm_scale,
|
||||
skip_decode=True,
|
||||
fp8_out_scale=output_scale,
|
||||
sinks=sinks,
|
||||
)
|
||||
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs = len(seq_lens)
|
||||
num_query_heads = query.shape[1]
|
||||
num_kv_heads = key.shape[1]
|
||||
num_queries_per_kv = query.shape[1] // key.shape[1]
|
||||
head_size = query.shape[2]
|
||||
|
||||
# Conversion of FP8 Tensor from uint8 storage to
|
||||
# appropriate torch.dtype for interpretation by Triton
|
||||
if "fp8" in kv_cache_dtype:
|
||||
assert key_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
assert value_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
|
||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||
target_dtype = current_platform.fp8_dtype()
|
||||
elif kv_cache_dtype == "fp8_e5m2":
|
||||
target_dtype = torch.float8_e5m2
|
||||
else:
|
||||
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
|
||||
|
||||
key_cache = key_cache.view(target_dtype)
|
||||
value_cache = value_cache.view(target_dtype)
|
||||
|
||||
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
|
||||
16)
|
||||
|
||||
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||
use_custom = use_rocm_custom_paged_attention(
|
||||
query.dtype,
|
||||
head_size,
|
||||
block_size,
|
||||
num_queries_per_kv,
|
||||
max_seq_len,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
alibi_slopes,
|
||||
sinks,
|
||||
)
|
||||
if use_custom:
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
|
||||
_PARTITION_SIZE_ROCM)
|
||||
assert _PARTITION_SIZE_ROCM % block_size == 0
|
||||
total_num_seq = block_table.shape[0]
|
||||
tmp_output = torch.empty(
|
||||
size=(total_num_seq, num_query_heads, max_num_partitions,
|
||||
head_size),
|
||||
dtype=query.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(total_num_seq, num_query_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
|
||||
ops.paged_attention_rocm(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale=sm_scale,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
query_start_loc=query_start_loc,
|
||||
block_size=block_size,
|
||||
max_seq_len=max_seq_len,
|
||||
alibi_slopes=alibi_slopes,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
fp8_out_scale=output_scale,
|
||||
)
|
||||
else:
|
||||
kernel_paged_attention_2d[(
|
||||
num_seqs,
|
||||
num_kv_heads,
|
||||
)](
|
||||
output_ptr=output,
|
||||
query_ptr=query,
|
||||
key_cache_ptr=key_cache,
|
||||
value_cache_ptr=value_cache,
|
||||
sink_ptr=sinks,
|
||||
block_tables_ptr=block_table,
|
||||
seq_lens_ptr=seq_lens,
|
||||
alibi_slopes_ptr=alibi_slopes,
|
||||
scale=sm_scale,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
out_scale_inv=1.0 /
|
||||
output_scale if output_scale is not None else 1.0,
|
||||
num_query_heads=num_query_heads,
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
num_queries_per_kv_padded=num_queries_per_kv_padded,
|
||||
block_table_stride=block_table.stride(0),
|
||||
query_stride_0=query.stride(0),
|
||||
query_stride_1=query.stride(1),
|
||||
output_stride_0=output.stride(0),
|
||||
output_stride_1=output.stride(1),
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
SLIDING_WINDOW=sliding_window,
|
||||
x=key_cache.shape[4],
|
||||
stride_k_cache_0=key_cache.stride(0),
|
||||
stride_k_cache_1=key_cache.stride(1),
|
||||
stride_k_cache_2=key_cache.stride(2),
|
||||
stride_k_cache_3=key_cache.stride(3),
|
||||
stride_k_cache_4=key_cache.stride(4),
|
||||
stride_v_cache_0=value_cache.stride(0),
|
||||
stride_v_cache_1=value_cache.stride(1),
|
||||
stride_v_cache_2=value_cache.stride(2),
|
||||
stride_v_cache_3=value_cache.stride(3),
|
||||
filter_by_query_len=True,
|
||||
query_start_len_ptr=query_start_loc,
|
||||
USE_SINKS=sinks is not None,
|
||||
USE_FP8=output_scale is not None,
|
||||
)
|
||||
345
vllm/attention/ops/common.py
Normal file
345
vllm/attention/ops/common.py
Normal file
@@ -0,0 +1,345 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _correct_attn_cp_out_kernel(outputs_ptr, new_output_ptr, lses_ptr,
|
||||
vlse_ptr, outputs_stride_B, outputs_stride_H,
|
||||
outputs_stride_D, lses_stride_N, lses_stride_B,
|
||||
lses_stride_H, lse_idx, HEAD_DIM: tl.constexpr,
|
||||
N_ROUNDED: tl.constexpr):
|
||||
"""
|
||||
Apply the all-gathered lses to correct each local rank's attention
|
||||
output. we still need perform a cross-rank reduction to obtain the
|
||||
final attention output.
|
||||
|
||||
Args:
|
||||
outputs_ptr (triton.PointerType):
|
||||
Pointer to input tensor of shape [ B, H, D ]
|
||||
lses_ptr (triton.PointerType):
|
||||
Pointer to input tensor of shape [ N, B, H ]
|
||||
new_output_ptr (triton.PointerType):
|
||||
Pointer to output tensor of shape [ B, H, D ]
|
||||
vlse_ptr (triton.PointerType):
|
||||
Pointer to output tensor of shape [ B, H ]
|
||||
"""
|
||||
batch_idx = tl.program_id(axis=0).to(tl.int64)
|
||||
head_idx = tl.program_id(axis=1).to(tl.int64)
|
||||
d_offsets = tl.arange(0, HEAD_DIM)
|
||||
num_n_offsets = tl.arange(0, N_ROUNDED)
|
||||
|
||||
# shape = [N]
|
||||
lse_offsets = num_n_offsets * lses_stride_N + batch_idx * \
|
||||
lses_stride_B + head_idx * lses_stride_H
|
||||
|
||||
# calc final lse
|
||||
lse = tl.load(lses_ptr + lse_offsets)
|
||||
lse = tl.where((lse != lse) | (lse == float('inf')), -float('inf'), lse)
|
||||
lse_max = tl.max(lse, axis=0)
|
||||
lse -= lse_max
|
||||
lse_exp = tl.exp(lse)
|
||||
lse_acc = tl.sum(lse_exp, axis=0)
|
||||
lse = tl.log(lse_acc)
|
||||
lse += lse_max
|
||||
|
||||
lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
|
||||
tl.store(vlse_ptr + lse_offsets, lse)
|
||||
|
||||
# shape = [D]
|
||||
output_offsets = batch_idx * outputs_stride_B + \
|
||||
head_idx * outputs_stride_H + \
|
||||
d_offsets * outputs_stride_D
|
||||
|
||||
# correct output
|
||||
lse_offset = lse_idx * lses_stride_N + batch_idx * \
|
||||
lses_stride_B + head_idx * lses_stride_H
|
||||
lse_tmp = tl.load(lses_ptr + lse_offset)
|
||||
lse_finally = lse_tmp - lse
|
||||
lse_finally = tl.where(
|
||||
(lse_finally != lse_finally) | (lse_finally == float('inf')),
|
||||
-float('inf'), lse_finally)
|
||||
factor = tl.exp(lse_finally)
|
||||
output = tl.load(outputs_ptr + output_offsets)
|
||||
output = output * factor
|
||||
|
||||
tl.store(new_output_ptr + output_offsets, output)
|
||||
|
||||
|
||||
class CPTritonContext:
|
||||
""" The CPTritonContext is used to avoid recompilation of the Triton JIT.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.inner_kernel = None
|
||||
|
||||
def call_kernel(self, kernel, grid, *regular_args, **const_args):
|
||||
if self.inner_kernel is None:
|
||||
self.inner_kernel = kernel[grid](*regular_args, **const_args)
|
||||
else:
|
||||
self.inner_kernel[grid](*regular_args)
|
||||
|
||||
|
||||
def correct_attn_out(
|
||||
out: torch.Tensor, lses: torch.Tensor, cp_rank: int,
|
||||
ctx: CPTritonContext) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Correct the attention output using the all-gathered lses.
|
||||
|
||||
Args:
|
||||
out: Tensor of shape [ B, H, D ]
|
||||
lses: Tensor of shape [ N, B, H ]
|
||||
cp_rank: Current rank in the context-parallel group
|
||||
ctx: Triton context to avoid recompilation
|
||||
|
||||
Returns:
|
||||
Tuple of (out, lse) with corrected attention and final log-sum-exp.
|
||||
"""
|
||||
if ctx is None:
|
||||
ctx = CPTritonContext()
|
||||
|
||||
lse = torch.empty_like(lses[0])
|
||||
|
||||
grid = (out.shape[0], out.shape[1], 1)
|
||||
regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(),
|
||||
cp_rank)
|
||||
const_args = {
|
||||
"HEAD_DIM": out.shape[-1],
|
||||
"N_ROUNDED": lses.shape[0],
|
||||
}
|
||||
|
||||
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args,
|
||||
**const_args)
|
||||
return out, lse
|
||||
|
||||
|
||||
def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
|
||||
cp_attn_lse: torch.Tensor,
|
||||
cp_group: GroupCoordinator,
|
||||
ctx: CPTritonContext = None):
|
||||
"""
|
||||
cp_attn_out: [ B, H, D ]
|
||||
cp_attn_lse: [ B, H ]
|
||||
"""
|
||||
if cp_group.world_size == 1:
|
||||
return cp_attn_out
|
||||
|
||||
if ctx is None:
|
||||
ctx = CPTritonContext()
|
||||
|
||||
lses = torch.empty((cp_group.world_size, ) + cp_attn_lse.shape,
|
||||
dtype=cp_attn_lse.dtype,
|
||||
device=cp_attn_lse.device)
|
||||
|
||||
cp_attn_lse = cp_attn_lse.contiguous()
|
||||
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
||||
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
||||
out = cp_group.reduce_scatter(out, dim=1)
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _pack_seq_kernel(
|
||||
x_ptr, # [N, D]
|
||||
out_ptr, # [B, Lmax, D]
|
||||
lengths_ptr, # *i32, [B]
|
||||
N: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
Lmax: tl.constexpr,
|
||||
PAD_VALUE: tl.constexpr,
|
||||
BLOCK_T: tl.constexpr, # timesteps per program
|
||||
BLOCK_D: tl.constexpr # features per program
|
||||
):
|
||||
pid_b = tl.program_id(0) # batch id
|
||||
pid_t = tl.program_id(1) # block over time dimension
|
||||
pid_d = tl.program_id(2) # block over feature dimension
|
||||
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
|
||||
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
|
||||
|
||||
# Compute start index and sequence length from cumulative lengths
|
||||
in_start = 0
|
||||
for i in range(pid_b):
|
||||
in_start += tl.load(lengths_ptr + i)
|
||||
seq_len = tl.load(lengths_ptr + pid_b)
|
||||
|
||||
# valid time positions for this block
|
||||
t_mask = off_t < Lmax
|
||||
|
||||
# compute input row indices for valid (b, t)
|
||||
in_row = in_start + off_t
|
||||
valid_row = (off_t < seq_len) & t_mask
|
||||
|
||||
# Pointers
|
||||
# x_ptr: row-major [N, D]
|
||||
x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :]
|
||||
|
||||
# out_ptr: row-major [B, Lmax, D]
|
||||
out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:,
|
||||
None] * D + off_d[None, :]
|
||||
|
||||
# Initialize with PAD (cast will occur as needed based on out_ptr dtype)
|
||||
d_mask = off_d[None, :] < D
|
||||
pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32)
|
||||
tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask)
|
||||
|
||||
# Load & write only where within seq_len
|
||||
x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask)
|
||||
tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask)
|
||||
|
||||
|
||||
def pack_seq_triton(x: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
pad_value: float = -float('inf'),
|
||||
block_t: int = 64,
|
||||
block_d: int = 64) -> torch.Tensor:
|
||||
"""
|
||||
Pack sequences of different lengths into a batched tensor.
|
||||
|
||||
Args:
|
||||
x: [N, ...] - input tensor where N is total number of tokens
|
||||
lengths: [B] - sequence lengths for each batch
|
||||
pad_value: value to use for padding
|
||||
block_t: block size for time dimension
|
||||
block_d: block size for feature dimension
|
||||
|
||||
Returns:
|
||||
packed: [B, Lmax, ...] - packed tensor
|
||||
"""
|
||||
|
||||
# Handle multi-dimensional input by reshaping to (N, -1)
|
||||
original_shape = x.shape
|
||||
if len(original_shape) > 2:
|
||||
N = original_shape[0]
|
||||
x_reshaped = x.reshape(N, -1)
|
||||
D = x_reshaped.shape[1]
|
||||
else:
|
||||
N, D = x.shape
|
||||
x_reshaped = x
|
||||
|
||||
B = lengths.numel()
|
||||
Lmax = int(lengths.max().item())
|
||||
|
||||
# Starts are computed inside the kernel from lengths
|
||||
|
||||
out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype)
|
||||
|
||||
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
|
||||
_pack_seq_kernel[grid](x_reshaped,
|
||||
out,
|
||||
lengths.int(),
|
||||
N,
|
||||
D,
|
||||
Lmax,
|
||||
PAD_VALUE=float(pad_value),
|
||||
BLOCK_T=block_t,
|
||||
BLOCK_D=block_d,
|
||||
num_warps=4,
|
||||
num_stages=2)
|
||||
|
||||
# Reshape output back to original dimensions (except first dimension)
|
||||
if len(original_shape) > 2:
|
||||
output_shape = (B, Lmax) + original_shape[1:]
|
||||
out = out.reshape(output_shape)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _unpack_seq_triton_kernel(
|
||||
packed_ptr, # [B, Lmax, D]
|
||||
out_ptr, # [N, D]
|
||||
lengths_ptr, # *i32, [B]
|
||||
B: tl.constexpr,
|
||||
Lmax: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
BLOCK_T: tl.constexpr, # timesteps per program
|
||||
BLOCK_D: tl.constexpr # features per program
|
||||
):
|
||||
pid_b = tl.program_id(0) # batch id
|
||||
pid_t = tl.program_id(1) # block over time dimension
|
||||
pid_d = tl.program_id(2) # block over feature dimension
|
||||
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
|
||||
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
|
||||
|
||||
# bounds: compute start from cumulative lengths
|
||||
in_start = 0
|
||||
for i in range(pid_b):
|
||||
in_start += tl.load(lengths_ptr + i)
|
||||
seq_len = tl.load(lengths_ptr + pid_b)
|
||||
|
||||
# valid time positions for this block
|
||||
t_mask = off_t < Lmax
|
||||
valid_row = (off_t < seq_len) & t_mask
|
||||
|
||||
# compute output row indices for valid (b, t)
|
||||
out_row = in_start + off_t
|
||||
|
||||
# Pointers
|
||||
# packed_ptr: row-major [B, Lmax, D]
|
||||
packed_row_ptr = packed_ptr + (pid_b * Lmax +
|
||||
off_t)[:, None] * D + off_d[None, :]
|
||||
|
||||
# out_ptr: row-major [N, D]
|
||||
out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :]
|
||||
|
||||
# Load from packed tensor and store to output
|
||||
d_mask = off_d[None, :] < D
|
||||
packed_vals = tl.load(packed_row_ptr, mask=valid_row[:, None] & d_mask)
|
||||
tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask)
|
||||
|
||||
|
||||
def unpack_seq_triton(packed_tensor: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
block_t: int = 64,
|
||||
block_d: int = 64) -> torch.Tensor:
|
||||
"""
|
||||
Unpack a packed decode query tensor back to the original format.
|
||||
Efficient Triton implementation.
|
||||
|
||||
Args:
|
||||
packed_tensor: [B, Lmax, ...] - packed tensor from pack_seq_triton
|
||||
lengths: [B] - sequence lengths for each batch
|
||||
block_t: block size for time dimension
|
||||
block_d: block size for feature dimension
|
||||
|
||||
Returns:
|
||||
unpacked_tensor: [N, ...] where N = sum(lengths)
|
||||
"""
|
||||
|
||||
# Handle multi-dimensional input by reshaping to (B, Lmax, -1)
|
||||
original_shape = packed_tensor.shape
|
||||
if len(original_shape) > 3:
|
||||
B, Lmax = original_shape[:2]
|
||||
packed_reshaped = packed_tensor.reshape(B, Lmax, -1)
|
||||
D = packed_reshaped.shape[2]
|
||||
else:
|
||||
B, Lmax, D = packed_tensor.shape
|
||||
packed_reshaped = packed_tensor
|
||||
|
||||
# Calculate total number of elements
|
||||
N = int(lengths.sum().item())
|
||||
|
||||
out = torch.empty((N, D),
|
||||
device=packed_tensor.device,
|
||||
dtype=packed_tensor.dtype)
|
||||
|
||||
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
|
||||
_unpack_seq_triton_kernel[grid](packed_reshaped,
|
||||
out,
|
||||
lengths.int(),
|
||||
B,
|
||||
Lmax,
|
||||
D,
|
||||
BLOCK_T=block_t,
|
||||
BLOCK_D=block_d,
|
||||
num_warps=4,
|
||||
num_stages=2)
|
||||
|
||||
# Reshape output back to original dimensions (except first dimension)
|
||||
if len(original_shape) > 3:
|
||||
output_shape = (N, ) + original_shape[2:]
|
||||
out = out.reshape(output_shape)
|
||||
|
||||
return out
|
||||
192
vllm/attention/ops/flashmla.py
Normal file
192
vllm/attention/ops/flashmla.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if current_platform.is_cuda():
|
||||
try:
|
||||
import vllm._flashmla_C # noqa: F401
|
||||
_flashmla_C_AVAILABLE = True
|
||||
except ImportError:
|
||||
_flashmla_C_AVAILABLE = False
|
||||
else:
|
||||
_flashmla_C_AVAILABLE = False
|
||||
|
||||
if current_platform.is_cuda():
|
||||
try:
|
||||
import vllm._flashmla_extension_C # noqa: F401
|
||||
_flashmla_extension_C_AVAILABLE = True
|
||||
except ImportError:
|
||||
_flashmla_extension_C_AVAILABLE = False
|
||||
else:
|
||||
_flashmla_extension_C_AVAILABLE = False
|
||||
|
||||
|
||||
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Return: is_supported_flag, unsupported_reason (optional).
|
||||
"""
|
||||
if not current_platform.is_cuda():
|
||||
return False, "FlashMLA is only supported on CUDA devices."
|
||||
if current_platform.get_device_capability()[0] != 9:
|
||||
return False, "FlashMLA is only supported on Hopper devices."
|
||||
if not _flashmla_C_AVAILABLE:
|
||||
return False, "vllm._flashmla_C is not available, likely was not "\
|
||||
"compiled due to insufficient nvcc version or a supported arch "\
|
||||
"(only sm90a currently) was not in the list of target arches to "\
|
||||
"compile for."
|
||||
return True, None
|
||||
|
||||
|
||||
def get_mla_metadata(
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_q_tokens_per_head_k: int,
|
||||
num_heads_k: int,
|
||||
num_heads_q: Optional[int] = None,
|
||||
is_fp8_kvcache: bool = False,
|
||||
topk: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
- cache_seqlens: (batch_size), dtype torch.int32.
|
||||
- num_q_tokens_per_head_k:
|
||||
Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
|
||||
- num_heads_k: The number of k heads.
|
||||
- num_heads_q:
|
||||
The number of q heads.
|
||||
This argument is optional when sparse attention is not enabled
|
||||
- is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
|
||||
- topk: If not None, sparse attention will be enabled,
|
||||
and only tokens in the `indices` array
|
||||
passed to `flash_mla_with_kvcache_sm90` will be attended to.
|
||||
|
||||
Returns:
|
||||
- tile_scheduler_metadata:
|
||||
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
||||
- num_splits: (batch_size + 1), dtype torch.int32.
|
||||
"""
|
||||
return torch.ops._flashmla_C.get_mla_decoding_metadata(
|
||||
cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q,
|
||||
is_fp8_kvcache, topk)
|
||||
|
||||
|
||||
def flash_mla_with_kvcache(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
head_dim_v: int,
|
||||
tile_scheduler_metadata: torch.Tensor,
|
||||
num_splits: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
descale_q: Optional[torch.Tensor] = None,
|
||||
descale_k: Optional[torch.Tensor] = None,
|
||||
is_fp8_kvcache: bool = False,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
- q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
- k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||
- block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||
- cache_seqlens: (batch_size), torch.int32.
|
||||
- head_dim_v: Head dimension of v.
|
||||
- tile_scheduler_metadata:
|
||||
(num_sm_parts, TileSchedulerMetaDataSize), torch.int32,
|
||||
returned by get_mla_metadata.
|
||||
- num_splits:
|
||||
(batch_size + 1), torch.int32, returned by get_mla_metadata.
|
||||
- softmax_scale: float.
|
||||
The scale of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(head_dim).
|
||||
- causal: bool. Whether to apply causal attention mask.
|
||||
- descale_q: (batch_size),
|
||||
torch.float32. Descaling factors for Q, used for fp8 quantization.
|
||||
- descale_k: (batch_size),
|
||||
torch.float32. Descaling factors for K, used for fp8 quantization.
|
||||
- is_fp8_kvcache: bool.
|
||||
Whether the k_cache and v_cache are in fp8 format.
|
||||
For the format of FP8 KV cache, please refer to README.md
|
||||
- indices: (batch_size, seq_len_q, topk), torch.int32.
|
||||
If not None, sparse attention will be enabled,
|
||||
and only tokens in the `indices` array will be attended to.
|
||||
Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
|
||||
For details about how to set up `indices`, please refer to README.md.
|
||||
|
||||
Returns:
|
||||
- out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
- softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1]**(-0.5)
|
||||
if indices is not None:
|
||||
# NOTE (zyongye): sparse attention is also causal
|
||||
# since it only attend to the tokens before
|
||||
# but here `causal` should not be specified
|
||||
assert not causal, \
|
||||
"causal must be `false` if sparse attention is enabled."
|
||||
assert (descale_q is None) == (
|
||||
descale_k is None
|
||||
), "descale_q and descale_k should be both None or both not None"
|
||||
|
||||
if indices is None and q.element_size() == 1:
|
||||
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
|
||||
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
|
||||
causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)
|
||||
else:
|
||||
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
|
||||
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
|
||||
causal, tile_scheduler_metadata, num_splits, is_fp8_kvcache,
|
||||
indices)
|
||||
return out, softmax_lse
|
||||
|
||||
|
||||
def flash_mla_sparse_prefill(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
sm_scale: float,
|
||||
d_v: int = 512,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sparse attention prefill kernel
|
||||
|
||||
Args:
|
||||
- q: [s_q, h_q, d_qk], bfloat16
|
||||
- kv: [s_kv, h_kv, d_qk], bfloat16
|
||||
- indices: [s_q, h_kv, topk], int32.
|
||||
Invalid indices should be set to -1 or numbers >= s_kv
|
||||
- sm_scale: float
|
||||
- d_v: The dimension of value vectors. Can only be 512
|
||||
|
||||
Returns:
|
||||
- (output, max_logits, lse)
|
||||
About the definition of output,
|
||||
max_logits and lse, please refer to README.md
|
||||
- output: [s_q, h_q, d_v], bfloat16
|
||||
- max_logits: [s_q, h_q], float
|
||||
- lse: [s_q, h_q], float, 2-based log-sum-exp
|
||||
"""
|
||||
results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices,
|
||||
sm_scale, d_v)
|
||||
return results
|
||||
|
||||
|
||||
#
|
||||
# TODO: Add fake functions
|
||||
#
|
||||
# @register_fake("_flashmla_C::get_mla_metadata")
|
||||
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# return ....
|
||||
#
|
||||
# @register_fake("_flashmla_C::fwd_kvcache_mla")
|
||||
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# return ....
|
||||
#
|
||||
43
vllm/attention/ops/merge_attn_states.py
Normal file
43
vllm/attention/ops/merge_attn_states.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def merge_attn_states(
|
||||
output: torch.Tensor,
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
suffix_output: torch.Tensor,
|
||||
suffix_lse: torch.Tensor,
|
||||
output_lse: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
|
||||
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel
|
||||
# is not support for FP8 dtype, fallback to use Triton kernel.
|
||||
def supported_dtypes(o: torch.Tensor) -> bool:
|
||||
return o.dtype in [torch.float32, torch.half, torch.bfloat16]
|
||||
|
||||
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA
|
||||
# kernel load/store 128b(16 bytes) per memory issue within
|
||||
# thread. Namely, the headsize(headdim) must be multiple of
|
||||
# pack_size (float32 -> 4, half/bfloat16 -> 8).
|
||||
def supported_headdim(o: torch.Tensor) -> bool:
|
||||
headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
if o.dtype == torch.float32:
|
||||
return headdim % 4 == 0
|
||||
return headdim % 8 == 0
|
||||
|
||||
if (current_platform.is_cuda() and supported_dtypes(output)
|
||||
and supported_headdim(output)):
|
||||
from vllm._custom_ops import merge_attn_states
|
||||
return merge_attn_states(output, prefix_output, prefix_lse,
|
||||
suffix_output, suffix_lse, output_lse)
|
||||
else:
|
||||
from vllm.attention.ops.triton_merge_attn_states import (
|
||||
merge_attn_states)
|
||||
return merge_attn_states(output, prefix_output, prefix_lse,
|
||||
suffix_output, suffix_lse, output_lse)
|
||||
262
vllm/attention/ops/paged_attn.py
Normal file
262
vllm/attention/ops/paged_attn.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm import _custom_ops as ops
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
|
||||
@dataclass
|
||||
class PagedAttentionMetadata:
|
||||
"""Metadata for PagedAttention."""
|
||||
# (batch_size,). The length of sequences (entire tokens seen so far) per
|
||||
# sequence.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
|
||||
class PagedAttention:
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [32, 64, 80, 96, 112, 120, 128, 192, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size * num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def split_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x = 16 // kv_cache.element_size()
|
||||
num_blocks = kv_cache.shape[1]
|
||||
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
|
||||
-1, x)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
||||
return key_cache, value_cache
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
) -> None:
|
||||
ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping.flatten(),
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
tp_rank: int = 0,
|
||||
blocksparse_local_blocks: int = 0,
|
||||
blocksparse_vert_stride: int = 0,
|
||||
blocksparse_block_size: int = 64,
|
||||
blocksparse_head_sliding_step: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||
# use blocksparse paged attention
|
||||
block_size = value_cache.size(-1)
|
||||
assert (blocksparse_block_size > 0 and
|
||||
blocksparse_block_size % block_size == 0), \
|
||||
(f"{blocksparse_block_size=} needs to be a multiple of"
|
||||
f"{block_size=} used in block_tables.")
|
||||
|
||||
output = torch.empty_like(query)
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
# TODO(woosuk): Tune this heuristic.
|
||||
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
|
||||
use_v1 = (max_seq_len <= 8192
|
||||
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
|
||||
|
||||
if use_v1:
|
||||
# Run PagedAttention V1.
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def forward_prefix(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
seq_lens_tensor: torch.Tensor,
|
||||
max_query_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
sliding_window: Optional[int],
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty_like(query)
|
||||
max_seq_len = None
|
||||
context_attention_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
# query_start_loc is (batch_size + 1,)
|
||||
query_start_loc,
|
||||
seq_lens_tensor,
|
||||
max_seq_len,
|
||||
max_query_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache = src_kv_cache[0]
|
||||
dst_key_cache = dst_kv_cache[0]
|
||||
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||
|
||||
src_value_cache = src_kv_cache[1]
|
||||
dst_value_cache = dst_kv_cache[1]
|
||||
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
124
vllm/attention/ops/pallas_kv_cache_update.py
Normal file
124
vllm/attention/ops/pallas_kv_cache_update.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
|
||||
import jax
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
|
||||
from vllm.utils import cdiv
|
||||
|
||||
|
||||
def _kv_cache_update_kernel(
|
||||
# Prefetch
|
||||
slices_ref, # [3, padded_num_slices], list of (kv_cache_start,
|
||||
# new_kv_start, slice_len)
|
||||
num_slices_ref, # [1]
|
||||
# Input
|
||||
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
|
||||
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
|
||||
# head_dim]
|
||||
# Output
|
||||
_, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
|
||||
# Scratch
|
||||
scratch, # [num_slices_per_block, page_size, num_combined_kv_heads,
|
||||
# head_dim]
|
||||
sem,
|
||||
):
|
||||
async_copies = []
|
||||
block_idx = pl.program_id(0)
|
||||
num_slices_per_block = scratch.shape[0]
|
||||
|
||||
# Copy from new_kv_hbm_ref to scratch
|
||||
for i in range(num_slices_per_block):
|
||||
offset_i = i + block_idx * num_slices_per_block
|
||||
new_kv_start = jax.lax.select(offset_i < num_slices_ref[0],
|
||||
slices_ref[1, offset_i], 0)
|
||||
length = jax.lax.select(offset_i < num_slices_ref[0],
|
||||
slices_ref[2, offset_i], 0)
|
||||
async_copy = pltpu.make_async_copy(
|
||||
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
|
||||
scratch.at[i, pl.ds(0, length), ...],
|
||||
sem,
|
||||
)
|
||||
async_copy.start()
|
||||
async_copies.append(async_copy)
|
||||
|
||||
for async_copy in async_copies:
|
||||
async_copy.wait()
|
||||
|
||||
# Copy from scratch to kv_cache_hbm_ref
|
||||
async_copies.clear()
|
||||
for i in range(num_slices_per_block):
|
||||
offset_i = i + block_idx * num_slices_per_block
|
||||
kv_cache_start = jax.lax.select(offset_i < num_slices_ref[0],
|
||||
slices_ref[0, offset_i], 0)
|
||||
length = jax.lax.select(offset_i < num_slices_ref[0],
|
||||
slices_ref[2, offset_i], 0)
|
||||
async_copy = pltpu.make_async_copy(
|
||||
scratch.at[i, pl.ds(0, length), ...],
|
||||
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
|
||||
sem,
|
||||
)
|
||||
async_copy.start()
|
||||
async_copies.append(async_copy)
|
||||
for async_copy in async_copies:
|
||||
async_copy.wait()
|
||||
|
||||
|
||||
@functools.partial(
|
||||
jax.jit,
|
||||
static_argnames=["page_size", "num_slices_per_block"],
|
||||
)
|
||||
def kv_cache_update(
|
||||
new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim]
|
||||
slices: jax.
|
||||
Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
|
||||
kv_cache: jax.
|
||||
Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
|
||||
num_kv_update_slices: jax.Array, # [1]
|
||||
*,
|
||||
page_size: int = 32,
|
||||
num_slices_per_block: int = 8,
|
||||
):
|
||||
_, num_combined_kv_heads, head_dim = new_kv.shape
|
||||
assert kv_cache.shape[1] == num_combined_kv_heads
|
||||
assert kv_cache.shape[2] == head_dim
|
||||
assert head_dim % 128 == 0
|
||||
# TODO: Add dynamic check to make sure that the all the slice lengths are
|
||||
# smaller or equal to page_size
|
||||
|
||||
in_specs = [
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||
]
|
||||
|
||||
out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
|
||||
out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]
|
||||
|
||||
scalar_prefetches = [slices, num_kv_update_slices]
|
||||
scratch = pltpu.VMEM(
|
||||
(num_slices_per_block, page_size, num_combined_kv_heads, head_dim),
|
||||
new_kv.dtype,
|
||||
)
|
||||
|
||||
scratch_shapes = [
|
||||
scratch,
|
||||
pltpu.SemaphoreType.DMA,
|
||||
]
|
||||
|
||||
kernel = pl.pallas_call(
|
||||
_kv_cache_update_kernel,
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=len(scalar_prefetches),
|
||||
in_specs=in_specs,
|
||||
out_specs=out_specs,
|
||||
grid=(cdiv(num_kv_update_slices[0], num_slices_per_block), ),
|
||||
scratch_shapes=scratch_shapes,
|
||||
),
|
||||
out_shape=out_shape,
|
||||
input_output_aliases={len(scalar_prefetches) + 1: 0},
|
||||
)
|
||||
|
||||
return kernel(*scalar_prefetches, new_kv, kv_cache)[0]
|
||||
928
vllm/attention/ops/prefix_prefill.py
Normal file
928
vllm/attention/ops/prefix_prefill.py
Normal file
@@ -0,0 +1,928 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
|
||||
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
# Static kernels parameters
|
||||
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
|
||||
NUM_WARPS = 4 if current_platform.is_rocm() else 8
|
||||
|
||||
# To check compatibility
|
||||
IS_TURING = current_platform.get_device_capability() == (7, 5)
|
||||
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
|
||||
|
||||
# Here's an example autotuner config for this kernel. This config does provide
|
||||
# a performance improvement, but dramatically increases first call latency in
|
||||
# triton 3.2. Because of this tradeoff, it's currently commented out.
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \
|
||||
# "num_unroll_cache": 4, \
|
||||
# "num_unroll_request": 1 } | \
|
||||
# ({"kpack": 2, "waves_per_eu": 2} \
|
||||
# if current_platform.is_rocm() else {}), \
|
||||
# num_warps=4, \
|
||||
# num_stages=1)
|
||||
# ],
|
||||
# key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"]
|
||||
# )
|
||||
@triton.jit
|
||||
def _fwd_kernel(Q,
|
||||
K,
|
||||
V,
|
||||
K_cache,
|
||||
V_cache,
|
||||
sink_ptr,
|
||||
B_Loc,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
out_scale_inv,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
x: tl.constexpr,
|
||||
Out,
|
||||
stride_b_loc_b,
|
||||
stride_b_loc_s,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_k_cache_bs,
|
||||
stride_k_cache_h,
|
||||
stride_k_cache_d,
|
||||
stride_k_cache_bl: tl.constexpr,
|
||||
stride_k_cache_x,
|
||||
stride_v_cache_bs,
|
||||
stride_v_cache_h,
|
||||
stride_v_cache_d,
|
||||
stride_v_cache_bl,
|
||||
num_queries_per_kv: tl.constexpr,
|
||||
IN_PRECISION: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DMODEL_PADDED: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SLIDING_WINDOW: tl.constexpr,
|
||||
num_unroll_cache: tl.constexpr,
|
||||
num_unroll_request: tl.constexpr,
|
||||
SKIP_DECODE: tl.constexpr,
|
||||
USE_SINKS: tl.constexpr,
|
||||
USE_FP8: tl.constexpr,
|
||||
MAX_Q_LEN: tl.constexpr = 0,
|
||||
MAX_CTX_LEN: tl.constexpr = 0,
|
||||
FP8_MIN: tl.constexpr = float8_info.min,
|
||||
FP8_MAX: tl.constexpr = float8_info.max):
|
||||
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // num_queries_per_kv
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
|
||||
cur_batch_query_len = (cur_batch_in_all_stop_index -
|
||||
cur_batch_in_all_start_index)
|
||||
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
|
||||
|
||||
if SKIP_DECODE and cur_batch_query_len == 1:
|
||||
return
|
||||
|
||||
# start position inside of the query
|
||||
# generally, N goes over kv, while M goes over query_len
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
# [BLOCK_SIZE]; starts at 0
|
||||
offs_bs_n = tl.arange(0, BLOCK_SIZE)
|
||||
# [N]; starts at 0
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
# [D]; starts at 0
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
|
||||
# [M]; starts at current position in query
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# [M,D]
|
||||
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||
|
||||
dim_mask = tl.where(
|
||||
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,
|
||||
0).to(tl.int1) # [D]
|
||||
|
||||
q = tl.load(Q + off_q,
|
||||
mask=dim_mask[None, :] &
|
||||
(offs_m[:, None] < cur_batch_query_len),
|
||||
other=0.0) # [M,D]
|
||||
|
||||
# initialize pointer to m and l
|
||||
if not USE_SINKS:
|
||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
else:
|
||||
m_i = tl.load(
|
||||
sink_ptr + tl.full([BLOCK_M], cur_head, dtype=tl.int64),
|
||||
mask=(offs_m < cur_batch_query_len),
|
||||
other=float("-inf"),
|
||||
).to(dtype=tl.float32)
|
||||
|
||||
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D]
|
||||
|
||||
# compute query against context (no causal mask here)
|
||||
for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \
|
||||
loop_unroll_factor=num_unroll_cache):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_SIZE)
|
||||
# -- compute qk ----
|
||||
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||
(start_n // BLOCK_SIZE) * stride_b_loc_s).to(tl.int64)
|
||||
# [D,BLOCK_SIZE]
|
||||
off_k = (
|
||||
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||
((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl +
|
||||
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||
|
||||
# [BLOCK_SIZE,D]
|
||||
off_v = (bn[:, None] * stride_v_cache_bs +
|
||||
cur_kv_head * stride_v_cache_h +
|
||||
offs_d[None, :] * stride_v_cache_d +
|
||||
offs_bs_n[:, None] * stride_v_cache_bl)
|
||||
|
||||
if start_n + BLOCK_SIZE > cur_batch_ctx_len or \
|
||||
BLOCK_DMODEL != BLOCK_DMODEL_PADDED:
|
||||
k_load = tl.load(
|
||||
K_cache + off_k,
|
||||
mask=dim_mask[:, None] &
|
||||
((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len),
|
||||
other=0.0) # [D,N]
|
||||
else:
|
||||
k_load = tl.load(K_cache + off_k)
|
||||
|
||||
if k_load.dtype.is_fp8():
|
||||
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
|
||||
else:
|
||||
k = k_load
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N]
|
||||
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||
qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||
float("-inf"))
|
||||
qk *= sm_scale
|
||||
if SLIDING_WINDOW > 0:
|
||||
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
|
||||
# Q entries in sequence
|
||||
# (start_n + offs_bs_n[None, :]) are the positions of
|
||||
# KV entries in sequence
|
||||
# So the condition makes sure each entry in Q only attends
|
||||
# to KV entries not more than SLIDING_WINDOW away.
|
||||
#
|
||||
# We can't use -inf here, because the
|
||||
# sliding window may lead to the entire row being masked.
|
||||
# This then makes m_ij contain -inf, which causes NaNs in
|
||||
# exp().
|
||||
qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
|
||||
(start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, qk,
|
||||
-10000)
|
||||
|
||||
# compute running maximum
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, axis=1)
|
||||
alpha = tl.exp(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update acc
|
||||
if start_n + BLOCK_SIZE > cur_batch_ctx_len or \
|
||||
BLOCK_DMODEL != BLOCK_DMODEL_PADDED:
|
||||
v_load = tl.load(
|
||||
V_cache + off_v,
|
||||
mask=dim_mask[None, :] &
|
||||
((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len),
|
||||
other=0.0) # [N,D]
|
||||
else:
|
||||
v_load = tl.load(V_cache + off_v)
|
||||
|
||||
if v_load.dtype.is_fp8():
|
||||
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
|
||||
else:
|
||||
v = v_load
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
|
||||
# # update m_i and l_i
|
||||
l_i = l_i * alpha + l_ij
|
||||
m_i = m_ij
|
||||
|
||||
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
|
||||
offs_d[:, None] * stride_kd)
|
||||
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
|
||||
offs_d[None, :] * stride_vd)
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
# block_mask is 0 when we're already past the current query length
|
||||
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
|
||||
|
||||
# compute query against itself (with causal mask)
|
||||
for start_n in tl.range(0, \
|
||||
block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \
|
||||
loop_unroll_factor=num_unroll_request):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=dim_mask[:, None] &
|
||||
((start_n + offs_n[None, :]) < cur_batch_query_len),
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||
qk *= sm_scale
|
||||
# apply causal mask
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||
float("-inf"))
|
||||
if SLIDING_WINDOW > 0:
|
||||
qk = tl.where(
|
||||
offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW,
|
||||
qk, -10000)
|
||||
|
||||
# compute running maximum
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, axis=1)
|
||||
alpha = tl.exp(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update acc
|
||||
v = tl.load(v_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=dim_mask[None, :] &
|
||||
((start_n + offs_n[:, None]) < cur_batch_query_len),
|
||||
other=0.0)
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
|
||||
# update m_i and l_i
|
||||
l_i = l_i * alpha + l_ij
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l_i[:, None]
|
||||
|
||||
# initialize pointers to output
|
||||
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
||||
out_ptrs = Out + off_o
|
||||
if USE_FP8:
|
||||
acc = acc * tl.load(out_scale_inv)
|
||||
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||
tl.store(out_ptrs,
|
||||
acc,
|
||||
mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len))
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_flash_attn_v2(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
K_cache,
|
||||
V_cache,
|
||||
B_Loc,
|
||||
sm_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
B_Ctxlen,
|
||||
block_size,
|
||||
x,
|
||||
Out,
|
||||
stride_b_loc_b,
|
||||
stride_b_loc_s,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_k_cache_bs,
|
||||
stride_k_cache_h,
|
||||
stride_k_cache_d,
|
||||
stride_k_cache_bl,
|
||||
stride_k_cache_x,
|
||||
stride_v_cache_bs,
|
||||
stride_v_cache_h,
|
||||
stride_v_cache_d,
|
||||
stride_v_cache_bl,
|
||||
num_queries_per_kv: int,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // num_queries_per_kv
|
||||
|
||||
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||
|
||||
q = tl.load(Q + off_q,
|
||||
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
# # initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
||||
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||
other=0).to(tl.int64)
|
||||
off_k = (
|
||||
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||
((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl +
|
||||
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||
off_v = (bn[:, None] * stride_v_cache_bs +
|
||||
cur_kv_head * stride_v_cache_h +
|
||||
offs_d[None, :] * stride_v_cache_d +
|
||||
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||
k = tl.load(K_cache + off_k,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||
float("-inf"))
|
||||
qk *= sm_scale
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
l_i_new = alpha * l_i + l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# scale acc
|
||||
acc_scale = alpha
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(V_cache + off_v,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
|
||||
offs_d[:, None] * stride_kd)
|
||||
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
|
||||
offs_d[None, :] * stride_vd)
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
block_mask = tl.where(
|
||||
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
|
||||
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :])
|
||||
< cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||
float("-inf"))
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
l_i_new = alpha * l_i + l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# scale acc
|
||||
acc_scale = alpha
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None])
|
||||
< cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
# acc /= l_i[:, None]
|
||||
# initialize pointers to output
|
||||
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs,
|
||||
acc,
|
||||
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_alibi(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
K_cache,
|
||||
V_cache,
|
||||
B_Loc,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
Alibi_slopes,
|
||||
block_size,
|
||||
x,
|
||||
Out,
|
||||
stride_b_loc_b,
|
||||
stride_b_loc_s,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_k_cache_bs,
|
||||
stride_k_cache_h,
|
||||
stride_k_cache_d,
|
||||
stride_k_cache_bl,
|
||||
stride_k_cache_x,
|
||||
stride_v_cache_bs,
|
||||
stride_v_cache_h,
|
||||
stride_v_cache_d,
|
||||
stride_v_cache_bl,
|
||||
num_queries_per_kv: int,
|
||||
IN_PRECISION: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr, # head size
|
||||
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
||||
BLOCK_N: tl.constexpr,
|
||||
SKIP_DECODE: tl.constexpr,
|
||||
):
|
||||
# attn_bias[]
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // num_queries_per_kv
|
||||
|
||||
# cur_batch_seq_len: the length of prompts
|
||||
# cur_batch_ctx_len: the length of prefix
|
||||
# cur_batch_in_all_start_index: the start id of the dim=0
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
|
||||
cur_batch_query_len = (cur_batch_in_all_stop_index -
|
||||
cur_batch_in_all_start_index)
|
||||
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
|
||||
|
||||
if SKIP_DECODE and cur_batch_query_len == 1:
|
||||
return
|
||||
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||
|
||||
dim_mask = tl.where(
|
||||
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
|
||||
|
||||
q = tl.load(Q + off_q,
|
||||
mask=dim_mask[None, :] &
|
||||
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
|
||||
other=0.0)
|
||||
|
||||
# # initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
|
||||
|
||||
alibi_slope = tl.load(Alibi_slopes + cur_head)
|
||||
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
|
||||
alibi_start_k = 0
|
||||
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
||||
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||
other=0).to(tl.int64)
|
||||
off_k = (
|
||||
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||
((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl +
|
||||
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||
off_v = (bn[:, None] * stride_v_cache_bs +
|
||||
cur_kv_head * stride_v_cache_h +
|
||||
offs_d[None, :] * stride_v_cache_d +
|
||||
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||
k_load = tl.load(K_cache + off_k,
|
||||
mask=dim_mask[:, None] &
|
||||
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
|
||||
other=0.0) # [D,N]
|
||||
|
||||
if k_load.dtype.is_fp8():
|
||||
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
|
||||
else:
|
||||
k = k_load
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||
float("-inf"))
|
||||
qk *= sm_scale
|
||||
|
||||
# load alibi
|
||||
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
|
||||
alibi_start_q[:, None]) * alibi_slope
|
||||
alibi = tl.where(
|
||||
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi,
|
||||
float("-inf"))
|
||||
qk += alibi
|
||||
alibi_start_k += BLOCK_N
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
l_i_new = alpha * l_i + l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# scale acc
|
||||
acc_scale = alpha
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v_load = tl.load(V_cache + off_v,
|
||||
mask=dim_mask[None, :] &
|
||||
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
|
||||
other=0.0)
|
||||
if v_load.dtype.is_fp8():
|
||||
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
|
||||
else:
|
||||
v = v_load
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
|
||||
offs_d[:, None] * stride_kd)
|
||||
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
|
||||
offs_d[None, :] * stride_vd)
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
block_mask = tl.where(
|
||||
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
|
||||
|
||||
# init alibi
|
||||
alibi_slope = tl.load(Alibi_slopes + cur_head)
|
||||
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
|
||||
alibi_start_k = cur_batch_ctx_len
|
||||
# # init debugger
|
||||
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
|
||||
# offset_db_k = tl.arange(0, BLOCK_N)
|
||||
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(
|
||||
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=dim_mask[:, None] & ((start_n + offs_n[None, :])
|
||||
< cur_batch_seq_len - cur_batch_ctx_len),
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk = tl.dot(q, k, acc=qk, input_precision='ieee')
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||
float("-inf"))
|
||||
|
||||
# load alibi
|
||||
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
|
||||
alibi_start_q[:, None]) * alibi_slope
|
||||
alibi = tl.where(
|
||||
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi,
|
||||
float("-inf"))
|
||||
qk += alibi
|
||||
alibi_start_k += BLOCK_N
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
l_i_new = alpha * l_i + l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# scale acc
|
||||
acc_scale = alpha
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(
|
||||
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=dim_mask[None, :] & ((start_n + offs_n[:, None])
|
||||
< cur_batch_seq_len - cur_batch_ctx_len),
|
||||
other=0.0)
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
acc = acc / l_i[:, None]
|
||||
|
||||
# initialize pointers to output
|
||||
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs,
|
||||
acc,
|
||||
mask=dim_mask[None, :] &
|
||||
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))
|
||||
return
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def context_attention_fwd(q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
kv_cache_dtype: str,
|
||||
k_cache,
|
||||
v_cache,
|
||||
b_loc,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_seq_len,
|
||||
max_input_len,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
sm_scale=None,
|
||||
skip_decode=False,
|
||||
fp8_out_scale=None,
|
||||
sinks=None):
|
||||
|
||||
q_dtype_is_f32 = q.dtype is torch.float32
|
||||
|
||||
# Turing does have tensor core for float32 multiplication
|
||||
# use ieee as fallback for triton kernels work. There is also
|
||||
# warning on vllm/config.py to inform users this fallback
|
||||
# implementation
|
||||
IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None
|
||||
|
||||
# Conversion of FP8 Tensor from uint8 storage to
|
||||
# appropriate torch.dtype for interpretation by Triton
|
||||
if "fp8" in kv_cache_dtype:
|
||||
assert k_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
assert v_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
|
||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||
target_dtype = current_platform.fp8_dtype()
|
||||
elif kv_cache_dtype == "fp8_e5m2":
|
||||
target_dtype = torch.float8_e5m2
|
||||
else:
|
||||
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
|
||||
|
||||
k_cache = k_cache.view(target_dtype)
|
||||
v_cache = v_cache.view(target_dtype)
|
||||
|
||||
if (k_cache.dtype == torch.uint8
|
||||
or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"):
|
||||
raise ValueError("kv_cache_dtype='auto' unsupported for\
|
||||
FP8 KV Cache prefill kernel")
|
||||
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
# round up Lk to a power of 2 - this is required for Triton block size
|
||||
Lk_padded = triton.next_power_of_2(Lk)
|
||||
|
||||
if sm_scale is None:
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
num_queries_per_kv = q.shape[1] // k.shape[1]
|
||||
|
||||
assert batch + 1 == len(b_start_loc)
|
||||
|
||||
# 0 means "disable"
|
||||
if sliding_window is None or sliding_window <= 0:
|
||||
sliding_window = 0
|
||||
|
||||
if alibi_slopes is not None:
|
||||
assert sinks is None, "Sinks arg is not supported with alibi"
|
||||
assert fp8_out_scale is None, "FP8 output not supported with alibi"
|
||||
# need to reduce num. blocks when using fp32
|
||||
# due to increased use of GPU shared memory
|
||||
# if q.dtype is torch.float32:
|
||||
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK
|
||||
# batch, head,
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
_fwd_kernel_alibi[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
b_loc,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
alibi_slopes,
|
||||
v_cache.shape[3],
|
||||
k_cache.shape[4],
|
||||
o,
|
||||
b_loc.stride(0),
|
||||
b_loc.stride(1),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
k_cache.stride(
|
||||
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
v_cache.stride(0),
|
||||
v_cache.stride(1),
|
||||
v_cache.stride(2),
|
||||
v_cache.stride(
|
||||
3), #[num_blocks, num_kv_heads, head_size, block_size]
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
IN_PRECISION=IN_PRECISION,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||
BLOCK_N=BLOCK,
|
||||
SKIP_DECODE=skip_decode,
|
||||
num_warps=NUM_WARPS,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
||||
max_seq_len = 0 if max_seq_len is None else max_seq_len
|
||||
extra_kargs = {}
|
||||
if current_platform.is_rocm():
|
||||
extra_kargs = {"kpack": 1, "waves_per_eu": 2}
|
||||
|
||||
grid = lambda META: (batch, head,
|
||||
triton.cdiv(max_input_len, META["BLOCK_M"]))
|
||||
_fwd_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
sinks,
|
||||
b_loc,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
1.0 / fp8_out_scale if fp8_out_scale is not None else 1.0,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
k_cache.shape[4],
|
||||
o,
|
||||
b_loc.stride(0),
|
||||
b_loc.stride(1),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
k_cache.stride(
|
||||
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
v_cache.stride(0),
|
||||
v_cache.stride(1),
|
||||
v_cache.stride(2),
|
||||
v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size]
|
||||
BLOCK_SIZE=v_cache.shape[3],
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
IN_PRECISION=IN_PRECISION,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||
SLIDING_WINDOW=sliding_window,
|
||||
SKIP_DECODE=skip_decode,
|
||||
USE_FP8=fp8_out_scale is not None,
|
||||
BLOCK_M=128,
|
||||
BLOCK_N=64,
|
||||
num_unroll_cache=4,
|
||||
num_unroll_request=1,
|
||||
num_warps=4,
|
||||
num_stages=1,
|
||||
USE_SINKS=sinks is not None,
|
||||
**extra_kargs)
|
||||
return
|
||||
104
vllm/attention/ops/rocm_aiter_mla.py
Normal file
104
vllm/attention/ops/rocm_aiter_mla.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
|
||||
def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
|
||||
max_block_per_batch: int,
|
||||
device: torch.device) -> tuple[torch.Tensor, ...]:
|
||||
paged_kv_indices = torch.zeros(max_batch_size * max_block_per_batch,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
paged_kv_indptr = torch.zeros(max_batch_size + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
paged_kv_last_page_lens = torch.full((max_batch_size, ),
|
||||
block_size,
|
||||
dtype=torch.int32)
|
||||
qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device)
|
||||
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr
|
||||
|
||||
|
||||
def aiter_mla_decode_fwd(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
sm_scale: float,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: Optional[torch.Tensor] = None,
|
||||
kv_indices: Optional[torch.Tensor] = None,
|
||||
kv_last_page_lens: Optional[torch.Tensor] = None,
|
||||
logit_cap: float = 0.0,
|
||||
):
|
||||
|
||||
torch.ops.vllm.rocm_aiter_mla_decode_fwd(q,
|
||||
kv_buffer.view(
|
||||
-1, 1, 1, q.shape[-1]),
|
||||
o,
|
||||
qo_indptr,
|
||||
max_seqlen_qo,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
sm_scale=sm_scale,
|
||||
logit_cap=logit_cap)
|
||||
|
||||
|
||||
def mla_decode_fwd_impl(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: Optional[torch.Tensor] = None,
|
||||
kv_indices: Optional[torch.Tensor] = None,
|
||||
kv_last_page_lens: Optional[torch.Tensor] = None,
|
||||
sm_scale: float = 1.0,
|
||||
logit_cap: float = 0.0,
|
||||
) -> None:
|
||||
from aiter.mla import mla_decode_fwd
|
||||
|
||||
mla_decode_fwd(q,
|
||||
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
||||
o,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
max_seqlen_qo,
|
||||
sm_scale=sm_scale,
|
||||
logit_cap=logit_cap)
|
||||
|
||||
|
||||
def mla_decode_fwd_fake(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: Optional[torch.Tensor] = None,
|
||||
kv_indices: Optional[torch.Tensor] = None,
|
||||
kv_last_page_lens: Optional[torch.Tensor] = None,
|
||||
sm_scale: float = 1.0,
|
||||
logit_cap: float = 0.0,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
if is_torch_equal_or_newer("2.7.0"):
|
||||
tags = ()
|
||||
else:
|
||||
tags = (torch.Tag.needs_fixed_stride_order, ),
|
||||
direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd",
|
||||
op_func=mla_decode_fwd_impl,
|
||||
mutates_args=["o"],
|
||||
fake_impl=mla_decode_fwd_fake,
|
||||
tags=tags)
|
||||
102
vllm/attention/ops/rocm_aiter_paged_attn.py
Normal file
102
vllm/attention/ops/rocm_aiter_paged_attn.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import aiter as rocm_aiter
|
||||
import torch
|
||||
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class AITERPagedAttention(PagedAttention):
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
) -> None:
|
||||
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
|
||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||
value_cache, slot_mapping,
|
||||
kv_cache_dtype, k_scale,
|
||||
v_scale)
|
||||
else:
|
||||
kv_cache_torch_dtype = (FP8_DTYPE
|
||||
if "fp8" in kv_cache_dtype else torch.int8)
|
||||
key_cache = key_cache.view(kv_cache_torch_dtype)
|
||||
value_cache = value_cache.view(kv_cache_torch_dtype)
|
||||
|
||||
rocm_aiter.reshape_and_cache_with_pertoken_quant(
|
||||
key, value, key_cache, value_cache, k_scale, v_scale,
|
||||
slot_mapping.flatten(), True)
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
tp_rank: int = 0,
|
||||
blocksparse_local_blocks: int = 0,
|
||||
blocksparse_vert_stride: int = 0,
|
||||
blocksparse_block_size: int = 64,
|
||||
blocksparse_head_sliding_step: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
|
||||
return PagedAttention.forward_decode(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
num_kv_heads=num_kv_heads,
|
||||
scale=scale,
|
||||
alibi_slopes=alibi_slopes,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
tp_rank=tp_rank,
|
||||
blocksparse_local_blocks=blocksparse_local_blocks,
|
||||
blocksparse_vert_stride=blocksparse_vert_stride,
|
||||
blocksparse_block_size=blocksparse_block_size,
|
||||
blocksparse_head_sliding_step=blocksparse_head_sliding_step)
|
||||
|
||||
if "fp8" in kv_cache_dtype:
|
||||
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||
|
||||
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||
# use blocksparse paged attention
|
||||
block_size = value_cache.size(-1)
|
||||
assert (blocksparse_block_size > 0 and
|
||||
blocksparse_block_size % block_size == 0), \
|
||||
(f"{blocksparse_block_size=} needs to be a multiple of"
|
||||
f"{block_size=} used in block_tables.")
|
||||
|
||||
output = torch.empty_like(query)
|
||||
block_size = value_cache.shape[3]
|
||||
max_num_blocks_per_seq = cdiv(max_seq_len, block_size)
|
||||
|
||||
rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables,
|
||||
seq_lens, max_num_blocks_per_seq, k_scale,
|
||||
v_scale, output)
|
||||
return output
|
||||
691
vllm/attention/ops/triton_decode_attention.py
Normal file
691
vllm/attention/ops/triton_decode_attention.py
Normal file
@@ -0,0 +1,691 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py
|
||||
# which was originally adapted from
|
||||
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
|
||||
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
|
||||
|
||||
# Changes:
|
||||
# - Add support for page size >= 1.
|
||||
|
||||
# Copyright 2025 vLLM Team
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Memory-efficient attention for decoding.
|
||||
It supports page size >= 1.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from packaging import version
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
is_hip_ = current_platform.is_rocm()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Only print the following warnings when triton version < 3.2.0.
|
||||
# The issue won't affect performance or accuracy.
|
||||
if version.parse(triton.__version__) < version.parse('3.2.0'):
|
||||
logger.warning(
|
||||
"The following error message 'operation scheduled before its operands' "
|
||||
"can be ignored.")
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tanh(x):
|
||||
# Tanh is just a scaled sigmoid
|
||||
return 2 * tl.sigmoid(2 * x) - 1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_stage1(
|
||||
Q,
|
||||
K_Buffer,
|
||||
V_Buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
Att_Out,
|
||||
stride_req_to_tokens_b,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_buf_kbs,
|
||||
stride_buf_kh,
|
||||
stride_buf_vbs,
|
||||
stride_buf_vh,
|
||||
stride_mid_ob,
|
||||
stride_mid_oh,
|
||||
stride_mid_os,
|
||||
kv_group_num: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
NUM_KV_SPLITS: tl.constexpr,
|
||||
PAGE_SIZE: tl.constexpr,
|
||||
logit_cap: tl.constexpr,
|
||||
Lk: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
split_kv_id = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_dv = tl.arange(0, BLOCK_DV)
|
||||
mask_d = offs_d < Lk
|
||||
mask_dv = offs_dv < Lv
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_req_idx = cur_batch
|
||||
|
||||
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
||||
q = tl.load(Q + off_q, mask=mask_d, other=0.0)
|
||||
|
||||
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||
split_kv_start = kv_len_per_split * split_kv_id
|
||||
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
|
||||
cur_batch_seq_len)
|
||||
|
||||
e_max = -float("inf")
|
||||
e_sum = 0.0
|
||||
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
||||
|
||||
if split_kv_end > split_kv_start:
|
||||
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||
kv_page_number = tl.load(
|
||||
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +
|
||||
offs_n // PAGE_SIZE,
|
||||
mask=offs_n < split_kv_end,
|
||||
other=0,
|
||||
)
|
||||
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
|
||||
offs_buf_k = (kv_loc[:, None] * stride_buf_kbs +
|
||||
cur_kv_head * stride_buf_kh + offs_d[None, :])
|
||||
k = tl.load(
|
||||
K_Buffer + offs_buf_k,
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
qk = tl.sum(q[None, :] * k, 1)
|
||||
qk *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))
|
||||
|
||||
offs_buf_v = (kv_loc[:, None] * stride_buf_vbs +
|
||||
cur_kv_head * stride_buf_vh + offs_dv[None, :])
|
||||
v = tl.load(
|
||||
V_Buffer + offs_buf_v,
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max)
|
||||
acc *= re_scale
|
||||
acc += tl.sum(p[:, None] * v, 0)
|
||||
|
||||
e_sum = e_sum * re_scale + tl.sum(p, 0)
|
||||
e_max = n_e_max
|
||||
|
||||
offs_mid_o = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
|
||||
split_kv_id * stride_mid_os + offs_dv)
|
||||
|
||||
tl.store(
|
||||
Att_Out + offs_mid_o,
|
||||
acc / e_sum,
|
||||
mask=(mask_dv),
|
||||
)
|
||||
|
||||
offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
|
||||
split_kv_id * stride_mid_os + Lv)
|
||||
|
||||
tl.store(
|
||||
Att_Out + offs_mid_o_1,
|
||||
e_max + tl.log(e_sum),
|
||||
)
|
||||
|
||||
|
||||
def _decode_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
att_out,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
):
|
||||
BLOCK = 64 if not is_hip_ else 8
|
||||
|
||||
NUM_KV_SPLITS = num_kv_splits
|
||||
Lk = k_buffer.shape[-1]
|
||||
Lv = v_buffer.shape[-1]
|
||||
|
||||
batch, head_num = q.shape[0], q.shape[1]
|
||||
|
||||
grid = (batch, head_num, NUM_KV_SPLITS)
|
||||
kv_group_num = q.shape[1] // k_buffer.shape[-2]
|
||||
|
||||
num_warps = 4
|
||||
if kv_group_num != 1:
|
||||
num_warps = 1 if is_hip_ else 2
|
||||
|
||||
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
_fwd_kernel_stage1[grid](
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
att_out,
|
||||
Req_to_tokens.stride(0),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
att_out.stride(0),
|
||||
att_out.stride(1),
|
||||
att_out.stride(2),
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
BLOCK_N=BLOCK,
|
||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||
PAGE_SIZE=page_size,
|
||||
logit_cap=logit_cap,
|
||||
num_warps=num_warps,
|
||||
num_stages=2,
|
||||
Lk=Lk,
|
||||
Lv=Lv,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_grouped_kernel_stage1(
|
||||
Q,
|
||||
K_Buffer,
|
||||
V_Buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
Att_Out,
|
||||
stride_req_to_tokens_b,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_buf_kbs,
|
||||
stride_buf_kh,
|
||||
stride_buf_vbs,
|
||||
stride_buf_vh,
|
||||
stride_mid_ob,
|
||||
stride_mid_oh,
|
||||
stride_mid_os,
|
||||
kv_group_num: tl.constexpr,
|
||||
q_head_num: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DPE: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_H: tl.constexpr,
|
||||
NUM_KV_SPLITS: tl.constexpr,
|
||||
PAGE_SIZE: tl.constexpr,
|
||||
logit_cap: tl.constexpr,
|
||||
Lk: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head_id = tl.program_id(1)
|
||||
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
|
||||
split_kv_id = tl.program_id(2)
|
||||
|
||||
if kv_group_num > BLOCK_H:
|
||||
VALID_BLOCK_H: tl.constexpr = BLOCK_H
|
||||
else:
|
||||
VALID_BLOCK_H: tl.constexpr = kv_group_num
|
||||
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
|
||||
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
|
||||
mask_h = mask_h & (cur_head < q_head_num)
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_dv = tl.arange(0, BLOCK_DV)
|
||||
mask_d = offs_d < Lk
|
||||
mask_dv = offs_dv < Lv
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_req_idx = cur_batch
|
||||
|
||||
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[
|
||||
None, :]
|
||||
q = tl.load(Q + offs_q,
|
||||
mask=(mask_h[:, None]) & (mask_d[None, :]),
|
||||
other=0.0)
|
||||
|
||||
if BLOCK_DPE > 0:
|
||||
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
||||
mask_dpe = offs_dpe < Lk
|
||||
off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh +
|
||||
offs_dpe[None, :])
|
||||
qpe = tl.load(Q + off_qpe,
|
||||
mask=(mask_h[:, None]) & (mask_dpe[None, :]),
|
||||
other=0.0)
|
||||
|
||||
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||
split_kv_start = kv_len_per_split * split_kv_id
|
||||
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
|
||||
cur_batch_seq_len)
|
||||
|
||||
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
|
||||
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
|
||||
|
||||
if split_kv_end > split_kv_start:
|
||||
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||
kv_page_number = tl.load(
|
||||
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +
|
||||
offs_n // PAGE_SIZE,
|
||||
mask=offs_n < split_kv_end,
|
||||
other=0,
|
||||
)
|
||||
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
|
||||
offs_buf_k = (kv_loc[None, :] * stride_buf_kbs +
|
||||
cur_kv_head * stride_buf_kh + offs_d[:, None])
|
||||
k = tl.load(
|
||||
K_Buffer + offs_buf_k,
|
||||
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
|
||||
other=0.0,
|
||||
)
|
||||
qk = tl.dot(q, k.to(q.dtype))
|
||||
if BLOCK_DPE > 0:
|
||||
offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs +
|
||||
cur_kv_head * stride_buf_kh +
|
||||
offs_dpe[:, None])
|
||||
kpe = tl.load(
|
||||
K_Buffer + offs_buf_kpe,
|
||||
mask=(offs_n[None, :] < split_kv_end) &
|
||||
(mask_dpe[:, None]),
|
||||
other=0.0,
|
||||
)
|
||||
qk += tl.dot(qpe, kpe.to(qpe.dtype))
|
||||
qk *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end),
|
||||
qk, float("-inf"))
|
||||
|
||||
offs_buf_v = (kv_loc[:, None] * stride_buf_vbs +
|
||||
cur_kv_head * stride_buf_vh + offs_dv[None, :])
|
||||
v = tl.load(
|
||||
V_Buffer + offs_buf_v,
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max[:, None])
|
||||
acc *= re_scale[:, None]
|
||||
acc += tl.dot(p.to(v.dtype), v)
|
||||
|
||||
e_sum = e_sum * re_scale + tl.sum(p, 1)
|
||||
e_max = n_e_max
|
||||
|
||||
offs_mid_o = (cur_batch * stride_mid_ob +
|
||||
cur_head[:, None] * stride_mid_oh +
|
||||
split_kv_id * stride_mid_os + offs_dv[None, :])
|
||||
|
||||
tl.store(
|
||||
Att_Out + offs_mid_o,
|
||||
acc / e_sum[:, None],
|
||||
mask=(mask_h[:, None]) & (mask_dv[None, :]),
|
||||
)
|
||||
|
||||
offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
|
||||
split_kv_id * stride_mid_os + Lv)
|
||||
|
||||
tl.store(
|
||||
Att_Out + offs_mid_o_1,
|
||||
e_max + tl.log(e_sum),
|
||||
mask=mask_h,
|
||||
)
|
||||
|
||||
|
||||
def _decode_grouped_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
att_out,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
):
|
||||
BLOCK = 32
|
||||
Lk = k_buffer.shape[-1]
|
||||
Lv = v_buffer.shape[-1]
|
||||
|
||||
# [TODO] work around shmem limit on MI3xx
|
||||
if is_hip_ and Lk >= 576:
|
||||
BLOCK = 16
|
||||
|
||||
if Lk == 576:
|
||||
BLOCK_DMODEL = 512
|
||||
BLOCK_DPE = 64
|
||||
elif Lk == 288:
|
||||
BLOCK_DMODEL = 256
|
||||
BLOCK_DPE = 32
|
||||
else:
|
||||
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
||||
BLOCK_DPE = 0
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
batch, head_num = q.shape[0], q.shape[1]
|
||||
kv_group_num = q.shape[1] // k_buffer.shape[-2]
|
||||
|
||||
BLOCK_H = 16
|
||||
NUM_KV_SPLITS = num_kv_splits
|
||||
grid = (
|
||||
batch,
|
||||
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
|
||||
NUM_KV_SPLITS,
|
||||
)
|
||||
|
||||
extra_kargs = {}
|
||||
num_stages = 2
|
||||
if is_hip_:
|
||||
# https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization
|
||||
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||
extra_kargs = {
|
||||
"waves_per_eu": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
}
|
||||
num_stages = 1
|
||||
|
||||
_fwd_grouped_kernel_stage1[grid](
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
att_out,
|
||||
Req_to_tokens.stride(0),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
att_out.stride(0),
|
||||
att_out.stride(1),
|
||||
att_out.stride(2),
|
||||
kv_group_num=kv_group_num,
|
||||
q_head_num=head_num,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_DPE=BLOCK_DPE,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
BLOCK_N=BLOCK,
|
||||
BLOCK_H=BLOCK_H,
|
||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||
PAGE_SIZE=page_size,
|
||||
logit_cap=logit_cap,
|
||||
num_warps=4,
|
||||
num_stages=num_stages,
|
||||
Lk=Lk,
|
||||
Lv=Lv,
|
||||
**extra_kargs,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_stage2(
|
||||
Mid_O,
|
||||
o,
|
||||
lse,
|
||||
B_Seqlen,
|
||||
stride_mid_ob,
|
||||
stride_mid_oh,
|
||||
stride_mid_os,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_lse_bs,
|
||||
NUM_KV_SPLITS: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DV)
|
||||
mask_d = offs_d < Lv
|
||||
|
||||
e_sum = 0.0
|
||||
e_max = -float("inf")
|
||||
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
||||
|
||||
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
|
||||
offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv
|
||||
|
||||
for split_kv_id in range(0, NUM_KV_SPLITS):
|
||||
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||
split_kv_start = kv_len_per_split * split_kv_id
|
||||
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
|
||||
cur_batch_seq_len)
|
||||
|
||||
if split_kv_end > split_kv_start:
|
||||
tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os,
|
||||
mask=mask_d,
|
||||
other=0.0)
|
||||
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
|
||||
n_e_max = tl.maximum(tlogic, e_max)
|
||||
|
||||
old_scale = tl.exp(e_max - n_e_max)
|
||||
acc *= old_scale
|
||||
exp_logic = tl.exp(tlogic - n_e_max)
|
||||
acc += exp_logic * tv
|
||||
|
||||
e_sum = e_sum * old_scale + exp_logic
|
||||
e_max = n_e_max
|
||||
|
||||
tl.store(
|
||||
o + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
|
||||
acc / e_sum,
|
||||
mask=mask_d,
|
||||
)
|
||||
lse_val = e_max + tl.log(e_sum)
|
||||
tl.store(
|
||||
lse + cur_batch * stride_lse_bs + cur_head,
|
||||
lse_val,
|
||||
)
|
||||
|
||||
|
||||
def _decode_softmax_reducev_fwd(
|
||||
logits,
|
||||
q,
|
||||
o,
|
||||
lse,
|
||||
v_buffer,
|
||||
b_seq_len,
|
||||
num_kv_splits,
|
||||
):
|
||||
batch, head_num = q.shape[0], q.shape[1]
|
||||
Lv = v_buffer.shape[-1]
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
NUM_KV_SPLITS = num_kv_splits
|
||||
|
||||
extra_kargs = {}
|
||||
if is_hip_:
|
||||
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
||||
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||
extra_kargs = {
|
||||
"waves_per_eu": 4,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
}
|
||||
|
||||
grid = (batch, head_num)
|
||||
_fwd_kernel_stage2[grid](
|
||||
logits,
|
||||
o,
|
||||
lse,
|
||||
b_seq_len,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
logits.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
lse.stride(0),
|
||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
Lv=Lv,
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
**extra_kargs,
|
||||
)
|
||||
|
||||
|
||||
def decode_attention_fwd_normal(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
lse,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap=0.0,
|
||||
):
|
||||
_decode_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len,
|
||||
num_kv_splits)
|
||||
|
||||
|
||||
def decode_attention_fwd_grouped(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
lse,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap=0.0,
|
||||
):
|
||||
_decode_grouped_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len,
|
||||
num_kv_splits)
|
||||
|
||||
|
||||
def decode_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
lse,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size=1,
|
||||
logit_cap=0.0,
|
||||
):
|
||||
assert num_kv_splits == attn_logits.shape[2]
|
||||
kv_group_num = q.shape[1] // v_buffer.shape[-2]
|
||||
|
||||
if kv_group_num == 1:
|
||||
# MHA
|
||||
decode_attention_fwd_normal(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
lse,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
)
|
||||
else:
|
||||
# GQA/MQA/MLA
|
||||
decode_attention_fwd_grouped(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
lse,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
)
|
||||
984
vllm/attention/ops/triton_flash_attention.py
Normal file
984
vllm/attention/ops/triton_flash_attention.py
Normal file
@@ -0,0 +1,984 @@
|
||||
#!/usr/bin/env python
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
|
||||
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
|
||||
(https://tridao.me/publications/flash2/flash2.pdf)
|
||||
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
|
||||
|
||||
Features supported:
|
||||
|
||||
1) Fwd with causal masking
|
||||
2) Any sequence lengths without padding (currently fwd kernel only)
|
||||
3) Support for different sequence lengths for q and k
|
||||
4) Nested tensor API currently does not support dropout or bias.
|
||||
|
||||
Not currently supported:
|
||||
|
||||
1) Non power of two head dims
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
# Avoid misleading ROCm warning.
|
||||
if current_platform.is_rocm():
|
||||
from vllm.platforms.rocm import on_gfx1x
|
||||
else:
|
||||
on_gfx1x = lambda *args, **kwargs: False
|
||||
|
||||
torch_dtype: tl.constexpr = torch.float16
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cdiv_fn(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
@triton.jit
|
||||
def max_fn(x, y):
|
||||
return tl.math.max(x, y)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
ms = tl.arange(0, m)
|
||||
ns = tl.arange(0, n)
|
||||
return philox_offset + ms[:, None] * stride + ns[None, :]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n,
|
||||
stride).to(tl.uint32)
|
||||
# TODO: use tl.randint for better performance
|
||||
return tl.rand(philox_seed, rng_offsets)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n,
|
||||
stride)
|
||||
rng_keep = rng_output > dropout_p
|
||||
return rng_keep
|
||||
|
||||
|
||||
@triton.jit
|
||||
def load_fn(block_ptr, first, second, pad):
|
||||
if first and second:
|
||||
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
|
||||
elif first:
|
||||
tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
|
||||
elif second:
|
||||
tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
|
||||
else:
|
||||
tensor = tl.load(block_ptr)
|
||||
return tensor
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
start_m,
|
||||
actual_seqlen_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
batch_philox_offset,
|
||||
encoded_softmax_block_ptr,
|
||||
block_min,
|
||||
block_max,
|
||||
offs_n_causal,
|
||||
masked_blocks,
|
||||
n_extra_tokens,
|
||||
bias_ptr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
OFFS_M: tl.constexpr,
|
||||
OFFS_N: tl.constexpr,
|
||||
PRE_LOAD_V: tl.constexpr,
|
||||
MASK_STEPS: tl.constexpr,
|
||||
ENABLE_DROPOUT: tl.constexpr,
|
||||
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
||||
PADDED_HEAD: tl.constexpr,
|
||||
USE_FP8: tl.constexpr,
|
||||
qk_scale,
|
||||
p_descale,
|
||||
):
|
||||
# loop over k, v, and update accumulator
|
||||
for start_n in range(block_min, block_max, BLOCK_N):
|
||||
# For padded blocks, we will overrun the tensor size if
|
||||
# we load all BLOCK_N. For others, the blocks are all within range.
|
||||
k = load_fn(
|
||||
K_block_ptr,
|
||||
PADDED_HEAD,
|
||||
MASK_STEPS and (n_extra_tokens != 0),
|
||||
"zero",
|
||||
)
|
||||
if PRE_LOAD_V:
|
||||
v = load_fn(
|
||||
V_block_ptr,
|
||||
MASK_STEPS and (n_extra_tokens != 0),
|
||||
PADDED_HEAD,
|
||||
"zero",
|
||||
)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
# We start from end of seqlen_k so only the first iteration would need
|
||||
# to be checked for padding if it is not a multiple of block_n
|
||||
# TODO: This can be optimized to only be true for the padded block.
|
||||
if MASK_STEPS: # noqa: SIM102
|
||||
# If this is the last block / iteration, we want to
|
||||
# mask if the sequence length is not a multiple of block size
|
||||
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
|
||||
# if not is_modulo_mn. last step might get wasted but that is okay.
|
||||
# check if this masking works for that case.
|
||||
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
|
||||
boundary_m = tl.full([BLOCK_M],
|
||||
actual_seqlen_k,
|
||||
dtype=tl.int32)
|
||||
size_n = start_n + OFFS_N[None, :]
|
||||
mask = size_n < boundary_m[:, None]
|
||||
qk = tl.where(mask, qk, float("-inf"))
|
||||
if IS_CAUSAL:
|
||||
causal_boundary = start_n + offs_n_causal
|
||||
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
|
||||
qk = tl.where(causal_mask, qk, float("-inf"))
|
||||
# -- compute qk ----
|
||||
qk += tl.dot(q, k)
|
||||
if USE_FP8:
|
||||
qk *= qk_scale
|
||||
if bias_ptr is not None:
|
||||
bias = load_fn(bias_ptr, False, MASK_STEPS
|
||||
and (n_extra_tokens != 0), "zero")
|
||||
# While bias is added after multiplying qk with sm_scale, our
|
||||
# optimization to use 2^x instead of e^x results in an additional
|
||||
# scale factor of log2(e) which we must also multiply the bias with.
|
||||
qk += bias * 1.44269504089
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
qk = qk - m_ij[:, None]
|
||||
p = tl.math.exp2(qk)
|
||||
|
||||
# CAVEAT: Must update l_ij before applying dropout
|
||||
l_ij = tl.sum(p, 1)
|
||||
if ENABLE_DROPOUT:
|
||||
philox_offset = (batch_philox_offset +
|
||||
start_m * BLOCK_M * actual_seqlen_k + start_n -
|
||||
BLOCK_N)
|
||||
keep = dropout_mask(
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
dropout_p,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
actual_seqlen_k,
|
||||
)
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
tl.store(
|
||||
encoded_softmax_block_ptr,
|
||||
tl.where(keep, p,
|
||||
-p).to(encoded_softmax_block_ptr.type.element_ty),
|
||||
)
|
||||
p = tl.where(keep, p, 0.0)
|
||||
elif RETURN_ENCODED_SOFTMAX:
|
||||
tl.store(
|
||||
encoded_softmax_block_ptr,
|
||||
p.to(encoded_softmax_block_ptr.type.element_ty),
|
||||
)
|
||||
# -- update output accumulator --
|
||||
alpha = tl.math.exp2(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
if not PRE_LOAD_V:
|
||||
v = load_fn(
|
||||
V_block_ptr,
|
||||
MASK_STEPS and (n_extra_tokens != 0),
|
||||
PADDED_HEAD,
|
||||
"zero",
|
||||
)
|
||||
# -- update m_i and l_i
|
||||
l_i = l_i * alpha + l_ij
|
||||
# update m_i and l_i
|
||||
m_i = m_ij
|
||||
|
||||
if USE_FP8:
|
||||
p *= p_descale
|
||||
|
||||
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
|
||||
|
||||
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
if bias_ptr is not None:
|
||||
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
|
||||
(0, BLOCK_N))
|
||||
return acc, l_i, m_i
|
||||
|
||||
|
||||
def get_cdna_autotune_configs():
|
||||
return [
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 256,
|
||||
'BLOCK_N': 64,
|
||||
'waves_per_eu': 2,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 128,
|
||||
'BLOCK_N': 128,
|
||||
'waves_per_eu': 2,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 256,
|
||||
'BLOCK_N': 128,
|
||||
'waves_per_eu': 2,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 128,
|
||||
'BLOCK_N': 64,
|
||||
'waves_per_eu': 1,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 128,
|
||||
'BLOCK_N': 64,
|
||||
'waves_per_eu': 3,
|
||||
'PRE_LOAD_V': True
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 128,
|
||||
'BLOCK_N': 64,
|
||||
'waves_per_eu': 3,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 64,
|
||||
'BLOCK_N': 64,
|
||||
'waves_per_eu': 4,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 32,
|
||||
'BLOCK_N': 32,
|
||||
'waves_per_eu': 4,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8),
|
||||
# TODO: This config fails with head_size not pow2 with data mismatches.
|
||||
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
|
||||
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
|
||||
|
||||
# Fails in AccelerateAMDMatmul (Triton) assert when using FP8:
|
||||
# triton.Config(
|
||||
# {
|
||||
# "BLOCK_M": 16,
|
||||
# "BLOCK_N": 16,
|
||||
# "waves_per_eu": 1,
|
||||
# "PRE_LOAD_V": False,
|
||||
# },
|
||||
# num_stages=1,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8']
|
||||
|
||||
|
||||
def get_rdna_autotune_configs():
|
||||
return [
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 32,
|
||||
'BLOCK_N': 32,
|
||||
'waves_per_eu': 4,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 32,
|
||||
'BLOCK_N': 32,
|
||||
'waves_per_eu': 2,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 32,
|
||||
'BLOCK_N': 16,
|
||||
'waves_per_eu': 4,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 32,
|
||||
'BLOCK_N': 16,
|
||||
'waves_per_eu': 2,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=2),
|
||||
# Fails in AccelerateAMDMatmul (Triton) assert when using FP8:
|
||||
# triton.Config(
|
||||
# {
|
||||
# 'BLOCK_M': 16,
|
||||
# 'BLOCK_N': 16,
|
||||
# 'waves_per_eu': 4,
|
||||
# 'PRE_LOAD_V': False
|
||||
# },
|
||||
# num_stages=1,
|
||||
# num_warps=2),
|
||||
# triton.Config(
|
||||
# {
|
||||
# 'BLOCK_M': 16,
|
||||
# 'BLOCK_N': 16,
|
||||
# 'waves_per_eu': 2,
|
||||
# 'PRE_LOAD_V': False
|
||||
# },
|
||||
# num_stages=1,
|
||||
# num_warps=2),
|
||||
# # Fall-back config.
|
||||
# triton.Config(
|
||||
# {
|
||||
# 'BLOCK_M': 16,
|
||||
# 'BLOCK_N': 16,
|
||||
# 'waves_per_eu': 1,
|
||||
# 'PRE_LOAD_V': False
|
||||
# },
|
||||
# num_stages=1,
|
||||
# num_warps=2),
|
||||
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8']
|
||||
|
||||
|
||||
def get_autotune_configs():
|
||||
if on_gfx1x():
|
||||
return get_rdna_autotune_configs()
|
||||
else:
|
||||
return get_cdna_autotune_configs()
|
||||
|
||||
|
||||
autotune_configs, autotune_keys = get_autotune_configs()
|
||||
|
||||
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=autotune_configs,
|
||||
key=autotune_keys,
|
||||
)
|
||||
@triton.jit
|
||||
def attn_fwd(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
bias,
|
||||
sm_scale,
|
||||
q_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
p_scale,
|
||||
p_descale,
|
||||
o_descale,
|
||||
L,
|
||||
Out,
|
||||
stride_qz: tl.int64,
|
||||
stride_qh: tl.int64,
|
||||
stride_qm: tl.int64,
|
||||
stride_qk: tl.int64,
|
||||
stride_kz: tl.int64,
|
||||
stride_kh: tl.int64,
|
||||
stride_kn: tl.int64,
|
||||
stride_kk: tl.int64,
|
||||
stride_vz: tl.int64,
|
||||
stride_vh: tl.int64,
|
||||
stride_vk: tl.int64,
|
||||
stride_vn: tl.int64,
|
||||
stride_oz: tl.int64,
|
||||
stride_oh: tl.int64,
|
||||
stride_om: tl.int64,
|
||||
stride_on: tl.int64,
|
||||
stride_bz: tl.int64,
|
||||
stride_bh: tl.int64,
|
||||
stride_bm: tl.int64,
|
||||
stride_bn: tl.int64,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
philox_offset_base,
|
||||
encoded_softmax,
|
||||
HQ: tl.constexpr,
|
||||
HK: tl.constexpr,
|
||||
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
||||
MAX_SEQLENS_Q: tl.constexpr,
|
||||
MAX_SEQLENS_K: tl.constexpr,
|
||||
VARLEN: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
USE_FP8: tl.constexpr,
|
||||
USE_FP8_OUT: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
PRE_LOAD_V: tl.constexpr,
|
||||
BIAS_TYPE: tl.constexpr,
|
||||
ENABLE_DROPOUT: tl.constexpr,
|
||||
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
||||
FP8_MIN: tl.constexpr = float8_info.min,
|
||||
FP8_MAX: tl.constexpr = float8_info.max,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_h_q = tl.program_id(1)
|
||||
off_z = tl.program_id(2)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
if VARLEN:
|
||||
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
|
||||
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
|
||||
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
|
||||
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
|
||||
# small for all start_m so for those we return early.
|
||||
if start_m * BLOCK_M > seqlen_q:
|
||||
return
|
||||
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
|
||||
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
|
||||
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
|
||||
else:
|
||||
cu_seqlens_q_start = 0
|
||||
cu_seqlens_k_start = 0
|
||||
seqlen_q = MAX_SEQLENS_Q
|
||||
seqlen_k = MAX_SEQLENS_K
|
||||
|
||||
# Now we compute whether we need to exit early due to causal masking.
|
||||
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
|
||||
# are completely masked, resulting in 0s written to the output, and
|
||||
# inf written to LSE. We don't need to do any GEMMs in this case.
|
||||
# This block of code determines what N is, and if this WG is operating
|
||||
# on those M rows.
|
||||
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
|
||||
if IS_CAUSAL:
|
||||
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
|
||||
# If seqlen_q != seqlen_k, attn scores are rectangular which means
|
||||
# the causal mask boundary is bottom right aligned, and ends at either
|
||||
# the top edge (seqlen_q < seqlen_k) or left edge.
|
||||
# This captures the decrease in n_blocks if we have a rectangular attn
|
||||
# matrix
|
||||
n_blocks_seqlen = cdiv_fn(
|
||||
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
|
||||
# This is what adjusts the block_max for the current WG, only
|
||||
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
|
||||
n_blocks = min(n_blocks, n_blocks_seqlen)
|
||||
# If we have no blocks after adjusting for seqlen deltas, this WG is
|
||||
# part of the blocks that are all 0. We exit early.
|
||||
if n_blocks <= 0:
|
||||
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
|
||||
off_h_q * stride_oh)
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out + o_offset,
|
||||
shape=(seqlen_q, BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
|
||||
# We still need to write 0s to the result
|
||||
# tl.store(O_block_ptr,
|
||||
# acc.to(Out.type.element_ty), boundary_check=(0,1))
|
||||
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
|
||||
# + offs_m
|
||||
# We store inf to LSE, not -inf because in the bwd pass,
|
||||
# we subtract this
|
||||
# from qk which makes it -inf, such that exp(qk - inf) = 0
|
||||
# for these masked blocks.
|
||||
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
|
||||
# tl.store(l_ptrs, l)
|
||||
# TODO: Should dropout and return encoded softmax be handled here?
|
||||
return
|
||||
|
||||
# If MQA / GQA, set the K and V head offsets appropriately.
|
||||
GROUP_SIZE: tl.constexpr = HQ // HK
|
||||
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
|
||||
|
||||
n_extra_tokens = 0
|
||||
if seqlen_k < BLOCK_N:
|
||||
n_extra_tokens = BLOCK_N - seqlen_k
|
||||
elif seqlen_k % BLOCK_N:
|
||||
n_extra_tokens = seqlen_k % BLOCK_N
|
||||
padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
|
||||
|
||||
# Compute pointers for all the tensors used in this kernel.
|
||||
q_offset = (off_z * stride_qz + off_h_q * stride_qh +
|
||||
cu_seqlens_q_start * stride_qm)
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + q_offset,
|
||||
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
k_offset = (off_z * stride_kz + off_h_k * stride_kh +
|
||||
cu_seqlens_k_start * stride_kn)
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + k_offset,
|
||||
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1),
|
||||
)
|
||||
v_offset = (off_z * stride_vz + off_h_k * stride_vh +
|
||||
cu_seqlens_k_start * stride_vk)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + v_offset,
|
||||
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
if BIAS_TYPE != 0:
|
||||
bias_ptr = tl.make_block_ptr(
|
||||
base=bias + off_h_q * stride_bh,
|
||||
shape=(seqlen_q, seqlen_k),
|
||||
strides=(stride_bm, stride_bn),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_N),
|
||||
order=(1, 0),
|
||||
)
|
||||
else:
|
||||
bias_ptr = None
|
||||
if ENABLE_DROPOUT:
|
||||
batch_philox_offset = philox_offset_base \
|
||||
+ (off_z * HQ + off_h_q) \
|
||||
* seqlen_q * seqlen_k
|
||||
else:
|
||||
batch_philox_offset = 0
|
||||
# We can ask to return the dropout mask without actually doing any dropout.
|
||||
# In this case, we return an invalid pointer so indicate the mask is not i
|
||||
# valid.
|
||||
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.make_block_ptr(
|
||||
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
|
||||
shape=(seqlen_q, seqlen_k),
|
||||
strides=(seqlen_k, 1),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_N),
|
||||
order=(1, 0),
|
||||
)
|
||||
else:
|
||||
encoded_softmax_block_ptr = 0
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
|
||||
# have native e^x support in HW.
|
||||
qk_scale = sm_scale * 1.44269504089
|
||||
# Q is loaded once at the beginning and shared by all N blocks.
|
||||
q = load_fn(Q_block_ptr, True, padded_head, "zero")
|
||||
if not USE_FP8:
|
||||
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
|
||||
acc_scale = 1.0
|
||||
else:
|
||||
qk_scale *= q_scale * k_scale
|
||||
acc_scale = p_scale * v_scale
|
||||
|
||||
# Here we compute how many full and masked blocks we have.
|
||||
padded_block_k = n_extra_tokens != 0
|
||||
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
|
||||
if IS_CAUSAL:
|
||||
# There are always at least BLOCK_M // BLOCK_N masked blocks.
|
||||
# Additionally there might be one more due to dissimilar seqlens.
|
||||
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
|
||||
else:
|
||||
# Padding on Q does not need to be masked in the FA loop.
|
||||
masked_blocks = padded_block_k
|
||||
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
|
||||
# block. In this case we might exceed n_blocks so pick the min.
|
||||
masked_blocks = min(masked_blocks, n_blocks)
|
||||
n_full_blocks = n_blocks - masked_blocks
|
||||
block_min = 0
|
||||
block_max = n_blocks * BLOCK_N
|
||||
# Compute for full blocks. Here we set causal to false regardless of its
|
||||
# value because there is no masking. Similarly we do not need padding.
|
||||
if n_full_blocks > 0:
|
||||
block_max = (n_blocks - masked_blocks) * BLOCK_N
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
start_m,
|
||||
seqlen_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
batch_philox_offset,
|
||||
encoded_softmax_block_ptr,
|
||||
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
|
||||
block_min,
|
||||
block_max,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
bias_ptr,
|
||||
# IS_CAUSAL, ....
|
||||
False,
|
||||
BLOCK_M,
|
||||
BLOCK_DMODEL,
|
||||
BLOCK_N,
|
||||
offs_m,
|
||||
offs_n,
|
||||
# _, MASK_STEPS, ...
|
||||
PRE_LOAD_V,
|
||||
False,
|
||||
ENABLE_DROPOUT,
|
||||
RETURN_ENCODED_SOFTMAX,
|
||||
padded_head,
|
||||
USE_FP8,
|
||||
qk_scale,
|
||||
p_descale,
|
||||
)
|
||||
block_min = block_max
|
||||
block_max = n_blocks * BLOCK_N
|
||||
|
||||
tl.debug_barrier()
|
||||
# Remaining blocks, if any, are full / not masked.
|
||||
if masked_blocks > 0:
|
||||
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
|
||||
if bias_ptr is not None:
|
||||
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
|
||||
(0, n_full_blocks))
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
start_m,
|
||||
seqlen_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
batch_philox_offset,
|
||||
encoded_softmax_block_ptr,
|
||||
block_min,
|
||||
block_max,
|
||||
offs_n_causal,
|
||||
masked_blocks,
|
||||
n_extra_tokens,
|
||||
bias_ptr,
|
||||
IS_CAUSAL,
|
||||
BLOCK_M,
|
||||
BLOCK_DMODEL,
|
||||
BLOCK_N,
|
||||
offs_m,
|
||||
offs_n,
|
||||
# _, MASK_STEPS, ...
|
||||
PRE_LOAD_V,
|
||||
True,
|
||||
ENABLE_DROPOUT,
|
||||
RETURN_ENCODED_SOFTMAX,
|
||||
padded_head,
|
||||
USE_FP8,
|
||||
qk_scale,
|
||||
p_descale,
|
||||
)
|
||||
# epilogue
|
||||
|
||||
if USE_FP8:
|
||||
acc *= acc_scale
|
||||
acc = acc / l_i[:, None]
|
||||
if ENABLE_DROPOUT:
|
||||
acc = acc / (1 - dropout_p)
|
||||
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
|
||||
# then we have one block with a row of all NaNs which come from computing
|
||||
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
|
||||
# and store 0s where there are NaNs as these rows should've been zeroed out.
|
||||
end_m_idx = (start_m + 1) * BLOCK_M
|
||||
start_m_idx = start_m * BLOCK_M
|
||||
causal_start_idx = seqlen_q - seqlen_k
|
||||
if USE_FP8_OUT:
|
||||
acc *= o_descale
|
||||
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||
acc = acc.to(Out.type.element_ty)
|
||||
if IS_CAUSAL: # noqa: SIM102
|
||||
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
|
||||
out_mask_boundary = tl.full((BLOCK_DMODEL, ),
|
||||
causal_start_idx,
|
||||
dtype=tl.int32)
|
||||
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
|
||||
out_ptrs_mask = (mask_m_offsets[:, None]
|
||||
>= out_mask_boundary[None, :])
|
||||
z = tl.zeros((1, ), tl.float32)
|
||||
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
|
||||
# write back LSE
|
||||
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
|
||||
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
|
||||
# few rows. This is only true for the last M block. For others,
|
||||
# overflow_size will be -ve
|
||||
# overflow_size = end_m_idx - seqlen_q
|
||||
# if overflow_size > 0:
|
||||
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
|
||||
# # This is a > check because mask being 0 blocks the store.
|
||||
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
|
||||
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
|
||||
# else:
|
||||
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
||||
|
||||
# write back O
|
||||
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
|
||||
off_h_q * stride_oh)
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out + o_offset,
|
||||
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
# Need boundary check on this to make sure the padding from the
|
||||
# Q and KV tensors in both dims are not part of what we store back.
|
||||
# TODO: Do the boundary check optionally.
|
||||
tl.store(O_block_ptr, acc, boundary_check=(0, 1))
|
||||
|
||||
|
||||
def check_args(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
varlen=True,
|
||||
max_seqlens=None,
|
||||
cu_seqlens_q=None,
|
||||
cu_seqlens_k=None,
|
||||
):
|
||||
assert q.dim() == k.dim() and q.dim() == v.dim()
|
||||
if varlen:
|
||||
assert q.dim() == 3
|
||||
total_q, nheads_q, head_size = q.shape
|
||||
total_k, nheads_k, _ = k.shape
|
||||
assert cu_seqlens_q is not None
|
||||
assert cu_seqlens_k is not None
|
||||
assert len(cu_seqlens_q) == len(cu_seqlens_k)
|
||||
else:
|
||||
assert q.dim() == 4
|
||||
batch, nheads_q, seqlen_q, head_size = q.shape
|
||||
_, nheads_k, seqlen_k, _ = k.shape
|
||||
assert max_seqlens > 0
|
||||
assert k.shape == v.shape
|
||||
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
|
||||
# TODO: Change assert if we support qkl f8 and v f16
|
||||
assert q.dtype == k.dtype and q.dtype == v.dtype
|
||||
assert head_size <= 256
|
||||
assert o.shape == q.shape
|
||||
assert (nheads_q % nheads_k) == 0
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlens_q,
|
||||
max_seqlens_k,
|
||||
causal=False,
|
||||
sm_scale=1.0,
|
||||
bias=None,
|
||||
fp8_scales=None,
|
||||
fp8_out_scale=None,
|
||||
):
|
||||
if fp8_scales is not None:
|
||||
use_fp8 = True
|
||||
(q_scale, k_scale, v_scale, p_scale) = fp8_scales
|
||||
float8 = current_platform.fp8_dtype()
|
||||
|
||||
def check_and_convert(t, scale):
|
||||
if t.dtype != float8:
|
||||
descale = 1.0 / scale
|
||||
ts = (t * descale).clamp(min=float8_info.min,
|
||||
max=float8_info.max)
|
||||
return ts.to(float8)
|
||||
else:
|
||||
return t
|
||||
|
||||
q = check_and_convert(q, q_scale)
|
||||
k = check_and_convert(k, k_scale)
|
||||
v = check_and_convert(v, v_scale)
|
||||
else:
|
||||
use_fp8 = False
|
||||
q_scale = k_scale = v_scale = p_scale = 1.0
|
||||
|
||||
if o is None:
|
||||
o = torch.empty_like(q, dtype=v.dtype)
|
||||
|
||||
check_args(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
varlen=True,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
)
|
||||
if True: # varlen
|
||||
total_q, nheads_q, head_size = q.shape
|
||||
total_k, nheads_k, _ = k.shape
|
||||
batch = len(cu_seqlens_q) - 1
|
||||
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
|
||||
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
|
||||
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
|
||||
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
|
||||
else:
|
||||
batch, seqlen_q, nheads_q, head_size = q.shape
|
||||
_, seqlen_k, nheads_k, _ = k.shape
|
||||
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
|
||||
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
|
||||
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
|
||||
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
|
||||
|
||||
# Get closest power of 2 over or equal to 32.
|
||||
unpadded_head_dims = {32, 64, 128, 256}
|
||||
if head_size not in unpadded_head_dims:
|
||||
padded_d_model = None
|
||||
for i in unpadded_head_dims:
|
||||
if i > head_size:
|
||||
padded_d_model = i
|
||||
break
|
||||
assert padded_d_model is not None
|
||||
else:
|
||||
padded_d_model = head_size
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
|
||||
nheads_q,
|
||||
batch,
|
||||
)
|
||||
|
||||
encoded_softmax = None
|
||||
|
||||
# Seed the RNG so we get reproducible results for testing.
|
||||
philox_seed = 0x1BF52
|
||||
philox_offset = 0x1D4B42
|
||||
|
||||
if bias is not None:
|
||||
bias_strides = (
|
||||
bias.stride(0),
|
||||
bias.stride(1),
|
||||
bias.stride(2),
|
||||
bias.stride(3),
|
||||
)
|
||||
else:
|
||||
bias_strides = (0, 0, 0, 0)
|
||||
|
||||
p_descale = 1.0 / p_scale
|
||||
o_descale = 1.0 / fp8_out_scale.item(
|
||||
) if fp8_out_scale is not None else 1.0
|
||||
|
||||
arg_max_seqlens_q = 0 if on_gfx1x() else max_seqlens_q
|
||||
arg_max_seqlens_k = 0 if on_gfx1x() else max_seqlens_k
|
||||
|
||||
attn_fwd[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
bias,
|
||||
sm_scale,
|
||||
q_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
p_scale,
|
||||
p_descale,
|
||||
o_descale,
|
||||
None,
|
||||
o,
|
||||
*q_strides,
|
||||
*k_strides,
|
||||
*v_strides,
|
||||
*o_strides,
|
||||
*bias_strides,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
dropout_p=0.0,
|
||||
philox_seed=philox_seed,
|
||||
philox_offset_base=philox_offset,
|
||||
encoded_softmax=encoded_softmax,
|
||||
HQ=nheads_q,
|
||||
HK=nheads_k,
|
||||
ACTUAL_BLOCK_DMODEL=head_size,
|
||||
MAX_SEQLENS_Q=arg_max_seqlens_q,
|
||||
MAX_SEQLENS_K=arg_max_seqlens_k,
|
||||
IS_CAUSAL=causal,
|
||||
VARLEN=True,
|
||||
BLOCK_DMODEL=padded_d_model,
|
||||
BIAS_TYPE=0 if bias is None else 1,
|
||||
ENABLE_DROPOUT=False,
|
||||
RETURN_ENCODED_SOFTMAX=False,
|
||||
USE_FP8=use_fp8,
|
||||
USE_FP8_OUT=fp8_out_scale is not None,
|
||||
)
|
||||
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = head_size
|
||||
ctx.causal = causal
|
||||
ctx.dropout_p = 0.0
|
||||
ctx.philox_seed = philox_seed
|
||||
ctx.philox_offset = philox_offset
|
||||
ctx.encoded_softmax = encoded_softmax
|
||||
ctx.return_encoded_softmax = False
|
||||
return o, encoded_softmax
|
||||
|
||||
|
||||
triton_attention = _attention.apply
|
||||
97
vllm/attention/ops/triton_merge_attn_states.py
Normal file
97
vllm/attention/ops/triton_merge_attn_states.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
# can be used to combine partial attention results (in the split-KV case)
|
||||
def merge_attn_states(
|
||||
output: torch.Tensor,
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
suffix_output: torch.Tensor,
|
||||
suffix_lse: torch.Tensor,
|
||||
output_lse: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
num_tokens = output.shape[0]
|
||||
num_query_heads = output.shape[1]
|
||||
head_size = output.shape[2]
|
||||
padded_head_size = triton.next_power_of_2(head_size)
|
||||
|
||||
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
|
||||
merge_attn_states_kernel[(num_tokens, num_query_heads)](
|
||||
output,
|
||||
output_lse,
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
head_size,
|
||||
padded_head_size,
|
||||
output_lse is not None,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def merge_attn_states_kernel(
|
||||
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
output_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
HEAD_SIZE: tl.constexpr,
|
||||
PADDED_HEAD_SIZE: tl.constexpr,
|
||||
OUTPUT_LSE: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
num_tokens = tl.num_programs(0)
|
||||
head_idx = tl.program_id(1)
|
||||
num_heads = tl.num_programs(1)
|
||||
|
||||
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
|
||||
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
|
||||
|
||||
# FA2 and FA3 have different behavior for when the sum-exp is 0, this namely
|
||||
# arises with 0 len seqlens. FA3 returns -inf here while FA2 returns inf.
|
||||
# If we see an inf assume FA2 and convert inf to -inf for consistency
|
||||
# and correctness. Inf generally doesn't make sense in this context outside
|
||||
# of undefined-behavior/FA2-case, so I think this a safe assumption.
|
||||
p_lse = float('-inf') if p_lse == float('inf') else p_lse
|
||||
s_lse = float('-inf') if s_lse == float('inf') else s_lse
|
||||
|
||||
max_lse = tl.maximum(p_lse, s_lse)
|
||||
p_lse = p_lse - max_lse
|
||||
s_lse = s_lse - max_lse
|
||||
# Will reuse precomputed Exp values for scale factor computation.
|
||||
p_se = tl.exp(p_lse)
|
||||
s_se = tl.exp(s_lse)
|
||||
out_se = (p_se + s_se)
|
||||
|
||||
if OUTPUT_LSE:
|
||||
out_lse = tl.log(out_se) + max_lse
|
||||
tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse)
|
||||
|
||||
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
|
||||
head_mask = head_arange < HEAD_SIZE
|
||||
p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE +
|
||||
head_idx * HEAD_SIZE + head_arange,
|
||||
mask=head_mask)
|
||||
s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE +
|
||||
head_idx * HEAD_SIZE + head_arange,
|
||||
mask=head_mask)
|
||||
|
||||
# NOTE(woosuk): Be careful with the numerical stability.
|
||||
# We should compute the scale first, and then multiply it with the output.
|
||||
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
|
||||
p_scale = p_se / out_se
|
||||
s_scale = s_se / out_se
|
||||
out = p_out * p_scale + s_out * s_scale
|
||||
tl.store(output + token_idx * num_heads * HEAD_SIZE +
|
||||
head_idx * HEAD_SIZE + head_arange,
|
||||
out,
|
||||
mask=head_mask)
|
||||
175
vllm/attention/ops/triton_reshape_and_cache_flash.py
Normal file
175
vllm/attention/ops/triton_reshape_and_cache_flash.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def reshape_and_cache_kernel_flash(
|
||||
key_ptr, # [num_tokens, num_heads, head_size]
|
||||
value_ptr, # [num_tokens, num_heads, head_size]
|
||||
key_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
|
||||
value_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
|
||||
slot_mapping_ptr, # [num_tokens]
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
# strides
|
||||
key_stride: tl.int64,
|
||||
value_stride: tl.int64,
|
||||
block_stride: tl.int64,
|
||||
page_stride: tl.int64,
|
||||
num_heads: tl.constexpr,
|
||||
head_size: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
# FP8 flags
|
||||
FP8_KV_CACHE: tl.constexpr,
|
||||
# tune parameters
|
||||
TILE_SIZE: tl.constexpr,
|
||||
):
|
||||
|
||||
token_idx = tl.program_id(axis=0)
|
||||
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64)
|
||||
if slot_idx < 0:
|
||||
# Padding token that should be ignored.
|
||||
return
|
||||
|
||||
tile_i = tl.program_id(axis=1)
|
||||
tile_offs = tl.arange(0, TILE_SIZE)
|
||||
tile_pos = tile_i * TILE_SIZE + tile_offs
|
||||
|
||||
block_idx = slot_idx // block_size
|
||||
block_offset = slot_idx % block_size
|
||||
|
||||
src_key_idx = token_idx * key_stride
|
||||
src_value_idx = token_idx * value_stride
|
||||
|
||||
tgt_idx = block_idx * block_stride + block_offset * page_stride
|
||||
|
||||
# [TILE_SIZE]
|
||||
key_load = tl.load(key_ptr + src_key_idx + tile_pos,
|
||||
mask=tile_pos < (num_heads * head_size))
|
||||
if FP8_KV_CACHE:
|
||||
if key_load.dtype.is_fp8():
|
||||
key_tile = key_load
|
||||
else:
|
||||
# tl.store will do the correct implicit cast to fp8,
|
||||
# based on the key_cache_ptr.dtype.element_ty
|
||||
key_tile = key_load / tl.load(k_scale)
|
||||
else:
|
||||
key_tile = key_load
|
||||
|
||||
# [TILE_SIZE]
|
||||
value_load = tl.load(value_ptr + src_value_idx + tile_pos,
|
||||
mask=tile_pos < (num_heads * head_size))
|
||||
if FP8_KV_CACHE:
|
||||
if value_load.dtype.is_fp8():
|
||||
value_tile = value_load
|
||||
else:
|
||||
# tl.store will do the correct implicit cast to fp8,
|
||||
# based on the value_cache_ptr.dtype.element_ty
|
||||
value_tile = value_load / tl.load(v_scale)
|
||||
else:
|
||||
value_tile = value_load
|
||||
|
||||
tl.store(
|
||||
key_cache_ptr + tgt_idx + tile_pos,
|
||||
key_tile,
|
||||
mask=tile_pos < (num_heads * head_size),
|
||||
)
|
||||
tl.store(
|
||||
value_cache_ptr + tgt_idx + tile_pos,
|
||||
value_tile,
|
||||
mask=tile_pos < (num_heads * head_size),
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def triton_reshape_and_cache_flash(
|
||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
# [num_blocks, block_size, num_heads, head_size]
|
||||
key_cache: torch.Tensor,
|
||||
# [num_blocks, block_size, num_heads, head_size]
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor, # [num_tokens]
|
||||
kv_cache_dtype: str, # "auto", "fp8"
|
||||
k_scale: torch.Tensor, # float32
|
||||
v_scale: torch.Tensor, # float32
|
||||
):
|
||||
num_tokens = key.shape[0]
|
||||
num_heads = key.shape[1]
|
||||
head_size = key.shape[2]
|
||||
block_size = key_cache.shape[1]
|
||||
n = num_heads * head_size
|
||||
|
||||
key_stride = key.stride()[0]
|
||||
value_stride = value.stride()[0]
|
||||
block_stride = key_cache.stride()[0]
|
||||
page_stride = key_cache.stride()[1]
|
||||
|
||||
head_stride = key_cache.stride()[2]
|
||||
assert head_stride == head_size, "only continous heads are supported"
|
||||
|
||||
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), \
|
||||
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
|
||||
kv_cache_torch_dtype = current_platform.fp8_dtype() if \
|
||||
kv_cache_dtype.startswith("fp8") else key_cache.dtype
|
||||
|
||||
if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith(
|
||||
"fp8"):
|
||||
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
|
||||
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
|
||||
key_cache = key_cache.view(kv_cache_torch_dtype)
|
||||
value_cache = value_cache.view(kv_cache_torch_dtype)
|
||||
assert kv_cache_dtype != torch.uint8, "explicit fp8 cast and store to "\
|
||||
"uint8 is not supported by triton reshape_and_cache_flash"
|
||||
|
||||
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
|
||||
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
|
||||
torch.float8_e4m3fn, torch.float8_e5m2, torch.uint8,
|
||||
torch.float8_e4m3fnuz], \
|
||||
"unsupported dtype of KV cache tensor, got "\
|
||||
"{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, " \
|
||||
"fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz."
|
||||
|
||||
# heuristics instead of autotuning
|
||||
TILE_SIZE = min(2048, triton.next_power_of_2(n))
|
||||
if torch.version.hip or torch.version.xpu:
|
||||
num_stages = 4
|
||||
num_warps = 8
|
||||
else: # cuda
|
||||
num_stages = 10
|
||||
num_warps = 16
|
||||
if torch.cuda.get_device_capability(key.device)[0] < 9:
|
||||
TILE_SIZE = min(512, TILE_SIZE)
|
||||
|
||||
# TODO(ngl): maybe replace with static launch grid to avoid overhead if
|
||||
# using cudagraphs
|
||||
grid = lambda meta: (int(num_tokens), triton.cdiv(n, meta["TILE_SIZE"]))
|
||||
|
||||
reshape_and_cache_kernel_flash[grid](
|
||||
key_ptr=key,
|
||||
value_ptr=value,
|
||||
key_cache_ptr=key_cache,
|
||||
value_cache_ptr=value_cache,
|
||||
slot_mapping_ptr=slot_mapping,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
# strides
|
||||
key_stride=key_stride,
|
||||
value_stride=value_stride,
|
||||
block_stride=block_stride,
|
||||
page_stride=page_stride,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
block_size=block_size,
|
||||
# FP8 flags
|
||||
FP8_KV_CACHE=FP8_KV_CACHE,
|
||||
# autotune parameters
|
||||
TILE_SIZE=TILE_SIZE,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
894
vllm/attention/ops/triton_unified_attention.py
Normal file
894
vllm/attention/ops/triton_unified_attention.py
Normal file
@@ -0,0 +1,894 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Authors:
|
||||
# - Burkhard Ringlein <ngl@zurich.ibm.com>
|
||||
# - Jan van Lunteren <jvl@zurich.ibm.com>
|
||||
# - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
|
||||
# - Thomas Parnell <tpa@zurich.ibm.com>
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
logger = init_logger(__name__)
|
||||
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cdiv_fn(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
@triton.jit
|
||||
def apply_softcap(S, x):
|
||||
Sdiv = S / x
|
||||
p1 = tl.exp(Sdiv)
|
||||
p2 = tl.exp(-Sdiv)
|
||||
return x * (p1 - p2) / (p1 + p2)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def find_seq_idx(query_start_len_ptr, target_idx, num_seqs,
|
||||
BLOCK_Q: tl.constexpr, use_q_block_mode: tl.constexpr):
|
||||
left: tl.int32 = 0
|
||||
right = num_seqs
|
||||
while left < right:
|
||||
mid = (left + right) // 2
|
||||
val = tl.load(query_start_len_ptr + mid)
|
||||
mid_val = val // BLOCK_Q + mid if use_q_block_mode else val
|
||||
|
||||
if mid_val <= target_idx:
|
||||
left = mid + 1
|
||||
else:
|
||||
right = mid
|
||||
|
||||
return left - 1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_unified_attention_2d(
|
||||
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||
sink_ptr, # [num_query_heads]
|
||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
alibi_slopes_ptr, # [num_query_heads]
|
||||
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
|
||||
scale, # float32
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
out_scale, # float32
|
||||
softcap, # float32
|
||||
num_query_heads: tl.constexpr, # int
|
||||
num_queries_per_kv: tl.constexpr, # int
|
||||
block_table_stride: tl.int64, # int
|
||||
query_stride_0: tl.int64, # int
|
||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||
output_stride_0: tl.int64, # int
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
qq_bias_stride_0: tl.int64, # int
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
TILE_SIZE: tl.constexpr, # int must be power of 2
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
USE_QQ_BIAS: tl.constexpr, # bool
|
||||
USE_SOFTCAP: tl.constexpr, # bool
|
||||
USE_SINKS: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
stride_k_cache_2: tl.int64, # int
|
||||
stride_k_cache_3: tl.constexpr, # int
|
||||
stride_v_cache_0: tl.int64, # int
|
||||
stride_v_cache_1: tl.int64, # int
|
||||
stride_v_cache_2: tl.int64, # int
|
||||
stride_v_cache_3: tl.constexpr, # int
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
BLOCK_Q: tl.constexpr, # int
|
||||
num_seqs: tl.int32,
|
||||
BLOCK_M: tl.constexpr, # int
|
||||
USE_FP8: tl.constexpr, # bool
|
||||
FP8_MIN: tl.constexpr = float8_info.min,
|
||||
FP8_MAX: tl.constexpr = float8_info.max,
|
||||
):
|
||||
q_block_global_idx = tl.program_id(0)
|
||||
kv_head_idx = tl.program_id(1)
|
||||
|
||||
seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs,
|
||||
BLOCK_Q, True)
|
||||
|
||||
q_block_start_idx = tl.load(query_start_len_ptr +
|
||||
seq_idx) // BLOCK_Q + seq_idx
|
||||
|
||||
q_block_local_idx = q_block_global_idx - q_block_start_idx
|
||||
|
||||
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
||||
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
|
||||
|
||||
cur_batch_query_len = cur_batch_in_all_stop_index \
|
||||
- cur_batch_in_all_start_index
|
||||
|
||||
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
|
||||
return
|
||||
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||
offs_t = tl.arange(0, TILE_SIZE)
|
||||
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
|
||||
|
||||
query_offset_0 = cur_batch_in_all_start_index + query_pos
|
||||
query_offset_1 = kv_head_idx * num_queries_per_kv + \
|
||||
offs_m % num_queries_per_kv
|
||||
query_offset = (query_offset_0[:, None] * query_stride_0 +
|
||||
query_offset_1[:, None] * query_stride_1 + offs_d[None, :])
|
||||
|
||||
dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
|
||||
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
|
||||
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
|
||||
|
||||
# Q : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
Q = tl.load(
|
||||
query_ptr + query_offset,
|
||||
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
block_table_offset = seq_idx * block_table_stride
|
||||
|
||||
if not USE_SINKS:
|
||||
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
else:
|
||||
M = tl.load(
|
||||
sink_ptr + query_offset_1,
|
||||
mask=query_mask_1,
|
||||
other=float("-inf"),
|
||||
).to(dtype=tl.float32)
|
||||
|
||||
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
|
||||
|
||||
# sequence len for this particular sequence
|
||||
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||
|
||||
# context length for this particular sequences
|
||||
context_len = seq_len - cur_batch_query_len
|
||||
|
||||
# alibi slope for this head
|
||||
if USE_ALIBI_SLOPES:
|
||||
alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1,
|
||||
mask=query_mask_1,
|
||||
other=0.0)
|
||||
|
||||
# query-query attention bias
|
||||
if USE_QQ_BIAS:
|
||||
qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
|
||||
) # shape: [BLOCK_M]
|
||||
|
||||
# compute the length of the longest sequence prefix spanned by any
|
||||
# query token in the current q_block (q_block_local_idx)
|
||||
max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
|
||||
BLOCK_M - 1) // num_queries_per_kv + 1
|
||||
|
||||
# adjust for potential padding in the last q_block by considering the
|
||||
# actual sequence length
|
||||
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
|
||||
|
||||
# calculate the number of tiles that need to be processed to
|
||||
# cover the longest sequence prefix (due to causal masking, tiles beyond
|
||||
# this prefix can be skipped)
|
||||
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
|
||||
|
||||
# ---- Sliding-window tile pruning --------------------
|
||||
# Default: keep previous global behavior
|
||||
tile_start = 0
|
||||
tile_end = num_tiles
|
||||
if SLIDING_WINDOW > 0:
|
||||
# Query rows covered by this Q-block
|
||||
qpos_lo = q_block_local_idx * BLOCK_Q
|
||||
qpos_hi = tl.minimum(
|
||||
qpos_lo + (BLOCK_M - 1) // num_queries_per_kv,
|
||||
cur_batch_query_len - 1,
|
||||
)
|
||||
# For sliding window, each query position q can only attend to
|
||||
# keys in the range [q_abs - SLIDING_WINDOW + 1, q_abs]
|
||||
# where q_abs = context_len + q
|
||||
# The union of allowed key positions for this Q-block is:
|
||||
# [context_len + qpos_lo - SLIDING_WINDOW + 1, context_len + qpos_hi]
|
||||
first_allowed_key = context_len + qpos_lo - SLIDING_WINDOW + 1
|
||||
last_allowed_key = context_len + qpos_hi
|
||||
# Convert to tile indices and clamp
|
||||
tile_start = tl.maximum(0, first_allowed_key // TILE_SIZE)
|
||||
tile_end = tl.minimum((last_allowed_key // TILE_SIZE) + 1, num_tiles)
|
||||
|
||||
# iterate through tiles (now limited to the sliding window range)
|
||||
for j in range(tile_start, tile_end):
|
||||
seq_offset = j * TILE_SIZE + offs_t
|
||||
tile_mask = seq_offset < max_seq_prefix_len
|
||||
|
||||
physical_block_idx = tl.load(block_tables_ptr + block_table_offset +
|
||||
seq_offset // BLOCK_SIZE).to(tl.int64)
|
||||
|
||||
v_offset = (physical_block_idx[:, None] * stride_v_cache_0 +
|
||||
kv_head_idx * stride_v_cache_2 +
|
||||
offs_d[None, :] * stride_v_cache_3 +
|
||||
(seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1)
|
||||
|
||||
k_offset = (physical_block_idx[None, :] * stride_k_cache_0 +
|
||||
kv_head_idx * stride_k_cache_2 +
|
||||
offs_d[:, None] * stride_k_cache_3 +
|
||||
(seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1)
|
||||
|
||||
# K : (HEAD_SIZE, TILE_SIZE)
|
||||
K_load = tl.load(key_cache_ptr + k_offset,
|
||||
mask=dim_mask[:, None] & tile_mask[None, :],
|
||||
other=0.0)
|
||||
|
||||
if K_load.dtype.is_fp8():
|
||||
if Q.dtype.is_fp8():
|
||||
K = K_load
|
||||
else:
|
||||
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
||||
else:
|
||||
K = K_load
|
||||
|
||||
# V : (TILE_SIZE, HEAD_SIZE)
|
||||
V_load = tl.load(value_cache_ptr + v_offset,
|
||||
mask=dim_mask[None, :] & tile_mask[:, None],
|
||||
other=0.0)
|
||||
|
||||
if V_load.dtype.is_fp8():
|
||||
if Q.dtype.is_fp8():
|
||||
V = V_load
|
||||
else:
|
||||
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
||||
else:
|
||||
V = V_load
|
||||
|
||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
||||
|
||||
# S : (BLOCK_M, TILE_SIZE)
|
||||
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||
|
||||
S += scale * tl.dot(Q, K)
|
||||
|
||||
if USE_SOFTCAP:
|
||||
S = apply_softcap(S, softcap)
|
||||
|
||||
S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask,
|
||||
S, float("-inf"))
|
||||
|
||||
if SLIDING_WINDOW > 0:
|
||||
S = tl.where((context_len + query_pos[:, None] - seq_offset)
|
||||
< SLIDING_WINDOW, S, float("-inf"))
|
||||
|
||||
if USE_ALIBI_SLOPES:
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
|
||||
if USE_QQ_BIAS:
|
||||
# compute key positions relative to query section
|
||||
key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE]
|
||||
# load bias only for keys that correspond to queries
|
||||
is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0
|
||||
qq_bias = tl.load(
|
||||
qq_bias_row_ptrs + key_rel_pos[None, :],
|
||||
mask=is_query_key[None, :], # avoid OOB for context keys
|
||||
other=0.0,
|
||||
)
|
||||
S += qq_bias
|
||||
|
||||
# compute running maximum
|
||||
# m_j : (BLOCK_M,)
|
||||
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||
|
||||
# For sliding window there's a chance the max is -inf due to masking of
|
||||
# the entire row. In this case we need to set m_j 0 to avoid NaN
|
||||
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
|
||||
|
||||
# P : (BLOCK_M, TILE_SIZE)
|
||||
P = tl.exp(S - m_j[:, None])
|
||||
|
||||
# l_j : (BLOCK_M,)
|
||||
l_j = tl.sum(P, axis=1)
|
||||
|
||||
# alpha : (BLOCK_M, )
|
||||
alpha = tl.exp(M - m_j)
|
||||
|
||||
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update constants
|
||||
L = L * alpha + l_j
|
||||
M = m_j
|
||||
|
||||
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
acc += tl.dot(P.to(V.dtype), V)
|
||||
|
||||
# epilogue
|
||||
acc = acc / L[:, None]
|
||||
if USE_FP8:
|
||||
acc = acc * tl.load(out_scale)
|
||||
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||
|
||||
output_offset = (query_offset_0[:, None] * output_stride_0 +
|
||||
query_offset_1[:, None] * output_stride_1 +
|
||||
offs_d[None, :])
|
||||
|
||||
tl.store(
|
||||
output_ptr + output_offset,
|
||||
acc,
|
||||
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_unified_attention_3d(
|
||||
segm_output_ptr,
|
||||
# [num_tokens, num_query_heads, num_segments, head_size]
|
||||
segm_max_ptr, # [num_tokens, num_query_heads, num_segments]
|
||||
segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments]
|
||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
|
||||
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
|
||||
sink_ptr, # [num_query_heads]
|
||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
alibi_slopes_ptr, # [num_query_heads]
|
||||
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
|
||||
scale, # float32
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
softcap, # float32
|
||||
num_query_heads: tl.constexpr, # int
|
||||
num_queries_per_kv: tl.constexpr, # int
|
||||
block_table_stride: tl.int64, # int
|
||||
query_stride_0: tl.int64, # int
|
||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||
qq_bias_stride_0: tl.int64, # int
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
TILE_SIZE: tl.constexpr, # int, must be power of 2
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
USE_QQ_BIAS: tl.constexpr, # bool
|
||||
USE_SOFTCAP: tl.constexpr, # bool
|
||||
USE_SINKS: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
stride_k_cache_2: tl.int64, # int
|
||||
stride_k_cache_3: tl.constexpr, # int
|
||||
stride_v_cache_0: tl.int64, # int
|
||||
stride_v_cache_1: tl.int64, # int
|
||||
stride_v_cache_2: tl.int64, # int
|
||||
stride_v_cache_3: tl.constexpr, # int
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
BLOCK_Q: tl.constexpr, # int
|
||||
num_seqs: tl.int32,
|
||||
BLOCK_M: tl.constexpr, # int
|
||||
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
||||
):
|
||||
q_block_global_idx = tl.program_id(0)
|
||||
kv_head_idx = tl.program_id(1)
|
||||
segm_idx = tl.program_id(2)
|
||||
|
||||
seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs,
|
||||
BLOCK_Q, True)
|
||||
|
||||
q_block_start_idx = tl.load(query_start_len_ptr +
|
||||
seq_idx) // BLOCK_Q + seq_idx
|
||||
|
||||
q_block_local_idx = q_block_global_idx - q_block_start_idx
|
||||
|
||||
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
||||
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
|
||||
|
||||
cur_batch_query_len = cur_batch_in_all_stop_index \
|
||||
- cur_batch_in_all_start_index
|
||||
|
||||
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
|
||||
return
|
||||
|
||||
# sequence len for this particular sequence
|
||||
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||
|
||||
# number of segments for this particular sequence
|
||||
num_segments = NUM_SEGMENTS_PER_SEQ
|
||||
tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE)
|
||||
|
||||
if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len:
|
||||
return
|
||||
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||
offs_t = tl.arange(0, TILE_SIZE)
|
||||
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
|
||||
|
||||
query_offset_0 = cur_batch_in_all_start_index + query_pos
|
||||
query_offset_1 = kv_head_idx * num_queries_per_kv + \
|
||||
offs_m % num_queries_per_kv
|
||||
query_offset = (query_offset_0[:, None] * query_stride_0 +
|
||||
query_offset_1[:, None] * query_stride_1 + offs_d[None, :])
|
||||
|
||||
dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
|
||||
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
|
||||
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
|
||||
|
||||
# Q : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
Q = tl.load(
|
||||
query_ptr + query_offset,
|
||||
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
block_table_offset = seq_idx * block_table_stride
|
||||
|
||||
if USE_SINKS:
|
||||
if segm_idx == 0:
|
||||
M = tl.load(
|
||||
sink_ptr + query_offset_1,
|
||||
mask=query_mask_1,
|
||||
other=float("-inf"),
|
||||
).to(dtype=tl.float32)
|
||||
else:
|
||||
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
else:
|
||||
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
|
||||
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
|
||||
|
||||
# context length for this particular sequences
|
||||
context_len = seq_len - cur_batch_query_len
|
||||
|
||||
# alibi slope for this head
|
||||
if USE_ALIBI_SLOPES:
|
||||
alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1,
|
||||
mask=query_mask_1,
|
||||
other=0.0)
|
||||
|
||||
# query-query attention bias
|
||||
if USE_QQ_BIAS:
|
||||
qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
|
||||
) # shape: [BLOCK_M]
|
||||
|
||||
# compute the length of the longest sequence prefix spanned by any
|
||||
# query token in the current q_block (q_block_local_idx)
|
||||
max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
|
||||
BLOCK_M - 1) // num_queries_per_kv + 1
|
||||
|
||||
# adjust for potential padding in the last q_block by considering the
|
||||
# actual sequence length
|
||||
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
|
||||
|
||||
# calculate the number of tiles that need to be processed to
|
||||
# cover the longest sequence prefix (due to causal masking, tiles beyond
|
||||
# this prefix can be skipped)
|
||||
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
|
||||
|
||||
# iterate through tiles within current segment
|
||||
for j in range(
|
||||
segm_idx * tiles_per_segment,
|
||||
min((segm_idx + 1) * tiles_per_segment, num_tiles),
|
||||
):
|
||||
seq_offset = j * TILE_SIZE + offs_t
|
||||
tile_mask = seq_offset < max_seq_prefix_len
|
||||
|
||||
physical_block_idx = tl.load(block_tables_ptr + block_table_offset +
|
||||
seq_offset // BLOCK_SIZE).to(tl.int64)
|
||||
|
||||
v_offset = (physical_block_idx[:, None] * stride_v_cache_0 +
|
||||
kv_head_idx * stride_v_cache_2 +
|
||||
offs_d[None, :] * stride_v_cache_3 +
|
||||
(seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1)
|
||||
|
||||
k_offset = (physical_block_idx[None, :] * stride_k_cache_0 +
|
||||
kv_head_idx * stride_k_cache_2 +
|
||||
offs_d[:, None] * stride_k_cache_3 +
|
||||
(seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1)
|
||||
|
||||
# K : (HEAD_SIZE, TILE_SIZE)
|
||||
K_load = tl.load(key_cache_ptr + k_offset,
|
||||
mask=dim_mask[:, None] & tile_mask[None, :],
|
||||
other=0.0)
|
||||
|
||||
if K_load.dtype.is_fp8():
|
||||
if Q.dtype.is_fp8():
|
||||
K = K_load
|
||||
else:
|
||||
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
||||
else:
|
||||
K = K_load
|
||||
|
||||
# V : (TILE_SIZE, HEAD_SIZE)
|
||||
V_load = tl.load(value_cache_ptr + v_offset,
|
||||
mask=dim_mask[None, :] & tile_mask[:, None],
|
||||
other=0.0)
|
||||
|
||||
if V_load.dtype.is_fp8():
|
||||
if Q.dtype.is_fp8():
|
||||
V = V_load
|
||||
else:
|
||||
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
||||
else:
|
||||
V = V_load
|
||||
|
||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
||||
|
||||
# S : (BLOCK_M, TILE_SIZE)
|
||||
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||
S += scale * tl.dot(Q, K)
|
||||
|
||||
if USE_SOFTCAP:
|
||||
S = apply_softcap(S, softcap)
|
||||
|
||||
S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask,
|
||||
S, float("-inf"))
|
||||
|
||||
if SLIDING_WINDOW > 0:
|
||||
S = tl.where((context_len + query_pos[:, None] - seq_offset)
|
||||
< SLIDING_WINDOW, S, float("-inf"))
|
||||
|
||||
if USE_ALIBI_SLOPES:
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
|
||||
if USE_QQ_BIAS:
|
||||
# compute key positions relative to query section
|
||||
key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE]
|
||||
# load bias only for keys that correspond to queries
|
||||
is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0
|
||||
qq_bias = tl.load(
|
||||
qq_bias_row_ptrs + key_rel_pos[None, :],
|
||||
mask=is_query_key[None, :], # avoid OOB for context keys
|
||||
other=0.0,
|
||||
)
|
||||
S += qq_bias
|
||||
|
||||
# compute running maximum
|
||||
# m_j : (BLOCK_M,)
|
||||
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||
|
||||
# For sliding window there's a chance the max is -inf due to masking of
|
||||
# the entire row. In this case we need to set m_j 0 to avoid NaN
|
||||
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
|
||||
|
||||
# P : (BLOCK_M, TILE_SIZE,)
|
||||
P = tl.exp(S - m_j[:, None])
|
||||
|
||||
# l_j : (BLOCK_M,)
|
||||
l_j = tl.sum(P, axis=1)
|
||||
|
||||
# alpha : (BLOCK_M, )
|
||||
alpha = tl.exp(M - m_j)
|
||||
|
||||
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update constants
|
||||
L = L * alpha + l_j
|
||||
M = m_j
|
||||
|
||||
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
acc += tl.dot(P.to(V.dtype), V)
|
||||
|
||||
segm_output_offset = (
|
||||
query_offset_0[:, None].to(tl.int64) *
|
||||
(num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
|
||||
query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
|
||||
segm_idx * HEAD_SIZE_PADDED + tl.arange(0, HEAD_SIZE_PADDED)[None, :])
|
||||
tl.store(
|
||||
segm_output_ptr + segm_output_offset,
|
||||
acc,
|
||||
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||
)
|
||||
segm_offset = (query_offset_0.to(tl.int64) *
|
||||
(num_query_heads * NUM_SEGMENTS_PER_SEQ) +
|
||||
query_offset_1 * NUM_SEGMENTS_PER_SEQ + segm_idx)
|
||||
tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1)
|
||||
tl.store(segm_expsum_ptr + segm_offset,
|
||||
L,
|
||||
mask=query_mask_0 & query_mask_1)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def reduce_segments(
|
||||
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
segm_output_ptr,
|
||||
#[num_tokens, num_query_heads, max_num_segments, head_size]
|
||||
segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments]
|
||||
segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
num_seqs, # int
|
||||
num_query_heads: tl.constexpr, # int
|
||||
out_scale_inv, # float32
|
||||
output_stride_0: tl.int64, # int
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
block_table_stride: tl.int64, # int
|
||||
TILE_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE: tl.constexpr, # int, must be power of 2
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
BLOCK_Q: tl.constexpr, # int
|
||||
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
||||
USE_FP8: tl.constexpr, # bool
|
||||
FP8_MIN: tl.constexpr = float8_info.min,
|
||||
FP8_MAX: tl.constexpr = float8_info.max,
|
||||
):
|
||||
query_token_idx = tl.program_id(0)
|
||||
query_head_idx = tl.program_id(1)
|
||||
|
||||
seq_idx = find_seq_idx(query_start_len_ptr, query_token_idx, num_seqs,
|
||||
BLOCK_Q, False)
|
||||
|
||||
# sequence len for this particular sequence
|
||||
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||
|
||||
# number of segments for this particular sequence
|
||||
num_segments = NUM_SEGMENTS_PER_SEQ
|
||||
tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE)
|
||||
|
||||
# create masks for subsequent loads
|
||||
act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE)
|
||||
segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full(
|
||||
[NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32)
|
||||
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
|
||||
0).to(tl.int1)
|
||||
|
||||
# load segment maxima
|
||||
segm_offset = (query_token_idx.to(tl.int64) *
|
||||
(num_query_heads * NUM_SEGMENTS_PER_SEQ) +
|
||||
query_head_idx * NUM_SEGMENTS_PER_SEQ +
|
||||
tl.arange(0, NUM_SEGMENTS_PER_SEQ))
|
||||
segm_max = tl.load(segm_max_ptr + segm_offset,
|
||||
mask=segm_mask,
|
||||
other=float("-inf"))
|
||||
overall_max = tl.max(segm_max)
|
||||
|
||||
# load and rescale segment exp sums
|
||||
segm_expsum = tl.load(segm_expsum_ptr + segm_offset,
|
||||
mask=segm_mask,
|
||||
other=0.0)
|
||||
segm_expsum = segm_expsum * tl.exp(segm_max - overall_max)
|
||||
overall_expsum = tl.sum(segm_expsum)
|
||||
|
||||
# load, rescale, and add segment attention outputs
|
||||
segm_output_offset = (
|
||||
query_token_idx.to(tl.int64) *
|
||||
(num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
|
||||
query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
|
||||
tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED +
|
||||
tl.arange(0, HEAD_SIZE_PADDED)[None, :])
|
||||
segm_output = tl.load(
|
||||
segm_output_ptr + segm_output_offset,
|
||||
mask=segm_mask[:, None] & dim_mask[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
segm_output *= tl.exp(segm_max - overall_max)[:, None]
|
||||
acc_sum = tl.sum(segm_output, axis=0)
|
||||
# safely divide by overall_expsum, returning 0.0 if overall_expsum is 0
|
||||
acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum)
|
||||
|
||||
if USE_FP8:
|
||||
acc = acc * tl.load(out_scale_inv)
|
||||
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||
|
||||
# write result
|
||||
output_offset = (query_token_idx * output_stride_0 +
|
||||
query_head_idx * output_stride_1 +
|
||||
tl.arange(0, HEAD_SIZE_PADDED))
|
||||
tl.store(output_ptr + output_offset, acc, mask=dim_mask)
|
||||
|
||||
|
||||
def unified_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
seqused_k,
|
||||
max_seqlen_k,
|
||||
softmax_scale,
|
||||
causal,
|
||||
window_size,
|
||||
block_table,
|
||||
softcap,
|
||||
q_descale,
|
||||
k_descale,
|
||||
v_descale,
|
||||
alibi_slopes=None,
|
||||
output_scale=None,
|
||||
qq_bias=None,
|
||||
# Optional tensor for sinks
|
||||
sinks=None,
|
||||
):
|
||||
|
||||
assert causal, "Only causal attention is supported"
|
||||
assert q_descale is None, "Q scales not supported"
|
||||
|
||||
if sinks is not None:
|
||||
assert sinks.shape[0] == q.shape[1], \
|
||||
"Sinks must be num_query_heads size"
|
||||
|
||||
use_alibi_slopes = alibi_slopes is not None
|
||||
use_qq_bias = qq_bias is not None
|
||||
|
||||
block_size = v.shape[1]
|
||||
num_seqs = len(seqused_k)
|
||||
num_query_heads = q.shape[1]
|
||||
num_kv_heads = k.shape[2]
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
head_size = q.shape[2]
|
||||
|
||||
BLOCK_M = 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(
|
||||
num_queries_per_kv)
|
||||
BLOCK_Q = BLOCK_M // num_queries_per_kv
|
||||
|
||||
# Ideally we would launch with kernel with:
|
||||
# \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks.
|
||||
# However, it is slow to realize the query_lens on cpu.
|
||||
# Instead we use upper-bound:
|
||||
# \sum_i[ceil(query_len[i] / BLOCK_Q)]
|
||||
# <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1]
|
||||
# = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs
|
||||
# <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs
|
||||
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
|
||||
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
|
||||
|
||||
# Assigning default tile sizes for prefill and decode.
|
||||
# Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1)
|
||||
# and at least 16 for all other data types.
|
||||
TILE_SIZE_PREFILL = 32
|
||||
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
|
||||
|
||||
# if batch contains a prefill
|
||||
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
|
||||
kernel_unified_attention_2d[(
|
||||
total_num_q_blocks,
|
||||
num_kv_heads,
|
||||
)](
|
||||
output_ptr=out,
|
||||
query_ptr=q,
|
||||
key_cache_ptr=k,
|
||||
value_cache_ptr=v,
|
||||
sink_ptr=sinks,
|
||||
block_tables_ptr=block_table,
|
||||
seq_lens_ptr=seqused_k,
|
||||
alibi_slopes_ptr=alibi_slopes,
|
||||
qq_bias_ptr=qq_bias,
|
||||
scale=softmax_scale,
|
||||
k_scale=k_descale,
|
||||
v_scale=v_descale,
|
||||
out_scale=1 / output_scale if output_scale is not None else 1.0,
|
||||
softcap=softcap,
|
||||
num_query_heads=num_query_heads,
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
block_table_stride=block_table.stride(0),
|
||||
query_stride_0=q.stride(0),
|
||||
query_stride_1=q.stride(1),
|
||||
output_stride_0=out.stride(0),
|
||||
output_stride_1=out.stride(1),
|
||||
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
|
||||
BLOCK_SIZE=block_size,
|
||||
TILE_SIZE=TILE_SIZE_PREFILL,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
USE_QQ_BIAS=use_qq_bias,
|
||||
USE_SOFTCAP=(softcap > 0),
|
||||
USE_SINKS=(sinks is not None),
|
||||
SLIDING_WINDOW=(1 + window_size[0]),
|
||||
stride_k_cache_0=k.stride(0),
|
||||
stride_k_cache_1=k.stride(1),
|
||||
stride_k_cache_2=k.stride(2),
|
||||
stride_k_cache_3=k.stride(3),
|
||||
stride_v_cache_0=v.stride(0),
|
||||
stride_v_cache_1=v.stride(1),
|
||||
stride_v_cache_2=v.stride(2),
|
||||
stride_v_cache_3=v.stride(3),
|
||||
query_start_len_ptr=cu_seqlens_q,
|
||||
BLOCK_Q=BLOCK_Q,
|
||||
num_seqs=num_seqs,
|
||||
BLOCK_M=BLOCK_M,
|
||||
USE_FP8=output_scale is not None,
|
||||
)
|
||||
else:
|
||||
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
|
||||
# value that showed good performance in tests
|
||||
NUM_SEGMENTS = 16
|
||||
|
||||
segm_output = torch.empty(
|
||||
q.shape[0],
|
||||
num_query_heads,
|
||||
NUM_SEGMENTS,
|
||||
triton.next_power_of_2(head_size),
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
segm_max = torch.empty(
|
||||
q.shape[0],
|
||||
num_query_heads,
|
||||
NUM_SEGMENTS,
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
segm_expsum = torch.empty(
|
||||
q.shape[0],
|
||||
num_query_heads,
|
||||
NUM_SEGMENTS,
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
|
||||
kernel_unified_attention_3d[(
|
||||
total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)](
|
||||
segm_output_ptr=segm_output,
|
||||
segm_max_ptr=segm_max,
|
||||
segm_expsum_ptr=segm_expsum,
|
||||
query_ptr=q,
|
||||
key_cache_ptr=k,
|
||||
value_cache_ptr=v,
|
||||
sink_ptr=sinks,
|
||||
block_tables_ptr=block_table,
|
||||
seq_lens_ptr=seqused_k,
|
||||
alibi_slopes_ptr=alibi_slopes,
|
||||
qq_bias_ptr=qq_bias,
|
||||
scale=softmax_scale,
|
||||
k_scale=k_descale,
|
||||
v_scale=v_descale,
|
||||
softcap=softcap,
|
||||
num_query_heads=num_query_heads,
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
block_table_stride=block_table.stride(0),
|
||||
query_stride_0=q.stride(0),
|
||||
query_stride_1=q.stride(1),
|
||||
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
|
||||
BLOCK_SIZE=block_size,
|
||||
TILE_SIZE=TILE_SIZE_DECODE,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
USE_QQ_BIAS=use_qq_bias,
|
||||
USE_SOFTCAP=(softcap > 0),
|
||||
USE_SINKS=(sinks is not None),
|
||||
SLIDING_WINDOW=(1 + window_size[0]),
|
||||
stride_k_cache_0=k.stride(0),
|
||||
stride_k_cache_1=k.stride(1),
|
||||
stride_k_cache_2=k.stride(2),
|
||||
stride_k_cache_3=k.stride(3),
|
||||
stride_v_cache_0=v.stride(0),
|
||||
stride_v_cache_1=v.stride(1),
|
||||
stride_v_cache_2=v.stride(2),
|
||||
stride_v_cache_3=v.stride(3),
|
||||
query_start_len_ptr=cu_seqlens_q,
|
||||
BLOCK_Q=BLOCK_Q,
|
||||
num_seqs=num_seqs,
|
||||
BLOCK_M=BLOCK_M,
|
||||
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
|
||||
)
|
||||
reduce_segments[(q.shape[0], num_query_heads)](
|
||||
output_ptr=out,
|
||||
segm_output_ptr=segm_output,
|
||||
segm_max_ptr=segm_max,
|
||||
segm_expsum_ptr=segm_expsum,
|
||||
seq_lens_ptr=seqused_k,
|
||||
num_seqs=num_seqs,
|
||||
num_query_heads=num_query_heads,
|
||||
out_scale_inv=1 /
|
||||
output_scale if output_scale is not None else 1.0,
|
||||
output_stride_0=out.stride(0),
|
||||
output_stride_1=out.stride(1),
|
||||
block_table_stride=block_table.stride(0),
|
||||
TILE_SIZE=TILE_SIZE_DECODE,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
query_start_len_ptr=cu_seqlens_q,
|
||||
BLOCK_Q=BLOCK_Q,
|
||||
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
|
||||
USE_FP8=output_scale is not None,
|
||||
)
|
||||
245
vllm/attention/selector.py
Normal file
245
vllm/attention/selector.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from typing import Generator, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
|
||||
"""
|
||||
Convert a string backend name to a _Backend enum value.
|
||||
|
||||
Returns:
|
||||
* _Backend: enum value if backend_name is a valid in-tree type
|
||||
* None: otherwise it's an invalid in-tree type or an out-of-tree platform is
|
||||
loaded.
|
||||
"""
|
||||
assert backend_name is not None
|
||||
return _Backend[backend_name] if backend_name in _Backend.__members__ else \
|
||||
None
|
||||
|
||||
|
||||
def get_env_variable_attn_backend() -> Optional[_Backend]:
|
||||
'''
|
||||
Get the backend override specified by the vLLM attention
|
||||
backend environment variable, if one is specified.
|
||||
|
||||
Returns:
|
||||
|
||||
* _Backend enum value if an override is specified
|
||||
* None otherwise
|
||||
'''
|
||||
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
|
||||
return (None
|
||||
if backend_name is None else backend_name_to_enum(backend_name))
|
||||
|
||||
|
||||
# Global state allows a particular choice of backend
|
||||
# to be forced, overriding the logic which auto-selects
|
||||
# a backend based on system & workload configuration
|
||||
# (default behavior if this variable is None)
|
||||
#
|
||||
# THIS SELECTION TAKES PRECEDENCE OVER THE
|
||||
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
|
||||
forced_attn_backend: Optional[_Backend] = None
|
||||
|
||||
|
||||
def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
|
||||
'''
|
||||
Force all attention operations to use a specified backend.
|
||||
|
||||
Passing `None` for the argument re-enables automatic
|
||||
backend selection.,
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_backend: backend selection (None to revert to auto)
|
||||
'''
|
||||
global forced_attn_backend
|
||||
forced_attn_backend = attn_backend
|
||||
|
||||
|
||||
def get_global_forced_attn_backend() -> Optional[_Backend]:
|
||||
'''
|
||||
Get the currently-forced choice of attention backend,
|
||||
or None if auto-selection is currently enabled.
|
||||
'''
|
||||
return forced_attn_backend
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _IsSupported:
|
||||
can_import: bool
|
||||
head_size: bool
|
||||
dtype: bool
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return self.can_import and self.head_size and self.dtype
|
||||
|
||||
|
||||
def is_attn_backend_supported(
|
||||
attn_backend: Union[str, type[AttentionBackend]],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
allow_import_error: bool = True,
|
||||
) -> _IsSupported:
|
||||
if isinstance(attn_backend, str):
|
||||
try:
|
||||
attn_backend = resolve_obj_by_qualname(attn_backend)
|
||||
except ImportError:
|
||||
if not allow_import_error:
|
||||
raise
|
||||
|
||||
return _IsSupported(can_import=False, head_size=False, dtype=False)
|
||||
|
||||
assert isinstance(attn_backend, type)
|
||||
|
||||
# TODO: Update the interface once V0 is removed
|
||||
if get_supported_head_sizes := getattr(attn_backend,
|
||||
"get_supported_head_sizes", None):
|
||||
is_head_size_supported = head_size in get_supported_head_sizes()
|
||||
elif validate_head_size := getattr(attn_backend, "validate_head_size",
|
||||
None):
|
||||
try:
|
||||
validate_head_size(head_size)
|
||||
is_head_size_supported = True
|
||||
except Exception:
|
||||
is_head_size_supported = False
|
||||
else:
|
||||
raise NotImplementedError(f"{attn_backend.__name__} does not support "
|
||||
"head size validation")
|
||||
|
||||
if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes",
|
||||
None):
|
||||
is_dtype_supported = dtype in get_supported_dtypes()
|
||||
else:
|
||||
raise NotImplementedError(f"{attn_backend.__name__} does not support "
|
||||
"dtype validation")
|
||||
|
||||
return _IsSupported(
|
||||
can_import=True,
|
||||
head_size=is_head_size_supported,
|
||||
dtype=is_dtype_supported,
|
||||
)
|
||||
|
||||
|
||||
def get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||
# value to be returned from the cache if the value changes between calls.
|
||||
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
|
||||
# private function.
|
||||
return _cached_get_attn_backend(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
block_size=block_size,
|
||||
use_v1=envs.VLLM_USE_V1,
|
||||
use_mla=use_mla,
|
||||
has_sink=has_sink,
|
||||
use_sparse=use_sparse,
|
||||
)
|
||||
|
||||
|
||||
@cache
|
||||
def _cached_get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
|
||||
# Check whether a particular choice of backend was
|
||||
# previously forced.
|
||||
#
|
||||
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||
# ENVIRONMENT VARIABLE.
|
||||
selected_backend = None
|
||||
backend_by_global_setting: Optional[_Backend] = (
|
||||
get_global_forced_attn_backend())
|
||||
if backend_by_global_setting is not None:
|
||||
selected_backend = backend_by_global_setting
|
||||
else:
|
||||
# Check the environment variable and override if specified
|
||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||
if backend_by_env_var is not None:
|
||||
if backend_by_env_var.endswith("_VLLM_V1"):
|
||||
logger.warning(
|
||||
"The suffix '_VLLM_V1' in the environment variable "
|
||||
"%s is no longer necessary as V0 backends have been "
|
||||
"deprecated. Please remove this suffix from your "
|
||||
"environment variable setting.", STR_BACKEND_ENV_VAR)
|
||||
backend_by_env_var = backend_by_env_var.removesuffix(
|
||||
"_VLLM_V1")
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
if selected_backend is None:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: '{backend_by_env_var}'. "
|
||||
f"Valid backends are: {list(_Backend.__members__.keys())}")
|
||||
|
||||
# get device-specific attn_backend
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
|
||||
use_mla, has_sink, use_sparse)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}")
|
||||
return resolve_obj_by_qualname(attention_cls)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def global_force_attn_backend_context_manager(
|
||||
attn_backend: _Backend) -> Generator[None, None, None]:
|
||||
'''
|
||||
Globally force a vLLM attention backend override within a
|
||||
context manager, reverting the global attention backend
|
||||
override to its prior state upon exiting the context
|
||||
manager.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_backend: attention backend to force
|
||||
|
||||
Returns:
|
||||
|
||||
* Generator
|
||||
'''
|
||||
|
||||
# Save the current state of the global backend override (if any)
|
||||
original_value = get_global_forced_attn_backend()
|
||||
|
||||
# Globally force the new backend override
|
||||
global_force_attn_backend(attn_backend)
|
||||
|
||||
# Yield control back to the enclosed code block
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Revert the original global backend override, if any
|
||||
global_force_attn_backend(original_value)
|
||||
0
vllm/attention/utils/__init__.py
Normal file
0
vllm/attention/utils/__init__.py
Normal file
85
vllm/attention/utils/fa_utils.py
Normal file
85
vllm/attention/utils/fa_utils.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm import _custom_ops as ops
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
get_scheduler_metadata)
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
flash_attn_varlen_func = ops.flash_attn_varlen_func
|
||||
get_scheduler_metadata = ops.get_scheduler_metadata
|
||||
|
||||
|
||||
def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
|
||||
# import here to avoid circular dependencies
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_xpu():
|
||||
return 2
|
||||
try:
|
||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
fa_version_unsupported_reason, is_fa_version_supported)
|
||||
device_capability = current_platform.get_device_capability()
|
||||
|
||||
assert device_capability is not None
|
||||
|
||||
# 1. default version depending on platform
|
||||
fa_version = 3 if (device_capability.major == 9
|
||||
and is_fa_version_supported(3)) else 2
|
||||
|
||||
# 2. override if passed by environment
|
||||
if envs.VLLM_FLASH_ATTN_VERSION is not None:
|
||||
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
|
||||
fa_version = envs.VLLM_FLASH_ATTN_VERSION
|
||||
|
||||
# 3. fallback for unsupported combinations
|
||||
if device_capability.major == 10 and fa_version == 3:
|
||||
logger.warning_once(
|
||||
"Cannot use FA version 3 on Blackwell platform "
|
||||
"defaulting to FA version 2.")
|
||||
fa_version = 2
|
||||
|
||||
if requires_alibi and fa_version == 3:
|
||||
logger.warning_once("Cannot use FA version 3 with ALiBi, "
|
||||
"defaulting to FA version 2.")
|
||||
fa_version = 2
|
||||
|
||||
if not is_fa_version_supported(fa_version):
|
||||
logger.error("Cannot use FA version %d is not supported due to %s",
|
||||
fa_version, fa_version_unsupported_reason(fa_version))
|
||||
|
||||
assert is_fa_version_supported(fa_version)
|
||||
return fa_version
|
||||
except (ImportError, AssertionError):
|
||||
return None
|
||||
|
||||
|
||||
def flash_attn_supports_fp8() -> bool:
|
||||
return get_flash_attn_version() == 3 and \
|
||||
current_platform.get_device_capability().major == 9
|
||||
|
||||
|
||||
def flash_attn_supports_mla():
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_cuda():
|
||||
try:
|
||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
is_fa_version_supported)
|
||||
return is_fa_version_supported(3) \
|
||||
and current_platform.get_device_capability()[0] == 9
|
||||
except (ImportError, AssertionError):
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def is_flash_attn_varlen_func_available() -> bool:
|
||||
return current_platform.is_cuda() or current_platform.is_xpu()
|
||||
33
vllm/attention/utils/kv_sharing_utils.py
Normal file
33
vllm/attention/utils/kv_sharing_utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
def validate_kv_sharing_target(current_layer_name, target_layer_name,
|
||||
static_forward_context):
|
||||
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "
|
||||
f"is not valid: target layer {target_layer_name} ")
|
||||
|
||||
if current_layer_name == target_layer_name:
|
||||
raise ValueError(error_msg +
|
||||
"cannot be the same as the current layer.")
|
||||
|
||||
if target_layer_name not in static_forward_context:
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
|
||||
# If target layer name is not in the static fwd context, it means either
|
||||
# a) the target layer does not come BEFORE the current layer, or
|
||||
# b) the target layer is not an Attention layer that exists in the model
|
||||
current_layer_idx = extract_layer_index(current_layer_name)
|
||||
target_layer_idx = extract_layer_index(target_layer_name)
|
||||
if current_layer_idx <= target_layer_idx:
|
||||
raise ValueError(error_msg + "must come before the current layer.")
|
||||
else:
|
||||
raise ValueError(error_msg +
|
||||
"is not a valid Attention layer in the model.")
|
||||
|
||||
# Currently KV sharing is only supported between layers of the same type
|
||||
target_layer_attn_type = static_forward_context[
|
||||
target_layer_name].attn_type
|
||||
expected = static_forward_context[current_layer_name].attn_type
|
||||
if target_layer_attn_type != expected:
|
||||
raise ValueError(
|
||||
error_msg +
|
||||
f"must be the same type as the current layer ({expected}).")
|
||||
87
vllm/beam_search.py
Normal file
87
vllm/beam_search.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamSearchSequence:
|
||||
"""A sequence for beam search.
|
||||
It keeps track of the tokens and the log probability of the sequence.
|
||||
The text field is optional and will only be filled when the sequence is
|
||||
about to be returned to the user.
|
||||
"""
|
||||
# The tokens include the prompt.
|
||||
tokens: list[int]
|
||||
logprobs: list[dict[int, Logprob]]
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
cum_logprob: float = 0.0
|
||||
text: Optional[str] = None
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Union[int, str, None] = None
|
||||
multi_modal_data: Optional["MultiModalDataDict"] = None
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamSearchOutput:
|
||||
"""The output of beam search.
|
||||
It contains the list of the best beam search sequences.
|
||||
The length of the list is equal to the beam width.
|
||||
"""
|
||||
sequences: list[BeamSearchSequence]
|
||||
|
||||
|
||||
class BeamSearchInstance:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_tokens: list[int],
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
logprobs: Optional[list[dict[int, Logprob]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.beams: list[BeamSearchSequence] = [
|
||||
BeamSearchSequence(
|
||||
tokens=prompt_tokens,
|
||||
logprobs=[] if logprobs is None else list(logprobs),
|
||||
lora_request=lora_request,
|
||||
**kwargs,
|
||||
)
|
||||
]
|
||||
self.completed: list[BeamSearchSequence] = []
|
||||
|
||||
|
||||
def get_beam_search_score(
|
||||
tokens: list[int],
|
||||
cumulative_logprob: float,
|
||||
eos_token_id: int,
|
||||
length_penalty: float = 1.0,
|
||||
) -> float:
|
||||
"""Calculate the beam search score with length penalty.
|
||||
|
||||
Adapted from
|
||||
|
||||
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
|
||||
"""
|
||||
seq_len = len(tokens)
|
||||
if tokens[-1] == eos_token_id:
|
||||
seq_len -= 1
|
||||
|
||||
return cumulative_logprob / (seq_len**length_penalty)
|
||||
|
||||
|
||||
def create_sort_beams_key_function(eos_token_id: int, length_penalty: float):
|
||||
|
||||
def sort_beams_key(x: BeamSearchSequence) -> float:
|
||||
return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id,
|
||||
length_penalty)
|
||||
|
||||
return sort_beams_key
|
||||
0
vllm/benchmarks/__init__.py
Normal file
0
vllm/benchmarks/__init__.py
Normal file
2723
vllm/benchmarks/datasets.py
Normal file
2723
vllm/benchmarks/datasets.py
Normal file
File diff suppressed because it is too large
Load Diff
170
vllm/benchmarks/latency.py
Normal file
170
vllm/benchmarks/latency.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Benchmark the latency of processing a single batch of requests."""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.benchmarks.lib.utils import (convert_to_pytorch_benchmark_format,
|
||||
write_to_json)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
results: dict[str, Any]) -> None:
|
||||
pt_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={"latency": results["latencies"]},
|
||||
extra_info={k: results[k]
|
||||
for k in ["avg_latency", "percentiles"]})
|
||||
if pt_records:
|
||||
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--input-len", type=int, default=32)
|
||||
parser.add_argument("--output-len", type=int, default=128)
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of generated sequences per prompt.",
|
||||
)
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument(
|
||||
"--num-iters-warmup",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of iterations to run for warmup.",
|
||||
)
|
||||
parser.add_argument("--num-iters",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Number of iterations to run.")
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="profile the generation process of a single batch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the latency results in JSON format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-detokenize",
|
||||
action="store_true",
|
||||
help=("Do not detokenize responses (i.e. do not include "
|
||||
"detokenization time in the latency measurement)"),
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# V1 enables prefix caching by default which skews the latency
|
||||
# numbers. We need to disable prefix caching by default.
|
||||
parser.set_defaults(enable_prefix_caching=False)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
if args.profile and not envs.VLLM_TORCH_PROFILER_DIR:
|
||||
raise OSError(
|
||||
"The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. "
|
||||
"Please set it to a valid path to use torch profiler.")
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
|
||||
# Lazy import to avoid importing LLM when the bench command is not selected.
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||
# the engine will automatically process the request in multiple batches.
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
assert llm.llm_engine.model_config.max_model_len >= (
|
||||
args.input_len +
|
||||
args.output_len), ("Please ensure that max_model_len is greater than"
|
||||
" the sum of input_len and output_len.")
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=args.n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=args.output_len,
|
||||
detokenize=not args.disable_detokenize,
|
||||
)
|
||||
dummy_prompt_token_ids = np.random.randint(10000,
|
||||
size=(args.batch_size,
|
||||
args.input_len))
|
||||
dummy_prompts: list[PromptType] = [{
|
||||
"prompt_token_ids": batch
|
||||
} for batch in dummy_prompt_token_ids.tolist()]
|
||||
|
||||
def llm_generate():
|
||||
if not args.use_beam_search:
|
||||
llm.generate(dummy_prompts,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
else:
|
||||
llm.beam_search(
|
||||
dummy_prompts,
|
||||
BeamSearchParams(
|
||||
beam_width=args.n,
|
||||
max_tokens=args.output_len,
|
||||
ignore_eos=True,
|
||||
),
|
||||
)
|
||||
|
||||
def run_to_completion(profile_dir: Optional[str] = None):
|
||||
if profile_dir:
|
||||
llm.start_profile()
|
||||
llm_generate()
|
||||
llm.stop_profile()
|
||||
else:
|
||||
start_time = time.perf_counter()
|
||||
llm_generate()
|
||||
end_time = time.perf_counter()
|
||||
latency = end_time - start_time
|
||||
return latency
|
||||
|
||||
print("Warming up...")
|
||||
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
|
||||
run_to_completion(profile_dir=None)
|
||||
|
||||
if args.profile:
|
||||
profile_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
||||
run_to_completion(profile_dir=profile_dir)
|
||||
return
|
||||
|
||||
# Benchmark.
|
||||
latencies = []
|
||||
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
|
||||
latencies.append(run_to_completion(profile_dir=None))
|
||||
latencies = np.array(latencies)
|
||||
percentages = [10, 25, 50, 75, 90, 99]
|
||||
percentiles = np.percentile(latencies, percentages)
|
||||
print(f"Avg latency: {np.mean(latencies)} seconds")
|
||||
for percentage, percentile in zip(percentages, percentiles):
|
||||
print(f"{percentage}% percentile latency: {percentile} seconds")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"avg_latency": np.mean(latencies),
|
||||
"latencies": latencies.tolist(),
|
||||
"percentiles": dict(zip(percentages, percentiles.tolist())),
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
save_to_pytorch_benchmark_format(args, results)
|
||||
3
vllm/benchmarks/lib/__init__.py
Normal file
3
vllm/benchmarks/lib/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Benchmark library utilities."""
|
||||
533
vllm/benchmarks/lib/endpoint_request_func.py
Normal file
533
vllm/benchmarks/lib/endpoint_request_func.py
Normal file
@@ -0,0 +1,533 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""The request function for API endpoints."""
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Protocol, Union
|
||||
|
||||
import aiohttp
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
|
||||
|
||||
class StreamedResponseHandler:
|
||||
"""Handles streaming HTTP responses by accumulating chunks until complete
|
||||
messages are available."""
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = ""
|
||||
|
||||
def add_chunk(self, chunk_bytes: bytes) -> list[str]:
|
||||
"""Add a chunk of bytes to the buffer and return any complete
|
||||
messages."""
|
||||
chunk_str = chunk_bytes.decode("utf-8")
|
||||
self.buffer += chunk_str
|
||||
|
||||
messages = []
|
||||
|
||||
# Split by double newlines (SSE message separator)
|
||||
while "\n\n" in self.buffer:
|
||||
message, self.buffer = self.buffer.split("\n\n", 1)
|
||||
message = message.strip()
|
||||
if message:
|
||||
messages.append(message)
|
||||
|
||||
# if self.buffer is not empty, check if it is a complete message
|
||||
# by removing data: prefix and check if it is a valid JSON
|
||||
if self.buffer.startswith("data: "):
|
||||
message_content = self.buffer.removeprefix("data: ").strip()
|
||||
if message_content == "[DONE]":
|
||||
messages.append(self.buffer.strip())
|
||||
self.buffer = ""
|
||||
elif message_content:
|
||||
try:
|
||||
json.loads(message_content)
|
||||
messages.append(self.buffer.strip())
|
||||
self.buffer = ""
|
||||
except json.JSONDecodeError:
|
||||
# Incomplete JSON, wait for more chunks.
|
||||
pass
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestFuncInput:
|
||||
"""The input for the request function."""
|
||||
prompt: str
|
||||
api_url: str
|
||||
prompt_len: int
|
||||
output_len: int
|
||||
model: str
|
||||
model_name: Optional[str] = None
|
||||
logprobs: Optional[int] = None
|
||||
extra_headers: Optional[dict] = None
|
||||
extra_body: Optional[dict] = None
|
||||
multi_modal_content: Optional[Union[dict, list[dict]]] = None
|
||||
ignore_eos: bool = False
|
||||
language: Optional[str] = None
|
||||
request_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestFuncOutput:
|
||||
"""The output of the request function including metrics."""
|
||||
generated_text: str = ""
|
||||
success: bool = False
|
||||
latency: float = 0.0
|
||||
output_tokens: int = 0
|
||||
ttft: float = 0.0 # Time to first token
|
||||
itl: list[float] = field(
|
||||
default_factory=list) # list of inter-token latencies
|
||||
tpot: float = 0.0 # avg next-token latencies
|
||||
prompt_len: int = 0
|
||||
error: str = ""
|
||||
start_time: float = 0.0
|
||||
|
||||
|
||||
class RequestFunc(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> Awaitable[RequestFuncOutput]:
|
||||
...
|
||||
|
||||
|
||||
async def async_request_openai_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
"""The async request function for the OpenAI Completions API.
|
||||
|
||||
Args:
|
||||
request_func_input: The input for the request function.
|
||||
pbar: The progress bar to display the progress.
|
||||
|
||||
Returns:
|
||||
The output of the request function.
|
||||
"""
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
("completions", "profile")
|
||||
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"temperature": 0.0,
|
||||
"repetition_penalty": 1.0,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"logprobs": request_func_input.logprobs,
|
||||
"stream": True,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
},
|
||||
}
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||
}
|
||||
if request_func_input.extra_headers:
|
||||
headers |= request_func_input.extra_headers
|
||||
if request_func_input.request_id:
|
||||
headers["x-request-id"] = request_func_input.request_id
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
st = time.perf_counter()
|
||||
output.start_time = st
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
first_chunk_received = False
|
||||
handler = StreamedResponseHandler()
|
||||
|
||||
async for chunk_bytes in response.content.iter_any():
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
messages = handler.add_chunk(chunk_bytes)
|
||||
for message in messages:
|
||||
# NOTE: SSE comments (often used as pings) start with
|
||||
# a colon. These are not JSON data payload and should
|
||||
# be skipped.
|
||||
if message.startswith(":"):
|
||||
continue
|
||||
|
||||
chunk = message.removeprefix("data: ")
|
||||
|
||||
if chunk != "[DONE]":
|
||||
data = json.loads(chunk)
|
||||
|
||||
# NOTE: Some completion API might have a last
|
||||
# usage summary response without a token so we
|
||||
# want to check a token was generated
|
||||
if choices := data.get("choices"):
|
||||
# Note that text could be empty here
|
||||
# e.g. for special tokens
|
||||
text = choices[0].get("text")
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if not first_chunk_received:
|
||||
first_chunk_received = True
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += text or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
if first_chunk_received:
|
||||
output.success = True
|
||||
else:
|
||||
output.success = False
|
||||
output.error = (
|
||||
"Never received a valid chunk to calculate TTFT."
|
||||
"This response will be marked as failed!")
|
||||
output.generated_text = generated_text
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_chat_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(("chat/completions", "profile")), (
|
||||
"OpenAI Chat Completions API URL must end with 'chat/completions'.")
|
||||
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
if request_func_input.multi_modal_content:
|
||||
mm_content = request_func_input.multi_modal_content
|
||||
if isinstance(mm_content, list):
|
||||
content.extend(mm_content)
|
||||
elif isinstance(mm_content, dict):
|
||||
content.append(mm_content)
|
||||
else:
|
||||
raise TypeError(
|
||||
"multi_modal_content must be a dict or list[dict] "
|
||||
"for openai-chat"
|
||||
)
|
||||
payload = {
|
||||
"model":
|
||||
request_func_input.model_name
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": content
|
||||
},
|
||||
],
|
||||
"temperature":
|
||||
0.0,
|
||||
"max_completion_tokens":
|
||||
request_func_input.output_len,
|
||||
"stream":
|
||||
True,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
},
|
||||
}
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
if request_func_input.extra_headers:
|
||||
headers |= request_func_input.extra_headers
|
||||
if request_func_input.request_id:
|
||||
headers["x-request-id"] = request_func_input.request_id
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
output.start_time = st
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
handler = StreamedResponseHandler()
|
||||
async for chunk_bytes in response.content.iter_any():
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
messages = handler.add_chunk(chunk_bytes)
|
||||
for message in messages:
|
||||
# NOTE: SSE comments (often used as pings) start with
|
||||
# a colon. These are not JSON data payload and should
|
||||
# be skipped.
|
||||
if message.startswith(":"):
|
||||
continue
|
||||
|
||||
chunk = message.removeprefix("data: ")
|
||||
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get("content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_audio(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
# Lazy import without PlaceholderModule to avoid vllm dep.
|
||||
import soundfile
|
||||
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(("transcriptions", "translations")), (
|
||||
"OpenAI Chat Completions API URL must end with 'transcriptions' ")
|
||||
"or `translations`."
|
||||
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
payload = {
|
||||
"model":
|
||||
request_func_input.model_name
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"temperature":
|
||||
0.0,
|
||||
"max_completion_tokens":
|
||||
request_func_input.output_len,
|
||||
"stream":
|
||||
True,
|
||||
"language":
|
||||
"en",
|
||||
# Flattened due to multipart/form-data
|
||||
"stream_include_usage":
|
||||
True,
|
||||
"stream_continuous_usage_stats":
|
||||
True,
|
||||
}
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
if request_func_input.extra_headers:
|
||||
headers |= request_func_input.extra_headers
|
||||
if request_func_input.request_id:
|
||||
headers["x-request-id"] = request_func_input.request_id
|
||||
|
||||
# Send audio file
|
||||
def to_bytes(y, sr):
|
||||
buffer = io.BytesIO()
|
||||
soundfile.write(buffer, y, sr, format="WAV")
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
mm_audio = request_func_input.multi_modal_content
|
||||
if not isinstance(mm_audio, dict) or "audio" not in mm_audio:
|
||||
raise TypeError("multi_modal_content must be a dict containing 'audio'")
|
||||
with to_bytes(*mm_audio["audio"]) as f:
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", f, content_type="audio/wav")
|
||||
for key, value in payload.items():
|
||||
form.add_field(key, str(value))
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
output.start_time = st
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url,
|
||||
data=form,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
handler = StreamedResponseHandler()
|
||||
|
||||
async for chunk_bytes in response.content.iter_any():
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
messages = handler.add_chunk(chunk_bytes)
|
||||
for message in messages:
|
||||
chunk = message.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get(
|
||||
"content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(
|
||||
timestamp - most_recent_timestamp)
|
||||
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_embeddings(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: Optional[tqdm] = None,
|
||||
):
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
"embeddings"
|
||||
), "OpenAI Embeddings API URL must end with 'embeddings'."
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model,
|
||||
"input": request_func_input.prompt,
|
||||
}
|
||||
|
||||
output = RequestFuncOutput()
|
||||
st = time.perf_counter()
|
||||
output.start_time = st
|
||||
try:
|
||||
async with session.post(
|
||||
url=api_url,
|
||||
headers=headers,
|
||||
json=payload
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
output.latency = time.perf_counter() - st
|
||||
data = await response.json()
|
||||
output.success = True
|
||||
output.generated_text = ""
|
||||
output.prompt_len = data.get(
|
||||
"usage", {}).get(
|
||||
"prompt_tokens", 0)
|
||||
else:
|
||||
output.success = False
|
||||
output.error = response.reason or ""
|
||||
except Exception as e:
|
||||
output.success = False
|
||||
output.error = str(e)
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
# TODO: Add more request functions for different API protocols.
|
||||
ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = {
|
||||
"vllm": async_request_openai_completions,
|
||||
"openai": async_request_openai_completions,
|
||||
"openai-chat": async_request_openai_chat_completions,
|
||||
"openai-audio": async_request_openai_audio,
|
||||
"openai-embeddings": async_request_openai_embeddings,
|
||||
}
|
||||
|
||||
OPENAI_COMPATIBLE_BACKENDS = [
|
||||
k for k, v in ASYNC_REQUEST_FUNCS.items()
|
||||
if v in (async_request_openai_completions,
|
||||
async_request_openai_chat_completions)
|
||||
]
|
||||
73
vllm/benchmarks/lib/ready_checker.py
Normal file
73
vllm/benchmarks/lib/ready_checker.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utilities for checking endpoint readiness."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import aiohttp
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
from .endpoint_request_func import (RequestFunc, RequestFuncInput,
|
||||
RequestFuncOutput)
|
||||
|
||||
|
||||
async def wait_for_endpoint(
|
||||
request_func: RequestFunc,
|
||||
test_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
timeout_seconds: int = 600,
|
||||
retry_interval: int = 5,
|
||||
) -> RequestFuncOutput:
|
||||
"""
|
||||
Wait for an endpoint to become available before starting benchmarks.
|
||||
|
||||
Args:
|
||||
request_func: The async request function to call
|
||||
test_input: The RequestFuncInput to test with
|
||||
timeout_seconds: Maximum time to wait in seconds (default: 10 minutes)
|
||||
retry_interval: Time between retries in seconds (default: 5 seconds)
|
||||
|
||||
Returns:
|
||||
RequestFuncOutput: The successful response
|
||||
|
||||
Raises:
|
||||
ValueError: If the endpoint doesn't become available within the timeout
|
||||
"""
|
||||
deadline = time.perf_counter() + timeout_seconds
|
||||
output = RequestFuncOutput(success=False)
|
||||
print(f"Waiting for endpoint to become up in {timeout_seconds} seconds")
|
||||
|
||||
with tqdm(
|
||||
total=timeout_seconds,
|
||||
bar_format="{desc} |{bar}| {elapsed} elapsed, {remaining} remaining",
|
||||
unit="s",
|
||||
) as pbar:
|
||||
|
||||
while True:
|
||||
# update progress bar
|
||||
remaining = deadline - time.perf_counter()
|
||||
elapsed = timeout_seconds - remaining
|
||||
update_amount = min(elapsed - pbar.n, timeout_seconds - pbar.n)
|
||||
pbar.update(update_amount)
|
||||
pbar.refresh()
|
||||
if remaining <= 0:
|
||||
pbar.close()
|
||||
break
|
||||
|
||||
# ping the endpoint using request_func
|
||||
try:
|
||||
output = await request_func(
|
||||
request_func_input=test_input, session=session)
|
||||
if output.success:
|
||||
pbar.close()
|
||||
return output
|
||||
except aiohttp.ClientConnectorError:
|
||||
pass
|
||||
|
||||
# retry after a delay
|
||||
sleep_duration = min(retry_interval, remaining)
|
||||
if sleep_duration > 0:
|
||||
await asyncio.sleep(sleep_duration)
|
||||
|
||||
return output
|
||||
80
vllm/benchmarks/lib/utils.py
Normal file
80
vllm/benchmarks/lib/utils.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
|
||||
def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
metrics: dict[str, list],
|
||||
extra_info: dict[str, Any]) -> list:
|
||||
"""
|
||||
Save the benchmark results in the format used by PyTorch OSS benchmark with
|
||||
on metric per record
|
||||
https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
|
||||
"""
|
||||
records = []
|
||||
if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False):
|
||||
return records
|
||||
|
||||
for name, benchmark_values in metrics.items():
|
||||
record = {
|
||||
"benchmark": {
|
||||
"name": "vLLM benchmark",
|
||||
"extra_info": {
|
||||
"args": vars(args),
|
||||
},
|
||||
},
|
||||
"model": {
|
||||
"name": args.model,
|
||||
},
|
||||
"metric": {
|
||||
"name": name,
|
||||
"benchmark_values": benchmark_values,
|
||||
"extra_info": extra_info,
|
||||
},
|
||||
}
|
||||
|
||||
tp = record["benchmark"]["extra_info"]["args"].get(
|
||||
"tensor_parallel_size")
|
||||
# Save tensor_parallel_size parameter if it's part of the metadata
|
||||
if not tp and "tensor_parallel_size" in extra_info:
|
||||
record["benchmark"]["extra_info"]["args"][
|
||||
"tensor_parallel_size"] = extra_info["tensor_parallel_size"]
|
||||
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
|
||||
|
||||
class InfEncoder(json.JSONEncoder):
|
||||
|
||||
def clear_inf(self, o: Any):
|
||||
if isinstance(o, dict):
|
||||
return {
|
||||
str(k)
|
||||
if not isinstance(k, (str, int, float, bool, type(None)))
|
||||
else k: self.clear_inf(v)
|
||||
for k, v in o.items()
|
||||
}
|
||||
elif isinstance(o, list):
|
||||
return [self.clear_inf(v) for v in o]
|
||||
elif isinstance(o, float) and math.isinf(o):
|
||||
return "inf"
|
||||
return o
|
||||
|
||||
def iterencode(self, o: Any, *args, **kwargs) -> Any:
|
||||
return super().iterencode(self.clear_inf(o), *args, **kwargs)
|
||||
|
||||
|
||||
def write_to_json(filename: str, records: list) -> None:
|
||||
with open(filename, "w") as f:
|
||||
json.dump(
|
||||
records,
|
||||
f,
|
||||
cls=InfEncoder,
|
||||
default=lambda o: f"<{type(o).__name__} is not JSON serializable>",
|
||||
)
|
||||
1358
vllm/benchmarks/serve.py
Normal file
1358
vllm/benchmarks/serve.py
Normal file
File diff suppressed because it is too large
Load Diff
696
vllm/benchmarks/throughput.py
Normal file
696
vllm/benchmarks/throughput.py
Normal file
@@ -0,0 +1,696 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Benchmark offline inference throughput."""
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from tqdm import tqdm
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
PreTrainedTokenizerBase)
|
||||
|
||||
from vllm.benchmarks.datasets import (AIMODataset, BurstGPTDataset,
|
||||
ConversationDataset,
|
||||
InstructCoderDataset,
|
||||
PrefixRepetitionRandomDataset,
|
||||
RandomDataset, SampleRequest,
|
||||
ShareGPTDataset, SonnetDataset,
|
||||
VisionArenaDataset)
|
||||
from vllm.benchmarks.lib.utils import (convert_to_pytorch_benchmark_format,
|
||||
write_to_json)
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.utils import merge_async_iterators
|
||||
|
||||
|
||||
def run_vllm(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
do_profile: bool,
|
||||
disable_detokenize: bool = False,
|
||||
) -> tuple[float, Optional[list[RequestOutput]]]:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len >= (
|
||||
request.prompt_len + request.expected_output_len)
|
||||
for request in requests), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests.")
|
||||
# Add the requests to the engine.
|
||||
prompts: list[Union[TextPrompt, TokensPrompt]] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
for request in requests:
|
||||
prompts.append(
|
||||
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
|
||||
multi_modal_data=request.multi_modal_data)
|
||||
if "prompt_token_ids" in request.prompt else \
|
||||
TextPrompt(prompt=request.prompt,
|
||||
multi_modal_data=request.multi_modal_data))
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
))
|
||||
lora_requests: Optional[list[LoRARequest]] = None
|
||||
if engine_args.enable_lora:
|
||||
lora_requests = [request.lora_request for request in requests]
|
||||
|
||||
use_beam_search = False
|
||||
|
||||
outputs = None
|
||||
if not use_beam_search:
|
||||
start = time.perf_counter()
|
||||
if do_profile:
|
||||
llm.start_profile()
|
||||
outputs = llm.generate(prompts,
|
||||
sampling_params,
|
||||
lora_request=lora_requests,
|
||||
use_tqdm=True)
|
||||
if do_profile:
|
||||
llm.stop_profile()
|
||||
end = time.perf_counter()
|
||||
else:
|
||||
assert lora_requests is None, "BeamSearch API does not support LoRA"
|
||||
prompts = [request.prompt for request in requests]
|
||||
# output_len should be the same for all requests.
|
||||
output_len = requests[0].expected_output_len
|
||||
for request in requests:
|
||||
assert request.expected_output_len == output_len
|
||||
start = time.perf_counter()
|
||||
if do_profile:
|
||||
llm.start_profile()
|
||||
llm.beam_search(
|
||||
prompts,
|
||||
BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
))
|
||||
if do_profile:
|
||||
llm.stop_profile()
|
||||
end = time.perf_counter()
|
||||
return end - start, outputs
|
||||
|
||||
|
||||
def run_vllm_chat(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
do_profile: bool,
|
||||
disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]:
|
||||
"""
|
||||
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
|
||||
multimodal models as it properly handles multimodal inputs and chat
|
||||
formatting. For non-multimodal models, use run_vllm() instead.
|
||||
"""
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len >= (
|
||||
request.prompt_len + request.expected_output_len)
|
||||
for request in requests), (
|
||||
"Please ensure that max_model_len is greater than the sum of "
|
||||
"prompt_len and expected_output_len for all requests.")
|
||||
|
||||
prompts = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
for request in requests:
|
||||
prompts.append(request.prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
))
|
||||
start = time.perf_counter()
|
||||
if do_profile:
|
||||
llm.start_profile()
|
||||
outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
|
||||
if do_profile:
|
||||
llm.stop_profile()
|
||||
end = time.perf_counter()
|
||||
return end - start, outputs
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: AsyncEngineArgs,
|
||||
do_profile: bool,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
from vllm import SamplingParams
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args)
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args,
|
||||
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
|
||||
) as llm:
|
||||
model_config = await llm.get_model_config()
|
||||
assert all(
|
||||
model_config.max_model_len >= (request.prompt_len +
|
||||
request.expected_output_len)
|
||||
for request in requests), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests.")
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: list[Union[TextPrompt, TokensPrompt]] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
lora_requests: list[Optional[LoRARequest]] = []
|
||||
for request in requests:
|
||||
prompts.append(
|
||||
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
|
||||
multi_modal_data=request.multi_modal_data)
|
||||
if "prompt_token_ids" in request.prompt else \
|
||||
TextPrompt(prompt=request.prompt,
|
||||
multi_modal_data=request.multi_modal_data))
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
))
|
||||
lora_requests.append(request.lora_request)
|
||||
|
||||
generators = []
|
||||
start = time.perf_counter()
|
||||
if do_profile:
|
||||
await llm.start_profile()
|
||||
for i, (prompt, sp,
|
||||
lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
|
||||
generator = llm.generate(prompt,
|
||||
sp,
|
||||
lora_request=lr,
|
||||
request_id=f"test{i}")
|
||||
generators.append(generator)
|
||||
all_gens = merge_async_iterators(*generators)
|
||||
async for i, res in all_gens:
|
||||
pass
|
||||
if do_profile:
|
||||
await llm.stop_profile()
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def run_hf(
|
||||
requests: list[SampleRequest],
|
||||
model: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
n: int,
|
||||
max_batch_size: int,
|
||||
trust_remote_code: bool,
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
||||
if llm.config.model_type == "llama":
|
||||
# To enable padding in the HF backend.
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
llm = llm.cuda()
|
||||
|
||||
pbar = tqdm(total=len(requests))
|
||||
start = time.perf_counter()
|
||||
batch: list[str] = []
|
||||
max_prompt_len = 0
|
||||
max_output_len = 0
|
||||
for i in range(len(requests)):
|
||||
prompt = requests[i].prompt
|
||||
prompt_len = requests[i].prompt_len
|
||||
output_len = requests[i].expected_output_len
|
||||
# Add the prompt to the batch.
|
||||
batch.append(prompt)
|
||||
max_prompt_len = max(max_prompt_len, prompt_len)
|
||||
max_output_len = max(max_output_len, output_len)
|
||||
if len(batch) < max_batch_size and i != len(requests) - 1:
|
||||
# Check if we can add more requests to the batch.
|
||||
next_prompt_len = requests[i + 1].prompt_len
|
||||
next_output_len = requests[i + 1].expected_output_len
|
||||
if (max(max_prompt_len, next_prompt_len) +
|
||||
max(max_output_len, next_output_len)) <= 2048:
|
||||
# We can add more requests to the batch.
|
||||
continue
|
||||
|
||||
# Generate the sequences.
|
||||
input_ids = tokenizer(batch, return_tensors="pt",
|
||||
padding=True).input_ids
|
||||
llm_outputs = llm.generate(
|
||||
input_ids=input_ids.cuda(),
|
||||
do_sample=True,
|
||||
num_return_sequences=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
use_cache=True,
|
||||
max_new_tokens=max_output_len,
|
||||
)
|
||||
if not disable_detokenize:
|
||||
# Include the decoding time.
|
||||
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
|
||||
pbar.update(len(batch))
|
||||
|
||||
# Clear the batch.
|
||||
batch = []
|
||||
max_prompt_len = 0
|
||||
max_output_len = 0
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
results: dict[str, Any]) -> None:
|
||||
pt_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"requests_per_second": [results["requests_per_second"]],
|
||||
"tokens_per_second": [results["tokens_per_second"]],
|
||||
},
|
||||
extra_info={
|
||||
k: results[k]
|
||||
for k in ["elapsed_time", "num_requests", "total_num_tokens"]
|
||||
})
|
||||
if pt_records:
|
||||
# Don't use json suffix here as we don't want CI to pick it up
|
||||
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
def get_requests(args, tokenizer):
|
||||
# Common parameters for all dataset types.
|
||||
common_kwargs = {
|
||||
"dataset_path": args.dataset_path,
|
||||
"random_seed": args.seed,
|
||||
}
|
||||
sample_kwargs = {
|
||||
"tokenizer": tokenizer,
|
||||
"lora_path": args.lora_path,
|
||||
"max_loras": args.max_loras,
|
||||
"num_requests": args.num_prompts,
|
||||
"input_len": args.input_len,
|
||||
"output_len": args.output_len,
|
||||
}
|
||||
|
||||
if args.dataset_path is None or args.dataset_name == "random":
|
||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
dataset_cls = RandomDataset
|
||||
elif args.dataset_name == "sharegpt":
|
||||
dataset_cls = ShareGPTDataset
|
||||
if args.backend == "vllm-chat":
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_name == "sonnet":
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
"Tokenizer/model must have chat template for sonnet dataset.")
|
||||
dataset_cls = SonnetDataset
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
sample_kwargs["return_prompt_formatted"] = True
|
||||
elif args.dataset_name == "burstgpt":
|
||||
dataset_cls = BurstGPTDataset
|
||||
elif args.dataset_name == "hf":
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = VisionArenaDataset
|
||||
common_kwargs['dataset_subset'] = None
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = InstructCoderDataset
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = ConversationDataset
|
||||
common_kwargs['dataset_subset'] = args.hf_subset
|
||||
common_kwargs['dataset_split'] = args.hf_split
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = AIMODataset
|
||||
common_kwargs['dataset_subset'] = None
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
elif args.dataset_name == "prefix_repetition":
|
||||
dataset_cls = PrefixRepetitionRandomDataset
|
||||
sample_kwargs["prefix_len"] = args.prefix_repetition_prefix_len
|
||||
sample_kwargs["suffix_len"] = args.prefix_repetition_suffix_len
|
||||
sample_kwargs["num_prefixes"] = args.prefix_repetition_num_prefixes
|
||||
sample_kwargs["output_len"] = args.prefix_repetition_output_len
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
||||
# Remove None values
|
||||
sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
|
||||
return dataset_cls(**common_kwargs).sample(**sample_kwargs)
|
||||
|
||||
|
||||
def validate_args(args):
|
||||
"""
|
||||
Validate command-line arguments.
|
||||
"""
|
||||
|
||||
# === Deprecation and Defaulting ===
|
||||
if args.dataset is not None:
|
||||
warnings.warn(
|
||||
"The '--dataset' argument will be deprecated in the next release. "
|
||||
"Please use '--dataset-name' and '--dataset-path' instead.",
|
||||
stacklevel=2)
|
||||
args.dataset_path = args.dataset
|
||||
|
||||
if not getattr(args, "tokenizer", None):
|
||||
args.tokenizer = args.model
|
||||
|
||||
# === Backend Validation ===
|
||||
valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
|
||||
if args.backend not in valid_backends:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
|
||||
# === Dataset Configuration ===
|
||||
if (
|
||||
not args.dataset
|
||||
and not args.dataset_path
|
||||
and args.dataset_name not in {"prefix_repetition"}
|
||||
):
|
||||
print(
|
||||
"When dataset path is not set, it will default to random dataset")
|
||||
args.dataset_name = 'random'
|
||||
if args.input_len is None:
|
||||
raise ValueError("input_len must be provided for a random dataset")
|
||||
|
||||
# === Dataset Name Specific Checks ===
|
||||
# --hf-subset and --hf-split: only used
|
||||
# when dataset_name is 'hf'
|
||||
if args.dataset_name != "hf" and (
|
||||
getattr(args, "hf_subset", None) is not None
|
||||
or getattr(args, "hf_split", None) is not None):
|
||||
warnings.warn("--hf-subset and --hf-split will be ignored \
|
||||
since --dataset-name is not 'hf'.",
|
||||
stacklevel=2)
|
||||
elif args.dataset_name == "hf":
|
||||
if args.dataset_path in (
|
||||
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
|
||||
| ConversationDataset.SUPPORTED_DATASET_PATHS):
|
||||
assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501
|
||||
elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS
|
||||
| AIMODataset.SUPPORTED_DATASET_PATHS):
|
||||
assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{args.dataset_path} is not supported by hf dataset.")
|
||||
|
||||
# --random-range-ratio: only used when dataset_name is 'random'
|
||||
if args.dataset_name != 'random' and args.random_range_ratio is not None:
|
||||
warnings.warn("--random-range-ratio will be ignored since \
|
||||
--dataset-name is not 'random'.",
|
||||
stacklevel=2)
|
||||
|
||||
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
|
||||
# set.
|
||||
if args.dataset_name not in {"random", "sonnet", None
|
||||
} and args.prefix_len is not None:
|
||||
warnings.warn("--prefix-len will be ignored since --dataset-name\
|
||||
is not 'random', 'sonnet', or not set.",
|
||||
stacklevel=2)
|
||||
|
||||
# === LoRA Settings ===
|
||||
if getattr(args, "enable_lora", False) and args.backend != "vllm":
|
||||
raise ValueError(
|
||||
"LoRA benchmarking is only supported for vLLM backend")
|
||||
if getattr(args, "enable_lora", False) and args.lora_path is None:
|
||||
raise ValueError("LoRA path must be provided when enable_lora is True")
|
||||
|
||||
# === Backend-specific Validations ===
|
||||
if args.backend == "hf" and args.hf_max_batch_size is None:
|
||||
raise ValueError("HF max batch size is required for HF backend")
|
||||
if args.backend != "hf" and args.hf_max_batch_size is not None:
|
||||
raise ValueError("HF max batch size is only for HF backend.")
|
||||
|
||||
if args.backend in {"hf", "mii"} and getattr(args, "quantization",
|
||||
None) is not None:
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
|
||||
if args.backend == "mii" and args.dtype != "auto":
|
||||
raise ValueError("dtype must be auto for MII backend.")
|
||||
if args.backend == "mii" and args.n != 1:
|
||||
raise ValueError("n must be 1 for MII backend.")
|
||||
if args.backend == "mii" and args.tokenizer != args.model:
|
||||
raise ValueError(
|
||||
"Tokenizer must be the same as the model for MII backend.")
|
||||
|
||||
# --data-parallel is not supported currently.
|
||||
# https://github.com/vllm-project/vllm/issues/16222
|
||||
if args.data_parallel_size > 1:
|
||||
raise ValueError(
|
||||
"Data parallel is not supported in offline benchmark, "
|
||||
"please use benchmark serving instead"
|
||||
)
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--backend",
|
||||
type=str,
|
||||
choices=["vllm", "hf", "mii", "vllm-chat"],
|
||||
default="vllm")
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
choices=[
|
||||
"sharegpt", "random", "sonnet", "burstgpt", "hf",
|
||||
"prefix_repetition"
|
||||
],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
default="sharegpt")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the ShareGPT dataset, will be deprecated in\
|
||||
the next release. The dataset is expected to "
|
||||
"be a json in form of list[dict[..., conversations: "
|
||||
"list[dict[..., value: <prompt_or_response>]]]]")
|
||||
parser.add_argument("--dataset-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset")
|
||||
parser.add_argument("--input-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Input prompt length for each request")
|
||||
parser.add_argument("--output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the "
|
||||
"output length from the dataset.")
|
||||
parser.add_argument("--n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of generated sequences per prompt.")
|
||||
parser.add_argument("--num-prompts",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of prompts to process.")
|
||||
parser.add_argument("--hf-max-batch-size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum batch size for HF backend.")
|
||||
parser.add_argument(
|
||||
'--output-json',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to save the throughput results in JSON format.')
|
||||
parser.add_argument("--async-engine",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Use vLLM async engine rather than LLM class.")
|
||||
parser.add_argument("--disable-frontend-multiprocessing",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable decoupled async engine frontend.")
|
||||
parser.add_argument(
|
||||
"--disable-detokenize",
|
||||
action="store_true",
|
||||
help=("Do not detokenize the response (i.e. do not include "
|
||||
"detokenization time in the measurement)"))
|
||||
# LoRA
|
||||
parser.add_argument(
|
||||
"--lora-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the lora adapters to use. This can be an absolute path, "
|
||||
"a relative path, or a Hugging Face model identifier.")
|
||||
parser.add_argument(
|
||||
"--prefix-len",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of fixed prefix tokens before the random "
|
||||
"context in a request (default: 0).",
|
||||
)
|
||||
# random dataset
|
||||
parser.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Range ratio for sampling input/output length, "
|
||||
"used only for RandomDataset. Must be in the range [0, 1) to define "
|
||||
"a symmetric sampling range "
|
||||
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
|
||||
)
|
||||
|
||||
# hf dtaset
|
||||
parser.add_argument("--hf-subset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Subset of the HF dataset.")
|
||||
parser.add_argument("--hf-split",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Split of the HF dataset.")
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use Torch Profiler. The env variable "
|
||||
"VLLM_TORCH_PROFILER_DIR must be set to enable profiler.")
|
||||
|
||||
# prefix repetition dataset
|
||||
prefix_repetition_group = parser.add_argument_group(
|
||||
"prefix repetition dataset options")
|
||||
prefix_repetition_group.add_argument(
|
||||
"--prefix-repetition-prefix-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of prefix tokens per request, used only for prefix "
|
||||
"repetition dataset.",
|
||||
)
|
||||
prefix_repetition_group.add_argument(
|
||||
"--prefix-repetition-suffix-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of suffix tokens per request, used only for prefix "
|
||||
"repetition dataset. Total input length is prefix_len + suffix_len.",
|
||||
)
|
||||
prefix_repetition_group.add_argument(
|
||||
"--prefix-repetition-num-prefixes",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of prefixes to generate, used only for prefix repetition "
|
||||
"dataset. Prompts per prefix is num_requests // num_prefixes.",
|
||||
)
|
||||
prefix_repetition_group.add_argument(
|
||||
"--prefix-repetition-output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of output tokens per request, used only for prefix "
|
||||
"repetition dataset.",
|
||||
)
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
validate_args(args)
|
||||
if args.seed is None:
|
||||
args.seed = 0
|
||||
random.seed(args.seed)
|
||||
# Sample the requests.
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
requests = get_requests(args, tokenizer)
|
||||
is_multi_modal = any(request.multi_modal_data is not None
|
||||
for request in requests)
|
||||
request_outputs: Optional[list[RequestOutput]] = None
|
||||
if args.backend == "vllm":
|
||||
if args.async_engine:
|
||||
elapsed_time = uvloop.run(
|
||||
run_vllm_async(
|
||||
requests,
|
||||
args.n,
|
||||
AsyncEngineArgs.from_cli_args(args),
|
||||
disable_frontend_multiprocessing=args.disable_frontend_multiprocessing,
|
||||
disable_detokenize=args.disable_detokenize,
|
||||
do_profile=args.profile,
|
||||
))
|
||||
else:
|
||||
elapsed_time, request_outputs = run_vllm(
|
||||
requests, args.n, EngineArgs.from_cli_args(args),
|
||||
disable_detokenize=args.disable_detokenize,
|
||||
do_profile=args.profile)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
if args.profile:
|
||||
raise NotImplementedError(
|
||||
"Profiling not implemented yet for backend='hf'.")
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
args.hf_max_batch_size, args.trust_remote_code,
|
||||
args.disable_detokenize)
|
||||
elif args.backend == "vllm-chat":
|
||||
elapsed_time, request_outputs = run_vllm_chat(
|
||||
requests, args.n, EngineArgs.from_cli_args(args),
|
||||
disable_detokenize=args.disable_detokenize, do_profile=args.profile)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
|
||||
if request_outputs:
|
||||
# Note: with the vllm and vllm-chat backends,
|
||||
# we have request_outputs, which we use to count tokens.
|
||||
total_prompt_tokens = 0
|
||||
total_output_tokens = 0
|
||||
for ro in request_outputs:
|
||||
if not isinstance(ro, RequestOutput):
|
||||
continue
|
||||
total_prompt_tokens += len(
|
||||
ro.prompt_token_ids) if ro.prompt_token_ids else 0
|
||||
total_output_tokens += sum(
|
||||
len(o.token_ids) for o in ro.outputs if o)
|
||||
total_num_tokens = total_prompt_tokens + total_output_tokens
|
||||
else:
|
||||
total_num_tokens = sum(r.prompt_len + r.expected_output_len
|
||||
for r in requests)
|
||||
total_output_tokens = sum(r.expected_output_len for r in requests)
|
||||
total_prompt_tokens = total_num_tokens - total_output_tokens
|
||||
|
||||
if is_multi_modal and args.backend != "vllm-chat":
|
||||
print("\033[91mWARNING\033[0m: Multi-modal request with "
|
||||
f"{args.backend} backend detected. The "
|
||||
"following metrics are not accurate because image tokens are not"
|
||||
" counted. See vllm-project/vllm/issues/9778 for details.")
|
||||
# TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
|
||||
# vllm-chat backend counts the image tokens now
|
||||
|
||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
|
||||
print(f"Total num prompt tokens: {total_prompt_tokens}")
|
||||
print(f"Total num output tokens: {total_output_tokens}")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"elapsed_time": elapsed_time,
|
||||
"num_requests": len(requests),
|
||||
"total_num_tokens": total_num_tokens,
|
||||
"requests_per_second": len(requests) / elapsed_time,
|
||||
"tokens_per_second": total_num_tokens / elapsed_time,
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
save_to_pytorch_benchmark_format(args, results)
|
||||
823
vllm/collect_env.py
Normal file
823
vllm/collect_env.py
Normal file
@@ -0,0 +1,823 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# ruff: noqa
|
||||
# code borrowed from https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py
|
||||
|
||||
import datetime
|
||||
import locale
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
# Unlike the rest of the PyTorch this file must be python2 compliant.
|
||||
# This script outputs relevant system environment info
|
||||
# Run it with `python collect_env.py` or `python -m torch.utils.collect_env`
|
||||
from collections import namedtuple
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.envs import environment_variables
|
||||
|
||||
try:
|
||||
import torch
|
||||
TORCH_AVAILABLE = True
|
||||
except (ImportError, NameError, AttributeError, OSError):
|
||||
TORCH_AVAILABLE = False
|
||||
|
||||
# System Environment Information
|
||||
SystemEnv = namedtuple(
|
||||
'SystemEnv',
|
||||
[
|
||||
'torch_version',
|
||||
'is_debug_build',
|
||||
'cuda_compiled_version',
|
||||
'gcc_version',
|
||||
'clang_version',
|
||||
'cmake_version',
|
||||
'os',
|
||||
'libc_version',
|
||||
'python_version',
|
||||
'python_platform',
|
||||
'is_cuda_available',
|
||||
'cuda_runtime_version',
|
||||
'cuda_module_loading',
|
||||
'nvidia_driver_version',
|
||||
'nvidia_gpu_models',
|
||||
'cudnn_version',
|
||||
'pip_version', # 'pip' or 'pip3'
|
||||
'pip_packages',
|
||||
'conda_packages',
|
||||
'hip_compiled_version',
|
||||
'hip_runtime_version',
|
||||
'miopen_runtime_version',
|
||||
'caching_allocator_config',
|
||||
'is_xnnpack_available',
|
||||
'cpu_info',
|
||||
'rocm_version', # vllm specific field
|
||||
'vllm_version', # vllm specific field
|
||||
'vllm_build_flags', # vllm specific field
|
||||
'gpu_topo', # vllm specific field
|
||||
'env_vars',
|
||||
])
|
||||
|
||||
DEFAULT_CONDA_PATTERNS = {
|
||||
"torch",
|
||||
"numpy",
|
||||
"cudatoolkit",
|
||||
"soumith",
|
||||
"mkl",
|
||||
"magma",
|
||||
"triton",
|
||||
"optree",
|
||||
"nccl",
|
||||
"transformers",
|
||||
"zmq",
|
||||
"nvidia",
|
||||
"pynvml",
|
||||
"flashinfer-python",
|
||||
}
|
||||
|
||||
DEFAULT_PIP_PATTERNS = {
|
||||
"torch",
|
||||
"numpy",
|
||||
"mypy",
|
||||
"flake8",
|
||||
"triton",
|
||||
"optree",
|
||||
"onnx",
|
||||
"nccl",
|
||||
"transformers",
|
||||
"zmq",
|
||||
"nvidia",
|
||||
"pynvml",
|
||||
"flashinfer-python",
|
||||
}
|
||||
|
||||
|
||||
def run(command):
|
||||
"""Return (return-code, stdout, stderr)."""
|
||||
shell = True if type(command) is str else False
|
||||
try:
|
||||
p = subprocess.Popen(command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
shell=shell)
|
||||
raw_output, raw_err = p.communicate()
|
||||
rc = p.returncode
|
||||
if get_platform() == 'win32':
|
||||
enc = 'oem'
|
||||
else:
|
||||
enc = locale.getpreferredencoding()
|
||||
output = raw_output.decode(enc)
|
||||
if command == 'nvidia-smi topo -m':
|
||||
# don't remove the leading whitespace of `nvidia-smi topo -m`
|
||||
# because they are meaningful
|
||||
output = output.rstrip()
|
||||
else:
|
||||
output = output.strip()
|
||||
err = raw_err.decode(enc)
|
||||
return rc, output, err.strip()
|
||||
|
||||
except FileNotFoundError:
|
||||
cmd_str = command if isinstance(command, str) else command[0]
|
||||
return 127, '', f"Command not found: {cmd_str}"
|
||||
|
||||
|
||||
def run_and_read_all(run_lambda, command):
|
||||
"""Run command using run_lambda; reads and returns entire output if rc is 0."""
|
||||
rc, out, _ = run_lambda(command)
|
||||
if rc != 0:
|
||||
return None
|
||||
return out
|
||||
|
||||
|
||||
def run_and_parse_first_match(run_lambda, command, regex):
|
||||
"""Run command using run_lambda, returns the first regex match if it exists."""
|
||||
rc, out, _ = run_lambda(command)
|
||||
if rc != 0:
|
||||
return None
|
||||
match = re.search(regex, out)
|
||||
if match is None:
|
||||
return None
|
||||
return match.group(1)
|
||||
|
||||
|
||||
def run_and_return_first_line(run_lambda, command):
|
||||
"""Run command using run_lambda and returns first line if output is not empty."""
|
||||
rc, out, _ = run_lambda(command)
|
||||
if rc != 0:
|
||||
return None
|
||||
return out.split('\n')[0]
|
||||
|
||||
|
||||
def get_conda_packages(run_lambda, patterns=None):
|
||||
if patterns is None:
|
||||
patterns = DEFAULT_CONDA_PATTERNS
|
||||
conda = os.environ.get('CONDA_EXE', 'conda')
|
||||
out = run_and_read_all(run_lambda, [conda, 'list'])
|
||||
if out is None:
|
||||
return out
|
||||
|
||||
return "\n".join(line for line in out.splitlines()
|
||||
if not line.startswith("#") and any(name in line
|
||||
for name in patterns))
|
||||
|
||||
|
||||
def get_gcc_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)')
|
||||
|
||||
|
||||
def get_clang_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, 'clang --version',
|
||||
r'clang version (.*)')
|
||||
|
||||
|
||||
def get_cmake_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, 'cmake --version',
|
||||
r'cmake (.*)')
|
||||
|
||||
|
||||
def get_nvidia_driver_version(run_lambda):
|
||||
if get_platform() == 'darwin':
|
||||
cmd = 'kextstat | grep -i cuda'
|
||||
return run_and_parse_first_match(run_lambda, cmd,
|
||||
r'com[.]nvidia[.]CUDA [(](.*?)[)]')
|
||||
smi = get_nvidia_smi()
|
||||
return run_and_parse_first_match(run_lambda, smi,
|
||||
r'Driver Version: (.*?) ')
|
||||
|
||||
|
||||
def get_gpu_info(run_lambda):
|
||||
if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr(
|
||||
torch.version, 'hip') and torch.version.hip is not None):
|
||||
if TORCH_AVAILABLE and torch.cuda.is_available():
|
||||
if torch.version.hip is not None:
|
||||
prop = torch.cuda.get_device_properties(0)
|
||||
if hasattr(prop, "gcnArchName"):
|
||||
gcnArch = " ({})".format(prop.gcnArchName)
|
||||
else:
|
||||
gcnArch = "NoGCNArchNameOnOldPyTorch"
|
||||
else:
|
||||
gcnArch = ""
|
||||
return torch.cuda.get_device_name(None) + gcnArch
|
||||
return None
|
||||
smi = get_nvidia_smi()
|
||||
uuid_regex = re.compile(r' \(UUID: .+?\)')
|
||||
rc, out, _ = run_lambda(smi + ' -L')
|
||||
if rc != 0:
|
||||
return None
|
||||
# Anonymize GPUs by removing their UUID
|
||||
return re.sub(uuid_regex, '', out)
|
||||
|
||||
|
||||
def get_running_cuda_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, 'nvcc --version',
|
||||
r'release .+ V(.*)')
|
||||
|
||||
|
||||
def get_cudnn_version(run_lambda):
|
||||
"""Return a list of libcudnn.so; it's hard to tell which one is being used."""
|
||||
if get_platform() == 'win32':
|
||||
system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows')
|
||||
cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%")
|
||||
where_cmd = os.path.join(system_root, 'System32', 'where')
|
||||
cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path)
|
||||
elif get_platform() == 'darwin':
|
||||
# CUDA libraries and drivers can be found in /usr/local/cuda/. See
|
||||
# https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install
|
||||
# https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac
|
||||
# Use CUDNN_LIBRARY when cudnn library is installed elsewhere.
|
||||
cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*'
|
||||
else:
|
||||
cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev'
|
||||
rc, out, _ = run_lambda(cudnn_cmd)
|
||||
# find will return 1 if there are permission errors or if not found
|
||||
if len(out) == 0 or (rc != 1 and rc != 0):
|
||||
l = os.environ.get('CUDNN_LIBRARY')
|
||||
if l is not None and os.path.isfile(l):
|
||||
return os.path.realpath(l)
|
||||
return None
|
||||
files_set = set()
|
||||
for fn in out.split('\n'):
|
||||
fn = os.path.realpath(fn) # eliminate symbolic links
|
||||
if os.path.isfile(fn):
|
||||
files_set.add(fn)
|
||||
if not files_set:
|
||||
return None
|
||||
# Alphabetize the result because the order is non-deterministic otherwise
|
||||
files = sorted(files_set)
|
||||
if len(files) == 1:
|
||||
return files[0]
|
||||
result = '\n'.join(files)
|
||||
return 'Probably one of the following:\n{}'.format(result)
|
||||
|
||||
|
||||
def get_nvidia_smi():
|
||||
# Note: nvidia-smi is currently available only on Windows and Linux
|
||||
smi = 'nvidia-smi'
|
||||
if get_platform() == 'win32':
|
||||
system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows')
|
||||
program_files_root = os.environ.get('PROGRAMFILES',
|
||||
'C:\\Program Files')
|
||||
legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation',
|
||||
'NVSMI', smi)
|
||||
new_path = os.path.join(system_root, 'System32', smi)
|
||||
smis = [new_path, legacy_path]
|
||||
for candidate_smi in smis:
|
||||
if os.path.exists(candidate_smi):
|
||||
smi = '"{}"'.format(candidate_smi)
|
||||
break
|
||||
return smi
|
||||
|
||||
|
||||
def get_rocm_version(run_lambda):
|
||||
"""Returns the ROCm version if available, otherwise 'N/A'."""
|
||||
return run_and_parse_first_match(run_lambda, 'hipcc --version',
|
||||
r'HIP version: (\S+)')
|
||||
|
||||
|
||||
def get_vllm_version():
|
||||
from vllm import __version__, __version_tuple__
|
||||
|
||||
if __version__ == "dev":
|
||||
return "N/A (dev)"
|
||||
version_str = __version_tuple__[-1]
|
||||
if isinstance(version_str, str) and version_str.startswith('g'):
|
||||
# it's a dev build
|
||||
if '.' in version_str:
|
||||
# it's a dev build containing local changes
|
||||
git_sha = version_str.split('.')[0][1:]
|
||||
date = version_str.split('.')[-1][1:]
|
||||
return f"{__version__} (git sha: {git_sha}, date: {date})"
|
||||
else:
|
||||
# it's a dev build without local changes
|
||||
git_sha = version_str[1:] # type: ignore
|
||||
return f"{__version__} (git sha: {git_sha})"
|
||||
return __version__
|
||||
|
||||
|
||||
def summarize_vllm_build_flags():
|
||||
# This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc.
|
||||
return 'CUDA Archs: {}; ROCm: {}'.format(
|
||||
os.environ.get('TORCH_CUDA_ARCH_LIST', 'Not Set'),
|
||||
'Enabled' if os.environ.get('ROCM_HOME') else 'Disabled',
|
||||
)
|
||||
|
||||
|
||||
def get_gpu_topo(run_lambda):
|
||||
output = None
|
||||
|
||||
if get_platform() == 'linux':
|
||||
output = run_and_read_all(run_lambda, 'nvidia-smi topo -m')
|
||||
if output is None:
|
||||
output = run_and_read_all(run_lambda, 'rocm-smi --showtopo')
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# example outputs of CPU infos
|
||||
# * linux
|
||||
# Architecture: x86_64
|
||||
# CPU op-mode(s): 32-bit, 64-bit
|
||||
# Address sizes: 46 bits physical, 48 bits virtual
|
||||
# Byte Order: Little Endian
|
||||
# CPU(s): 128
|
||||
# On-line CPU(s) list: 0-127
|
||||
# Vendor ID: GenuineIntel
|
||||
# Model name: Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
|
||||
# CPU family: 6
|
||||
# Model: 106
|
||||
# Thread(s) per core: 2
|
||||
# Core(s) per socket: 32
|
||||
# Socket(s): 2
|
||||
# Stepping: 6
|
||||
# BogoMIPS: 5799.78
|
||||
# Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr
|
||||
# sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl
|
||||
# xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16
|
||||
# pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand
|
||||
# hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced
|
||||
# fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap
|
||||
# avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1
|
||||
# xsaves wbnoinvd ida arat avx512vbmi pku ospke avx512_vbmi2 gfni vaes vpclmulqdq
|
||||
# avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear flush_l1d arch_capabilities
|
||||
# Virtualization features:
|
||||
# Hypervisor vendor: KVM
|
||||
# Virtualization type: full
|
||||
# Caches (sum of all):
|
||||
# L1d: 3 MiB (64 instances)
|
||||
# L1i: 2 MiB (64 instances)
|
||||
# L2: 80 MiB (64 instances)
|
||||
# L3: 108 MiB (2 instances)
|
||||
# NUMA:
|
||||
# NUMA node(s): 2
|
||||
# NUMA node0 CPU(s): 0-31,64-95
|
||||
# NUMA node1 CPU(s): 32-63,96-127
|
||||
# Vulnerabilities:
|
||||
# Itlb multihit: Not affected
|
||||
# L1tf: Not affected
|
||||
# Mds: Not affected
|
||||
# Meltdown: Not affected
|
||||
# Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
|
||||
# Retbleed: Not affected
|
||||
# Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
|
||||
# Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
|
||||
# Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
|
||||
# Srbds: Not affected
|
||||
# Tsx async abort: Not affected
|
||||
# * win32
|
||||
# Architecture=9
|
||||
# CurrentClockSpeed=2900
|
||||
# DeviceID=CPU0
|
||||
# Family=179
|
||||
# L2CacheSize=40960
|
||||
# L2CacheSpeed=
|
||||
# Manufacturer=GenuineIntel
|
||||
# MaxClockSpeed=2900
|
||||
# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
|
||||
# ProcessorType=3
|
||||
# Revision=27142
|
||||
#
|
||||
# Architecture=9
|
||||
# CurrentClockSpeed=2900
|
||||
# DeviceID=CPU1
|
||||
# Family=179
|
||||
# L2CacheSize=40960
|
||||
# L2CacheSpeed=
|
||||
# Manufacturer=GenuineIntel
|
||||
# MaxClockSpeed=2900
|
||||
# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
|
||||
# ProcessorType=3
|
||||
# Revision=27142
|
||||
|
||||
|
||||
def get_cpu_info(run_lambda):
|
||||
rc, out, err = 0, '', ''
|
||||
if get_platform() == 'linux':
|
||||
rc, out, err = run_lambda('lscpu')
|
||||
elif get_platform() == 'win32':
|
||||
rc, out, err = run_lambda(
|
||||
'wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \
|
||||
CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE'
|
||||
)
|
||||
elif get_platform() == 'darwin':
|
||||
rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string")
|
||||
cpu_info = 'None'
|
||||
if rc == 0:
|
||||
cpu_info = out
|
||||
else:
|
||||
cpu_info = err
|
||||
return cpu_info
|
||||
|
||||
|
||||
def get_platform():
|
||||
if sys.platform.startswith('linux'):
|
||||
return 'linux'
|
||||
elif sys.platform.startswith('win32'):
|
||||
return 'win32'
|
||||
elif sys.platform.startswith('cygwin'):
|
||||
return 'cygwin'
|
||||
elif sys.platform.startswith('darwin'):
|
||||
return 'darwin'
|
||||
else:
|
||||
return sys.platform
|
||||
|
||||
|
||||
def get_mac_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion',
|
||||
r'(.*)')
|
||||
|
||||
|
||||
def get_windows_version(run_lambda):
|
||||
system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows')
|
||||
wmic_cmd = os.path.join(system_root, 'System32', 'Wbem', 'wmic')
|
||||
findstr_cmd = os.path.join(system_root, 'System32', 'findstr')
|
||||
return run_and_read_all(
|
||||
run_lambda,
|
||||
'{} os get Caption | {} /v Caption'.format(wmic_cmd, findstr_cmd))
|
||||
|
||||
|
||||
def get_lsb_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, 'lsb_release -a',
|
||||
r'Description:\t(.*)')
|
||||
|
||||
|
||||
def check_release_file(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, 'cat /etc/*-release',
|
||||
r'PRETTY_NAME="(.*)"')
|
||||
|
||||
|
||||
def get_os(run_lambda):
|
||||
from platform import machine
|
||||
platform = get_platform()
|
||||
|
||||
if platform == 'win32' or platform == 'cygwin':
|
||||
return get_windows_version(run_lambda)
|
||||
|
||||
if platform == 'darwin':
|
||||
version = get_mac_version(run_lambda)
|
||||
if version is None:
|
||||
return None
|
||||
return 'macOS {} ({})'.format(version, machine())
|
||||
|
||||
if platform == 'linux':
|
||||
# Ubuntu/Debian based
|
||||
desc = get_lsb_version(run_lambda)
|
||||
if desc is not None:
|
||||
return '{} ({})'.format(desc, machine())
|
||||
|
||||
# Try reading /etc/*-release
|
||||
desc = check_release_file(run_lambda)
|
||||
if desc is not None:
|
||||
return '{} ({})'.format(desc, machine())
|
||||
|
||||
return '{} ({})'.format(platform, machine())
|
||||
|
||||
# Unknown platform
|
||||
return platform
|
||||
|
||||
|
||||
def get_python_platform():
|
||||
import platform
|
||||
return platform.platform()
|
||||
|
||||
|
||||
def get_libc_version():
|
||||
import platform
|
||||
if get_platform() != 'linux':
|
||||
return 'N/A'
|
||||
return '-'.join(platform.libc_ver())
|
||||
|
||||
|
||||
def is_uv_venv():
|
||||
if os.environ.get("UV"):
|
||||
return True
|
||||
pyvenv_cfg_path = os.path.join(sys.prefix, 'pyvenv.cfg')
|
||||
if os.path.exists(pyvenv_cfg_path):
|
||||
with open(pyvenv_cfg_path, 'r') as f:
|
||||
return any(line.startswith('uv = ') for line in f)
|
||||
return False
|
||||
|
||||
|
||||
def get_pip_packages(run_lambda, patterns=None):
|
||||
"""Return `pip list` output. Note: will also find conda-installed pytorch and numpy packages."""
|
||||
if patterns is None:
|
||||
patterns = DEFAULT_PIP_PATTERNS
|
||||
|
||||
def run_with_pip():
|
||||
try:
|
||||
import importlib.util
|
||||
pip_spec = importlib.util.find_spec('pip')
|
||||
pip_available = pip_spec is not None
|
||||
except ImportError:
|
||||
pip_available = False
|
||||
|
||||
if pip_available:
|
||||
cmd = [sys.executable, '-mpip', 'list', '--format=freeze']
|
||||
elif is_uv_venv():
|
||||
print("uv is set")
|
||||
cmd = ["uv", "pip", "list", "--format=freeze"]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Could not collect pip list output (pip or uv module not available)"
|
||||
)
|
||||
|
||||
out = run_and_read_all(run_lambda, cmd)
|
||||
return "\n".join(line for line in out.splitlines()
|
||||
if any(name in line for name in patterns))
|
||||
|
||||
pip_version = 'pip3' if sys.version[0] == '3' else 'pip'
|
||||
out = run_with_pip()
|
||||
return pip_version, out
|
||||
|
||||
|
||||
def get_cachingallocator_config():
|
||||
ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '')
|
||||
return ca_config
|
||||
|
||||
|
||||
def get_cuda_module_loading_config():
|
||||
if TORCH_AVAILABLE and torch.cuda.is_available():
|
||||
torch.cuda.init()
|
||||
config = os.environ.get('CUDA_MODULE_LOADING', '')
|
||||
return config
|
||||
else:
|
||||
return "N/A"
|
||||
|
||||
|
||||
def is_xnnpack_available():
|
||||
if TORCH_AVAILABLE:
|
||||
import torch.backends.xnnpack
|
||||
return str(
|
||||
torch.backends.xnnpack.enabled) # type: ignore[attr-defined]
|
||||
else:
|
||||
return "N/A"
|
||||
|
||||
|
||||
def get_env_vars():
|
||||
env_vars = ''
|
||||
secret_terms = ('secret', 'token', 'api', 'access', 'password')
|
||||
report_prefix = ("TORCH", "NCCL", "PYTORCH", "CUDA", "CUBLAS", "CUDNN",
|
||||
"OMP_", "MKL_", "NVIDIA")
|
||||
for k, v in os.environ.items():
|
||||
if any(term in k.lower() for term in secret_terms):
|
||||
continue
|
||||
if k in environment_variables:
|
||||
env_vars = env_vars + "{}={}".format(k, v) + "\n"
|
||||
if k.startswith(report_prefix):
|
||||
env_vars = env_vars + "{}={}".format(k, v) + "\n"
|
||||
|
||||
return env_vars
|
||||
|
||||
|
||||
def get_env_info():
|
||||
run_lambda = run
|
||||
pip_version, pip_list_output = get_pip_packages(run_lambda)
|
||||
|
||||
if TORCH_AVAILABLE:
|
||||
version_str = torch.__version__
|
||||
debug_mode_str = str(torch.version.debug)
|
||||
cuda_available_str = str(torch.cuda.is_available())
|
||||
cuda_version_str = torch.version.cuda
|
||||
if not hasattr(torch.version,
|
||||
'hip') or torch.version.hip is None: # cuda version
|
||||
hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A'
|
||||
else: # HIP version
|
||||
|
||||
def get_version_or_na(cfg, prefix):
|
||||
_lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s]
|
||||
return _lst[0] if _lst else 'N/A'
|
||||
|
||||
cfg = torch._C._show_config().split('\n')
|
||||
hip_runtime_version = get_version_or_na(cfg, 'HIP Runtime')
|
||||
miopen_runtime_version = get_version_or_na(cfg, 'MIOpen')
|
||||
cuda_version_str = 'N/A'
|
||||
hip_compiled_version = torch.version.hip
|
||||
else:
|
||||
version_str = debug_mode_str = cuda_available_str = cuda_version_str = 'N/A'
|
||||
hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A'
|
||||
|
||||
sys_version = sys.version.replace("\n", " ")
|
||||
|
||||
conda_packages = get_conda_packages(run_lambda)
|
||||
|
||||
rocm_version = get_rocm_version(run_lambda)
|
||||
vllm_version = get_vllm_version()
|
||||
vllm_build_flags = summarize_vllm_build_flags()
|
||||
gpu_topo = get_gpu_topo(run_lambda)
|
||||
|
||||
return SystemEnv(
|
||||
torch_version=version_str,
|
||||
is_debug_build=debug_mode_str,
|
||||
python_version='{} ({}-bit runtime)'.format(
|
||||
sys_version,
|
||||
sys.maxsize.bit_length() + 1),
|
||||
python_platform=get_python_platform(),
|
||||
is_cuda_available=cuda_available_str,
|
||||
cuda_compiled_version=cuda_version_str,
|
||||
cuda_runtime_version=get_running_cuda_version(run_lambda),
|
||||
cuda_module_loading=get_cuda_module_loading_config(),
|
||||
nvidia_gpu_models=get_gpu_info(run_lambda),
|
||||
nvidia_driver_version=get_nvidia_driver_version(run_lambda),
|
||||
cudnn_version=get_cudnn_version(run_lambda),
|
||||
hip_compiled_version=hip_compiled_version,
|
||||
hip_runtime_version=hip_runtime_version,
|
||||
miopen_runtime_version=miopen_runtime_version,
|
||||
pip_version=pip_version,
|
||||
pip_packages=pip_list_output,
|
||||
conda_packages=conda_packages,
|
||||
os=get_os(run_lambda),
|
||||
libc_version=get_libc_version(),
|
||||
gcc_version=get_gcc_version(run_lambda),
|
||||
clang_version=get_clang_version(run_lambda),
|
||||
cmake_version=get_cmake_version(run_lambda),
|
||||
caching_allocator_config=get_cachingallocator_config(),
|
||||
is_xnnpack_available=is_xnnpack_available(),
|
||||
cpu_info=get_cpu_info(run_lambda),
|
||||
rocm_version=rocm_version,
|
||||
vllm_version=vllm_version,
|
||||
vllm_build_flags=vllm_build_flags,
|
||||
gpu_topo=gpu_topo,
|
||||
env_vars=get_env_vars(),
|
||||
)
|
||||
|
||||
|
||||
env_info_fmt = """
|
||||
==============================
|
||||
System Info
|
||||
==============================
|
||||
OS : {os}
|
||||
GCC version : {gcc_version}
|
||||
Clang version : {clang_version}
|
||||
CMake version : {cmake_version}
|
||||
Libc version : {libc_version}
|
||||
|
||||
==============================
|
||||
PyTorch Info
|
||||
==============================
|
||||
PyTorch version : {torch_version}
|
||||
Is debug build : {is_debug_build}
|
||||
CUDA used to build PyTorch : {cuda_compiled_version}
|
||||
ROCM used to build PyTorch : {hip_compiled_version}
|
||||
|
||||
==============================
|
||||
Python Environment
|
||||
==============================
|
||||
Python version : {python_version}
|
||||
Python platform : {python_platform}
|
||||
|
||||
==============================
|
||||
CUDA / GPU Info
|
||||
==============================
|
||||
Is CUDA available : {is_cuda_available}
|
||||
CUDA runtime version : {cuda_runtime_version}
|
||||
CUDA_MODULE_LOADING set to : {cuda_module_loading}
|
||||
GPU models and configuration : {nvidia_gpu_models}
|
||||
Nvidia driver version : {nvidia_driver_version}
|
||||
cuDNN version : {cudnn_version}
|
||||
HIP runtime version : {hip_runtime_version}
|
||||
MIOpen runtime version : {miopen_runtime_version}
|
||||
Is XNNPACK available : {is_xnnpack_available}
|
||||
|
||||
==============================
|
||||
CPU Info
|
||||
==============================
|
||||
{cpu_info}
|
||||
|
||||
==============================
|
||||
Versions of relevant libraries
|
||||
==============================
|
||||
{pip_packages}
|
||||
{conda_packages}
|
||||
""".strip()
|
||||
|
||||
# both the above code and the following code use `strip()` to
|
||||
# remove leading/trailing whitespaces, so we need to add a newline
|
||||
# in between to separate the two sections
|
||||
env_info_fmt += "\n\n"
|
||||
|
||||
env_info_fmt += """
|
||||
==============================
|
||||
vLLM Info
|
||||
==============================
|
||||
ROCM Version : {rocm_version}
|
||||
vLLM Version : {vllm_version}
|
||||
vLLM Build Flags:
|
||||
{vllm_build_flags}
|
||||
GPU Topology:
|
||||
{gpu_topo}
|
||||
|
||||
==============================
|
||||
Environment Variables
|
||||
==============================
|
||||
{env_vars}
|
||||
""".strip()
|
||||
|
||||
|
||||
def pretty_str(envinfo):
|
||||
|
||||
def replace_nones(dct, replacement='Could not collect'):
|
||||
for key in dct.keys():
|
||||
if dct[key] is not None:
|
||||
continue
|
||||
dct[key] = replacement
|
||||
return dct
|
||||
|
||||
def replace_bools(dct, true='Yes', false='No'):
|
||||
for key in dct.keys():
|
||||
if dct[key] is True:
|
||||
dct[key] = true
|
||||
elif dct[key] is False:
|
||||
dct[key] = false
|
||||
return dct
|
||||
|
||||
def prepend(text, tag='[prepend]'):
|
||||
lines = text.split('\n')
|
||||
updated_lines = [tag + line for line in lines]
|
||||
return '\n'.join(updated_lines)
|
||||
|
||||
def replace_if_empty(text, replacement='No relevant packages'):
|
||||
if text is not None and len(text) == 0:
|
||||
return replacement
|
||||
return text
|
||||
|
||||
def maybe_start_on_next_line(string):
|
||||
# If `string` is multiline, prepend a \n to it.
|
||||
if string is not None and len(string.split('\n')) > 1:
|
||||
return '\n{}\n'.format(string)
|
||||
return string
|
||||
|
||||
mutable_dict = envinfo._asdict()
|
||||
|
||||
# If nvidia_gpu_models is multiline, start on the next line
|
||||
mutable_dict['nvidia_gpu_models'] = \
|
||||
maybe_start_on_next_line(envinfo.nvidia_gpu_models)
|
||||
|
||||
# If the machine doesn't have CUDA, report some fields as 'No CUDA'
|
||||
dynamic_cuda_fields = [
|
||||
'cuda_runtime_version',
|
||||
'nvidia_gpu_models',
|
||||
'nvidia_driver_version',
|
||||
]
|
||||
all_cuda_fields = dynamic_cuda_fields + ['cudnn_version']
|
||||
all_dynamic_cuda_fields_missing = all(mutable_dict[field] is None
|
||||
for field in dynamic_cuda_fields)
|
||||
if TORCH_AVAILABLE and not torch.cuda.is_available(
|
||||
) and all_dynamic_cuda_fields_missing:
|
||||
for field in all_cuda_fields:
|
||||
mutable_dict[field] = 'No CUDA'
|
||||
if envinfo.cuda_compiled_version is None:
|
||||
mutable_dict['cuda_compiled_version'] = 'None'
|
||||
|
||||
# Replace True with Yes, False with No
|
||||
mutable_dict = replace_bools(mutable_dict)
|
||||
|
||||
# Replace all None objects with 'Could not collect'
|
||||
mutable_dict = replace_nones(mutable_dict)
|
||||
|
||||
# If either of these are '', replace with 'No relevant packages'
|
||||
mutable_dict['pip_packages'] = replace_if_empty(
|
||||
mutable_dict['pip_packages'])
|
||||
mutable_dict['conda_packages'] = replace_if_empty(
|
||||
mutable_dict['conda_packages'])
|
||||
|
||||
# Tag conda and pip packages with a prefix
|
||||
# If they were previously None, they'll show up as ie '[conda] Could not collect'
|
||||
if mutable_dict['pip_packages']:
|
||||
mutable_dict['pip_packages'] = prepend(
|
||||
mutable_dict['pip_packages'], '[{}] '.format(envinfo.pip_version))
|
||||
if mutable_dict['conda_packages']:
|
||||
mutable_dict['conda_packages'] = prepend(
|
||||
mutable_dict['conda_packages'], '[conda] ')
|
||||
mutable_dict['cpu_info'] = envinfo.cpu_info
|
||||
return env_info_fmt.format(**mutable_dict)
|
||||
|
||||
|
||||
def get_pretty_env_info():
|
||||
return pretty_str(get_env_info())
|
||||
|
||||
|
||||
def main():
|
||||
print("Collecting environment information...")
|
||||
output = get_pretty_env_info()
|
||||
print(output)
|
||||
|
||||
if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr(
|
||||
torch.utils, '_crash_handler'):
|
||||
minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR
|
||||
if sys.platform == "linux" and os.path.exists(minidump_dir):
|
||||
dumps = [
|
||||
os.path.join(minidump_dir, dump)
|
||||
for dump in os.listdir(minidump_dir)
|
||||
]
|
||||
latest = max(dumps, key=os.path.getctime)
|
||||
ctime = os.path.getctime(latest)
|
||||
creation_time = datetime.datetime.fromtimestamp(ctime).strftime(
|
||||
'%Y-%m-%d %H:%M:%S')
|
||||
msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \
|
||||
"if this is related to your bug please include it when you file a report ***"
|
||||
print(msg, file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
0
vllm/compilation/__init__.py
Normal file
0
vllm/compilation/__init__.py
Normal file
BIN
vllm/compilation/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm/compilation/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/compilation/__pycache__/inductor_pass.cpython-312.pyc
Normal file
BIN
vllm/compilation/__pycache__/inductor_pass.cpython-312.pyc
Normal file
Binary file not shown.
189
vllm/compilation/activation_quant_fusion.py
Normal file
189
vllm/compilation/activation_quant_fusion.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only,
|
||||
register_replacement)
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey, kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
|
||||
|
||||
FUSED_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501
|
||||
}
|
||||
silu_and_mul_nvfp4_quant_supported = (current_platform.is_cuda() and hasattr(
|
||||
torch.ops._C, "silu_and_mul_nvfp4_quant"))
|
||||
if silu_and_mul_nvfp4_quant_supported:
|
||||
FUSED_OPS[
|
||||
kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
|
||||
|
||||
|
||||
class ActivationQuantPattern(ABC):
|
||||
"""
|
||||
The base class for Activation+Quant fusions.
|
||||
Should not be used directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_key: QuantKey,
|
||||
):
|
||||
self.quant_key = quant_key
|
||||
self.quant_dtype = quant_key.dtype
|
||||
|
||||
assert self.quant_key in QUANT_OPS, \
|
||||
f"unsupported quantization scheme {self.quant_key}"
|
||||
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
||||
|
||||
assert self.quant_key in FUSED_OPS, \
|
||||
f"unsupported fusion scheme {self.quant_key}"
|
||||
self.FUSED_OP = FUSED_OPS[self.quant_key]
|
||||
|
||||
def empty_quant(self, *args, **kwargs):
|
||||
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
|
||||
"""
|
||||
Fusion for SiluMul+Fp8StaticQuant Pattern
|
||||
"""
|
||||
|
||||
def __init__(self, symmetric: bool = True):
|
||||
quant_key = QuantKey(dtype=FP8_DTYPE,
|
||||
scale=kStaticTensorScale,
|
||||
symmetric=symmetric)
|
||||
super().__init__(quant_key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, result_silu_mul: torch.Tensor,
|
||||
input: torch.Tensor, scale: torch.Tensor):
|
||||
at1 = auto_functionalized(SILU_MUL_OP,
|
||||
result=result_silu_mul,
|
||||
input=input)
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
result=result,
|
||||
input=at1[1],
|
||||
scale=scale)
|
||||
return at2[1]
|
||||
|
||||
def replacement(result: torch.Tensor, result_silu_mul: torch.Tensor,
|
||||
input: torch.Tensor, scale: torch.Tensor):
|
||||
at = auto_functionalized(self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
scale=scale)
|
||||
return at[1]
|
||||
|
||||
inputs = [
|
||||
self.empty_quant(5, 4), # result
|
||||
empty_bf16(5, 4), # result_silu_mul
|
||||
empty_bf16(5, 4), # input
|
||||
empty_fp32(1, 1) # scale
|
||||
]
|
||||
|
||||
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
|
||||
|
||||
|
||||
class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
|
||||
"""
|
||||
Fusion for SiluMul+Nvfp4Quant Pattern
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(kNvfp4Quant)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, output_scale: torch.Tensor,
|
||||
result_silu_mul: torch.Tensor, input: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at1 = auto_functionalized(SILU_MUL_OP,
|
||||
result=result_silu_mul,
|
||||
input=input)
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
output=result,
|
||||
input=at1[1],
|
||||
output_scale=output_scale,
|
||||
input_scale=scale)
|
||||
return at2[1], at2[2]
|
||||
|
||||
def replacement(result: torch.Tensor, output_scale: torch.Tensor,
|
||||
result_silu_mul: torch.Tensor, input: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(self.FUSED_OP,
|
||||
result=result,
|
||||
result_block_scale=output_scale,
|
||||
input=input,
|
||||
input_global_scale=scale)
|
||||
return at[1], at[2]
|
||||
|
||||
inputs = [
|
||||
self.empty_quant(5, 32), # result
|
||||
empty_i32(128, 4), # output_scale
|
||||
empty_bf16(5, 64), # result_silu_mul
|
||||
empty_bf16(5, 64), # input
|
||||
empty_fp32(1, 1) # scale
|
||||
]
|
||||
|
||||
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
|
||||
|
||||
|
||||
class ActivationQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||
It uses the torch pattern matcher to find the patterns and replace them.
|
||||
|
||||
Because patterns can only be registered once, the pass is a singleton.
|
||||
This will be addressed in a future version of PyTorch:
|
||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="activation_quant_fusion_pass")
|
||||
|
||||
pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern()
|
||||
pattern_silu_mul_fp8.register(self.patterns)
|
||||
|
||||
if silu_and_mul_nvfp4_quant_supported:
|
||||
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
|
||||
pattern_silu_mul_nvfp4.register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self):
|
||||
return VllmInductorPass.hash_source(self, ActivationQuantPattern,
|
||||
SiluMulFp8StaticQuantPattern,
|
||||
SiluMulNvfp4QuantPattern)
|
||||
650
vllm/compilation/backends.py
Normal file
650
vllm/compilation/backends.py
Normal file
@@ -0,0 +1,650 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import dataclasses
|
||||
import os
|
||||
import pprint
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
|
||||
|
||||
from .compiler_interface import (CompilerInterface, EagerAdaptor,
|
||||
InductorAdaptor, InductorStandaloneAdaptor)
|
||||
from .counter import compilation_counter
|
||||
from .inductor_pass import InductorPass
|
||||
from .pass_manager import PostGradPassManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
|
||||
if compilation_config.use_inductor:
|
||||
# Use standalone compile only if requested, version is new enough,
|
||||
# and the symbol actually exists in this PyTorch build.
|
||||
if (envs.VLLM_USE_STANDALONE_COMPILE
|
||||
and is_torch_equal_or_newer("2.8.0.dev")
|
||||
and hasattr(torch._inductor, "standalone_compile")):
|
||||
logger.debug("Using InductorStandaloneAdaptor")
|
||||
return InductorStandaloneAdaptor()
|
||||
else:
|
||||
logger.debug("Using InductorAdaptor")
|
||||
return InductorAdaptor()
|
||||
else:
|
||||
logger.debug("Using EagerAdaptor")
|
||||
return EagerAdaptor()
|
||||
|
||||
|
||||
class CompilerManager:
|
||||
"""
|
||||
A manager to manage the compilation process, including
|
||||
caching the compiled graph, loading the compiled graph,
|
||||
and compiling the graph.
|
||||
|
||||
The cache is a dict mapping
|
||||
`(runtime_shape, graph_index, backend_name)`
|
||||
to `any_data` returned from the compiler.
|
||||
|
||||
When serializing the cache, we save it to a Python file
|
||||
for readability. We don't use json here because json doesn't
|
||||
support int as key.
|
||||
"""
|
||||
|
||||
def __init__(self, compilation_config: CompilationConfig):
|
||||
self.cache: dict[tuple[Optional[int], int, str], Any] = dict()
|
||||
self.is_cache_updated = False
|
||||
self.compilation_config = compilation_config
|
||||
self.compiler = make_compiler(compilation_config)
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
return self.compiler.compute_hash(vllm_config)
|
||||
|
||||
def initialize_cache(self,
|
||||
cache_dir: str,
|
||||
disable_cache: bool = False,
|
||||
prefix: str = ""):
|
||||
"""
|
||||
Initialize the cache directory for the compiler.
|
||||
|
||||
The organization of the cache directory is as follows:
|
||||
cache_dir=/path/to/hash_str/rank_i_j/prefix/
|
||||
inside cache_dir, there will be:
|
||||
- vllm_compile_cache.py
|
||||
- computation_graph.py
|
||||
- transformed_code.py
|
||||
|
||||
for multiple prefixes, they can share the same
|
||||
base cache dir of /path/to/hash_str/rank_i_j/ ,
|
||||
to store some common compilation artifacts.
|
||||
"""
|
||||
|
||||
self.disable_cache = disable_cache
|
||||
self.cache_dir = cache_dir
|
||||
self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")
|
||||
|
||||
if not disable_cache and os.path.exists(self.cache_file_path):
|
||||
# load the cache from the file
|
||||
with open(self.cache_file_path) as f:
|
||||
# we use ast.literal_eval to parse the data
|
||||
# because it is a safe way to parse Python literals.
|
||||
# do not use eval(), it is unsafe.
|
||||
self.cache = ast.literal_eval(f.read())
|
||||
|
||||
self.compiler.initialize_cache(cache_dir=cache_dir,
|
||||
disable_cache=disable_cache,
|
||||
prefix=prefix)
|
||||
|
||||
def save_to_file(self):
|
||||
if self.disable_cache or not self.is_cache_updated:
|
||||
return
|
||||
printer = pprint.PrettyPrinter(indent=4)
|
||||
data = printer.pformat(self.cache)
|
||||
with open(self.cache_file_path, "w") as f:
|
||||
f.write(data)
|
||||
|
||||
def load(self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None) -> Optional[Callable]:
|
||||
if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
|
||||
return None
|
||||
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
|
||||
compiled_graph = self.compiler.load(handle, graph, example_inputs,
|
||||
graph_index, runtime_shape)
|
||||
if runtime_shape is None:
|
||||
logger.debug(
|
||||
"Directly load the %s-th graph for dynamic shape from %s via "
|
||||
"handle %s", graph_index, self.compiler.name, handle)
|
||||
else:
|
||||
logger.debug(
|
||||
"Directly load the %s-th graph for shape %s from %s via "
|
||||
"handle %s", graph_index, str(runtime_shape),
|
||||
self.compiler.name, handle)
|
||||
return compiled_graph
|
||||
|
||||
def compile(self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs,
|
||||
additional_inductor_config,
|
||||
compilation_config: CompilationConfig,
|
||||
graph_index: int = 0,
|
||||
num_graphs: int = 1,
|
||||
runtime_shape: Optional[int] = None) -> Any:
|
||||
if graph_index == 0:
|
||||
# before compiling the first graph, record the start time
|
||||
global compilation_start_time
|
||||
compilation_start_time = time.time()
|
||||
|
||||
compilation_counter.num_backend_compilations += 1
|
||||
|
||||
compiled_graph = None
|
||||
|
||||
# try to load from the cache
|
||||
compiled_graph = self.load(graph, example_inputs, graph_index,
|
||||
runtime_shape)
|
||||
if compiled_graph is not None:
|
||||
if graph_index == num_graphs - 1:
|
||||
# after loading the last graph for this shape, record the time.
|
||||
# there can be multiple graphs due to piecewise compilation.
|
||||
now = time.time()
|
||||
elapsed = now - compilation_start_time
|
||||
if runtime_shape is None:
|
||||
logger.info(
|
||||
"Directly load the compiled graph(s) for dynamic shape "
|
||||
"from the cache, took %.3f s", elapsed)
|
||||
else:
|
||||
logger.info(
|
||||
"Directly load the compiled graph(s) for shape %s "
|
||||
"from the cache, took %.3f s", str(runtime_shape),
|
||||
elapsed)
|
||||
return compiled_graph
|
||||
|
||||
# no compiler cached the graph, or the cache is disabled,
|
||||
# we need to compile it
|
||||
if isinstance(self.compiler, InductorAdaptor):
|
||||
# Let compile_fx generate a key for us
|
||||
maybe_key = None
|
||||
else:
|
||||
maybe_key = \
|
||||
f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
|
||||
compiled_graph, handle = self.compiler.compile(
|
||||
graph, example_inputs, additional_inductor_config, runtime_shape,
|
||||
maybe_key)
|
||||
|
||||
assert compiled_graph is not None, "Failed to compile the graph"
|
||||
|
||||
# store the artifact in the cache
|
||||
if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None:
|
||||
self.cache[(runtime_shape, graph_index,
|
||||
self.compiler.name)] = handle
|
||||
compilation_counter.num_cache_entries_updated += 1
|
||||
self.is_cache_updated = True
|
||||
if graph_index == 0:
|
||||
# adds some info logging for the first graph
|
||||
if runtime_shape is None:
|
||||
logger.info(
|
||||
"Cache the graph for dynamic shape for later use")
|
||||
else:
|
||||
logger.info("Cache the graph of shape %s for later use",
|
||||
str(runtime_shape))
|
||||
if runtime_shape is None:
|
||||
logger.debug(
|
||||
"Store the %s-th graph for dynamic shape from %s via "
|
||||
"handle %s", graph_index, self.compiler.name, handle)
|
||||
else:
|
||||
logger.debug(
|
||||
"Store the %s-th graph for shape %s from %s via handle %s",
|
||||
graph_index, str(runtime_shape), self.compiler.name,
|
||||
handle)
|
||||
|
||||
# after compiling the last graph, record the end time
|
||||
if graph_index == num_graphs - 1:
|
||||
now = time.time()
|
||||
elapsed = now - compilation_start_time
|
||||
compilation_config.compilation_time += elapsed
|
||||
if runtime_shape is None:
|
||||
logger.info("Compiling a graph for dynamic shape takes %.2f s",
|
||||
elapsed)
|
||||
else:
|
||||
logger.info("Compiling a graph for shape %s takes %.2f s",
|
||||
runtime_shape, elapsed)
|
||||
|
||||
return compiled_graph
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SplitItem:
|
||||
submod_name: str
|
||||
graph_id: int
|
||||
is_splitting_graph: bool
|
||||
graph: fx.GraphModule
|
||||
|
||||
|
||||
def split_graph(graph: fx.GraphModule,
|
||||
ops: list[str]) -> tuple[fx.GraphModule, list[SplitItem]]:
|
||||
# split graph by ops
|
||||
subgraph_id = 0
|
||||
node_to_subgraph_id = {}
|
||||
split_op_graphs = []
|
||||
for node in graph.graph.nodes:
|
||||
if node.op in ("output", "placeholder"):
|
||||
continue
|
||||
if node.op == 'call_function' and str(node.target) in ops:
|
||||
subgraph_id += 1
|
||||
node_to_subgraph_id[node] = subgraph_id
|
||||
split_op_graphs.append(subgraph_id)
|
||||
subgraph_id += 1
|
||||
else:
|
||||
node_to_subgraph_id[node] = subgraph_id
|
||||
|
||||
# `keep_original_order` is important!
|
||||
# otherwise pytorch might reorder the nodes and
|
||||
# the semantics of the graph will change when we
|
||||
# have mutations in the graph
|
||||
split_gm = torch.fx.passes.split_module.split_module(
|
||||
graph,
|
||||
None,
|
||||
lambda node: node_to_subgraph_id[node],
|
||||
keep_original_order=True)
|
||||
|
||||
outputs = []
|
||||
|
||||
names = [name for (name, module) in split_gm.named_modules()]
|
||||
|
||||
for name in names:
|
||||
if "." in name or name == "":
|
||||
# recursive child module or the root module
|
||||
continue
|
||||
|
||||
module = getattr(split_gm, name)
|
||||
|
||||
graph_id = int(name.replace("submod_", ""))
|
||||
outputs.append(
|
||||
SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
|
||||
|
||||
# sort by integer graph_id, rather than string name
|
||||
outputs.sort(key=lambda x: x.graph_id)
|
||||
|
||||
return split_gm, outputs
|
||||
|
||||
|
||||
compilation_start_time = 0.0
|
||||
|
||||
|
||||
class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
|
||||
It runs the given graph with fake inputs, and compile some
|
||||
submodules specified by `compile_submod_names` with the given
|
||||
compilation configs.
|
||||
|
||||
NOTE: the order in `compile_submod_names` matters, because
|
||||
it will be used to determine the order of the compiled piecewise
|
||||
graphs. The first graph will handle logging, and the last graph
|
||||
has some special cudagraph output handling.
|
||||
"""
|
||||
|
||||
def __init__(self, module: torch.fx.GraphModule,
|
||||
compile_submod_names: list[str], vllm_config: VllmConfig,
|
||||
vllm_backend: "VllmBackend"):
|
||||
super().__init__(module)
|
||||
from torch._guards import detect_fake_mode
|
||||
self.fake_mode = detect_fake_mode()
|
||||
self.compile_submod_names = compile_submod_names
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.vllm_config = vllm_config
|
||||
self.vllm_backend = vllm_backend
|
||||
# When True, it annoyingly dumps the torch.fx.Graph on errors.
|
||||
self.extra_traceback = False
|
||||
|
||||
def run(self, *args):
|
||||
fake_args = [
|
||||
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
||||
for t in args
|
||||
]
|
||||
with self.fake_mode, enable_python_dispatcher():
|
||||
return super().run(*fake_args)
|
||||
|
||||
def call_module(self, target: torch.fx.node.Target,
|
||||
args: tuple[torch.fx.node.Argument,
|
||||
...], kwargs: dict[str, Any]) -> Any:
|
||||
assert isinstance(target, str)
|
||||
output = super().call_module(target, args, kwargs)
|
||||
|
||||
if target in self.compile_submod_names:
|
||||
index = self.compile_submod_names.index(target)
|
||||
submod = self.fetch_attr(target)
|
||||
sym_shape_indices = [
|
||||
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
||||
]
|
||||
global compilation_start_time
|
||||
|
||||
compiled_graph_for_dynamic_shape = self.vllm_backend.\
|
||||
compiler_manager.compile(
|
||||
submod,
|
||||
args,
|
||||
self.compilation_config.inductor_compile_config,
|
||||
self.compilation_config,
|
||||
graph_index=index,
|
||||
num_graphs=len(self.compile_submod_names),
|
||||
runtime_shape=None)
|
||||
# Lazy import here to avoid circular import
|
||||
from .cuda_piecewise_backend import PiecewiseBackend
|
||||
|
||||
piecewise_backend = PiecewiseBackend(
|
||||
submod, self.vllm_config, index,
|
||||
len(self.compile_submod_names), sym_shape_indices,
|
||||
compiled_graph_for_dynamic_shape, self.vllm_backend)
|
||||
|
||||
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and
|
||||
not self.compilation_config.use_inductor_graph_partition):
|
||||
# We're using Dynamo-based piecewise splitting, so we wrap
|
||||
# the whole subgraph with a static graph wrapper.
|
||||
from .cuda_graph import CUDAGraphOptions
|
||||
|
||||
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
|
||||
# class) as platform dependent.
|
||||
static_graph_wrapper_class = resolve_obj_by_qualname(
|
||||
current_platform.get_static_graph_wrapper_cls())
|
||||
|
||||
# Always assign PIECEWISE runtime mode to the
|
||||
# CUDAGraphWrapper for piecewise_backend, to distinguish
|
||||
# it from the FULL cudagraph runtime mode, no matter it
|
||||
# is wrapped on a full or piecewise fx graph.
|
||||
self.module.__dict__[target] = static_graph_wrapper_class(
|
||||
runnable=piecewise_backend,
|
||||
vllm_config=self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
cudagraph_options=CUDAGraphOptions(
|
||||
debug_log_enable=piecewise_backend.is_first_graph,
|
||||
gc_disable=not piecewise_backend.is_first_graph,
|
||||
weak_ref_output=piecewise_backend.is_last_graph))
|
||||
else:
|
||||
self.module.__dict__[target] = piecewise_backend
|
||||
|
||||
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# the tag for the part of model being compiled,
|
||||
# e.g. backbone/eagle_head
|
||||
model_tag: str = "backbone"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_model_tag(tag: str):
|
||||
"""Context manager to set the model tag."""
|
||||
global model_tag
|
||||
assert tag != model_tag, \
|
||||
f"Model tag {tag} is the same as the current tag {model_tag}."
|
||||
old_tag = model_tag
|
||||
model_tag = tag
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
model_tag = old_tag
|
||||
|
||||
|
||||
class VllmBackend:
|
||||
"""The compilation backend for `torch.compile` with vLLM.
|
||||
It is used for compilation level of `CompilationLevel.PIECEWISE`,
|
||||
where we customize the compilation.
|
||||
|
||||
The major work of this backend is to split the graph into
|
||||
piecewise graphs, and pass them to the piecewise backend.
|
||||
|
||||
This backend also adds the PostGradPassManager to Inductor config,
|
||||
which handles the post-grad passes.
|
||||
"""
|
||||
|
||||
vllm_config: VllmConfig
|
||||
compilation_config: CompilationConfig
|
||||
_called: bool = False
|
||||
# the graph we compiled
|
||||
graph: fx.GraphModule
|
||||
# the stiching graph module for all the piecewise graphs
|
||||
split_gm: fx.GraphModule
|
||||
piecewise_graphs: list[SplitItem]
|
||||
returned_callable: Callable
|
||||
# Inductor passes to run on the graph pre-defunctionalization
|
||||
post_grad_passes: Sequence[Callable]
|
||||
sym_tensor_indices: list[int]
|
||||
input_buffers: list[torch.Tensor]
|
||||
compiler_manager: CompilerManager
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
|
||||
# if the model is initialized with a non-empty prefix,
|
||||
# then usually it's enough to use that prefix,
|
||||
# e.g. language_model, vision_model, etc.
|
||||
# when multiple parts are initialized as independent
|
||||
# models, we need to use the model_tag to distinguish
|
||||
# them, e.g. backbone (default), eagle_head, etc.
|
||||
self.prefix = prefix or model_tag
|
||||
|
||||
# Passes to run on the graph post-grad.
|
||||
self.post_grad_pass_manager = PostGradPassManager()
|
||||
|
||||
self.sym_tensor_indices = []
|
||||
self.input_buffers = []
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
self.compiler_manager: CompilerManager = CompilerManager(
|
||||
self.compilation_config)
|
||||
|
||||
# `torch.compile` is JIT compiled, so we don't need to
|
||||
# do anything here
|
||||
|
||||
def configure_post_pass(self):
|
||||
config = self.compilation_config
|
||||
self.post_grad_pass_manager.configure(self.vllm_config)
|
||||
|
||||
# Post-grad custom passes are run using the post_grad_custom_post_pass
|
||||
# hook. If a pass for that hook exists, add it to the pass manager.
|
||||
inductor_config = config.inductor_compile_config
|
||||
PASS_KEY = "post_grad_custom_post_pass"
|
||||
if PASS_KEY in inductor_config:
|
||||
if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
|
||||
# PassManager already added to config, make sure it's correct
|
||||
assert (inductor_config[PASS_KEY].uuid() ==
|
||||
self.post_grad_pass_manager.uuid())
|
||||
else:
|
||||
# Config should automatically wrap all inductor passes
|
||||
assert isinstance(inductor_config[PASS_KEY], InductorPass)
|
||||
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
|
||||
inductor_config[PASS_KEY] = self.post_grad_pass_manager
|
||||
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
|
||||
|
||||
vllm_config = self.vllm_config
|
||||
if not self.compilation_config.cache_dir:
|
||||
# no provided cache dir, generate one based on the known factors
|
||||
# that affects the compilation. if none of the factors change,
|
||||
# the cache dir will be the same so that we can reuse the compiled
|
||||
# graph.
|
||||
|
||||
factors = []
|
||||
# 0. factors come from the env, for example, The values of
|
||||
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
|
||||
env_hash = envs.compute_hash()
|
||||
factors.append(env_hash)
|
||||
|
||||
# 1. factors come from the vllm_config (it mainly summarizes how the
|
||||
# model is created)
|
||||
config_hash = vllm_config.compute_hash()
|
||||
factors.append(config_hash)
|
||||
|
||||
# 2. factors come from the code files that are traced by Dynamo (
|
||||
# it mainly summarizes how the model is used in forward pass)
|
||||
forward_code_files = list(
|
||||
sorted(self.compilation_config.traced_files))
|
||||
self.compilation_config.traced_files.clear()
|
||||
logger.debug(
|
||||
"Traced files (to be considered for compilation cache):\n%s",
|
||||
"\n".join(forward_code_files))
|
||||
hash_content = []
|
||||
for filepath in forward_code_files:
|
||||
hash_content.append(filepath)
|
||||
if filepath == "<string>":
|
||||
# This means the function was dynamically generated, with
|
||||
# e.g. exec(). We can't actually check these.
|
||||
continue
|
||||
with open(filepath) as f:
|
||||
hash_content.append(f.read())
|
||||
import hashlib
|
||||
code_hash = hashlib.md5("\n".join(hash_content).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
factors.append(code_hash)
|
||||
|
||||
# 3. compiler hash
|
||||
compiler_hash = self.compiler_manager.compute_hash(vllm_config)
|
||||
factors.append(compiler_hash)
|
||||
|
||||
# combine all factors to generate the cache dir
|
||||
hash_key = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()[:10]
|
||||
|
||||
cache_dir = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT,
|
||||
"torch_compile_cache",
|
||||
hash_key,
|
||||
)
|
||||
self.compilation_config.cache_dir = cache_dir
|
||||
|
||||
cache_dir = self.compilation_config.cache_dir
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
self.compilation_config.cache_dir = cache_dir
|
||||
rank = vllm_config.parallel_config.rank
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}",
|
||||
self.prefix)
|
||||
os.makedirs(local_cache_dir, exist_ok=True)
|
||||
self.compilation_config.local_cache_dir = local_cache_dir
|
||||
|
||||
disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE
|
||||
|
||||
if disable_cache:
|
||||
logger.info("vLLM's torch.compile cache is disabled.")
|
||||
else:
|
||||
logger.info("Using cache directory: %s for vLLM's torch.compile",
|
||||
local_cache_dir)
|
||||
|
||||
self.compiler_manager.initialize_cache(local_cache_dir, disable_cache,
|
||||
self.prefix)
|
||||
|
||||
# when dynamo calls the backend, it means the bytecode
|
||||
# transform and analysis are done
|
||||
compilation_counter.num_graphs_seen += 1
|
||||
from .monitor import torch_compile_start_time
|
||||
dynamo_time = time.time() - torch_compile_start_time
|
||||
logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
|
||||
self.compilation_config.compilation_time += dynamo_time
|
||||
|
||||
# we control the compilation process, each instance can only be
|
||||
# called once
|
||||
assert not self._called, "VllmBackend can only be called once"
|
||||
|
||||
self.graph = graph
|
||||
self.configure_post_pass()
|
||||
|
||||
self.split_gm, self.piecewise_graphs = split_graph(
|
||||
graph, self.compilation_config.splitting_ops)
|
||||
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
|
||||
# depyf will hook lazy_format_graph_code and dump the graph
|
||||
# for debugging, no need to print the graph here
|
||||
lazy_format_graph_code("before split", self.graph)
|
||||
lazy_format_graph_code("after split", self.split_gm)
|
||||
|
||||
compilation_counter.num_piecewise_graphs_seen += len(
|
||||
self.piecewise_graphs)
|
||||
submod_names_to_compile = [
|
||||
item.submod_name for item in self.piecewise_graphs
|
||||
if not item.is_splitting_graph
|
||||
]
|
||||
|
||||
# propagate the split graph to the piecewise backend,
|
||||
# compile submodules with symbolic shapes
|
||||
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
|
||||
self.vllm_config,
|
||||
self).run(*example_inputs)
|
||||
|
||||
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
|
||||
if not os.path.exists(graph_path):
|
||||
# code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
|
||||
# use `print_readable` because it can include submodules
|
||||
src = "from __future__ import annotations\nimport torch\n" + \
|
||||
self.split_gm.print_readable(print_output=False)
|
||||
src = src.replace("<lambda>", "GraphModule")
|
||||
with open(graph_path, "w") as f:
|
||||
f.write(src)
|
||||
|
||||
logger.debug("Computation graph saved to %s", graph_path)
|
||||
|
||||
self._called = True
|
||||
|
||||
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or \
|
||||
not self.compilation_config.cudagraph_copy_inputs:
|
||||
return self.split_gm
|
||||
|
||||
# if we need to copy input buffers for cudagraph
|
||||
from torch._guards import detect_fake_mode
|
||||
fake_mode = detect_fake_mode()
|
||||
fake_args = [
|
||||
fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
||||
for t in example_inputs
|
||||
]
|
||||
|
||||
# index of tensors that have symbolic shapes (batch size)
|
||||
# for weights and static buffers, they will have concrete shapes.
|
||||
# symbolic shape only happens for input tensors.
|
||||
from torch.fx.experimental.symbolic_shapes import is_symbolic
|
||||
self.sym_tensor_indices = [
|
||||
i for i, x in enumerate(fake_args)
|
||||
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \
|
||||
any(is_symbolic(d) for d in x.size())
|
||||
]
|
||||
|
||||
# compiler managed cudagraph input buffers
|
||||
# we assume the first run with symbolic shapes
|
||||
# has the maximum size among all the tensors
|
||||
self.input_buffers = [
|
||||
example_inputs[x].clone() for x in self.sym_tensor_indices
|
||||
]
|
||||
|
||||
# this is the callable we return to Dynamo to run
|
||||
def copy_and_call(*args):
|
||||
list_args = list(args)
|
||||
for i, index in enumerate(self.sym_tensor_indices):
|
||||
runtime_tensor = list_args[index]
|
||||
runtime_shape = runtime_tensor.shape[0]
|
||||
static_tensor = self.input_buffers[i][:runtime_shape]
|
||||
|
||||
# copy the tensor to the static buffer
|
||||
static_tensor.copy_(runtime_tensor)
|
||||
|
||||
# replace the tensor in the list_args to the static buffer
|
||||
list_args[index] = static_tensor
|
||||
return self.split_gm(*list_args)
|
||||
|
||||
return copy_and_call
|
||||
56
vllm/compilation/base_static_graph.py
Normal file
56
vllm/compilation/base_static_graph.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Protocol
|
||||
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
|
||||
|
||||
class AbstractStaticGraphWrapper(Protocol):
|
||||
"""
|
||||
StaticGraphWrapper interface that allows platforms to wrap a callable
|
||||
to be captured as a static graph.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnable: Callable[..., Any],
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the StaticGraphWrapper class with graph capturing and
|
||||
execution-related configurations.
|
||||
|
||||
Args:
|
||||
runnable (Callable): The callable to be wrapped and captured.
|
||||
vllm_config (VllmConfig): Global configuration for vLLM.
|
||||
runtime_mode (CUDAGraphMode): The style of the static
|
||||
graph runtime. See CUDAGraphMode in vllm/config.py.
|
||||
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
|
||||
are used as concrete runtime mode for cudagraph dispatching.
|
||||
Keyword Args:
|
||||
kwargs: Additional keyword arguments for platform-specific
|
||||
configurations.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Executes the wrapped callable.
|
||||
|
||||
If the current runtime mode in the ForwardContext matches the runtime
|
||||
mode of this instance, it replays the CUDAGraph or captures it using
|
||||
the callable if it hasn't been captured yet. Otherwise, it calls the
|
||||
original callable directly.
|
||||
|
||||
Args:
|
||||
*args: Variable length input arguments to be passed into the
|
||||
callable.
|
||||
**kwargs: Keyword arguments to be passed into the callable.
|
||||
|
||||
Returns:
|
||||
Any: Output of the executed callable.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
1188
vllm/compilation/collective_fusion.py
Normal file
1188
vllm/compilation/collective_fusion.py
Normal file
File diff suppressed because it is too large
Load Diff
573
vllm/compilation/compiler_interface.py
Normal file
573
vllm/compilation/compiler_interface.py
Normal file
@@ -0,0 +1,573 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import copy
|
||||
import hashlib
|
||||
import os
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch._inductor.compile_fx
|
||||
import torch.fx as fx
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
|
||||
from .inductor_pass import pass_context
|
||||
|
||||
|
||||
class CompilerInterface:
|
||||
"""
|
||||
The interface for a compiler that can be used by vLLM.
|
||||
"""
|
||||
# The name of the compiler, e.g. inductor.
|
||||
# This is a class-level attribute.
|
||||
name: str
|
||||
|
||||
def initialize_cache(self,
|
||||
cache_dir: str,
|
||||
disable_cache: bool = False,
|
||||
prefix: str = ""):
|
||||
"""
|
||||
when the vLLM process uses `cache_dir` as the cache directory,
|
||||
the compiler should initialize itself with the cache directory,
|
||||
e.g. by re-directing its own cache directory to a sub-directory.
|
||||
|
||||
prefix can be used in combination with cache_dir to figure out the base
|
||||
cache directory, e.g. there're multiple parts of model being compiled,
|
||||
but we want to share the same cache directory for all of them.
|
||||
|
||||
e.g.
|
||||
cache_dir = "/path/to/dir/backbone", prefix = "backbone"
|
||||
cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
|
||||
"""
|
||||
pass
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
"""
|
||||
Gather all the relevant information from the vLLM config,
|
||||
to compute a hash so that we can cache the compiled model.
|
||||
|
||||
See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash]
|
||||
to check what information
|
||||
is already considered by default. This function should only
|
||||
consider the information that is specific to the compiler.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
"""
|
||||
Compile the graph with the given example inputs and compiler config,
|
||||
with a runtime shape. If the `runtime_shape` is None, it means
|
||||
the `example_inputs` have a dynamic shape. Otherwise, the
|
||||
`runtime_shape` specifies the shape of the inputs. Right now we only
|
||||
support one variable shape for all inputs, which is the batchsize
|
||||
(number of tokens) during inference.
|
||||
|
||||
Dynamo will make sure `graph(*example_inputs)` is valid.
|
||||
|
||||
The function should return a compiled callable function, as well as
|
||||
a handle that can be used to directly load the compiled function.
|
||||
|
||||
The handle should be a plain Python object, preferably a string or a
|
||||
file path for readability.
|
||||
|
||||
If the compiler doesn't support caching, it should return None for the
|
||||
handle. If the compiler fails to compile the graph, it should return
|
||||
None for the compiled function as well.
|
||||
|
||||
`key` is required for StandaloneInductorAdapter, it specifies where to
|
||||
save the compiled artifact. The compiled artifact gets saved to
|
||||
`cache_dir/key`.
|
||||
"""
|
||||
return None, None
|
||||
|
||||
def load(self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None) -> Callable:
|
||||
"""
|
||||
Load the compiled function from the handle.
|
||||
Raises an error if the handle is invalid.
|
||||
|
||||
The handle is the second return value of the `compile` function.
|
||||
"""
|
||||
raise NotImplementedError("caching is not supported")
|
||||
|
||||
|
||||
class AlwaysHitShapeEnv:
|
||||
"""
|
||||
Why do we need this class:
|
||||
|
||||
For normal `torch.compile` usage, every compilation will have
|
||||
one Dynamo bytecode compilation and one Inductor compilation.
|
||||
The Inductor compilation happens under the context of the
|
||||
Dynamo bytecode compilation, and that context is used to
|
||||
determine the dynamic shape information, etc.
|
||||
|
||||
For our use case, we only run Dynamo bytecode compilation once,
|
||||
and run Inductor compilation multiple times with different shapes
|
||||
plus a general shape. The compilation for specific shapes happens
|
||||
outside of the context of the Dynamo bytecode compilation. At that
|
||||
time, we don't have shape environment to provide to Inductor, and
|
||||
it will fail the Inductor code cache lookup.
|
||||
|
||||
By providing a dummy shape environment that always hits, we can
|
||||
make the Inductor code cache lookup always hit, and we can
|
||||
compile the graph for different shapes as needed.
|
||||
|
||||
The following dummy methods are obtained by trial-and-error
|
||||
until it works.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.guards: list[Any] = []
|
||||
|
||||
def evaluate_guards_expression(self, *args, **kwargs):
|
||||
return True
|
||||
|
||||
def get_pruned_guards(self, *args, **kwargs):
|
||||
return []
|
||||
|
||||
def produce_guards_expression(self, *args, **kwargs):
|
||||
return ""
|
||||
|
||||
|
||||
def get_inductor_factors() -> list[Any]:
|
||||
factors: list[Any] = []
|
||||
# summarize system state
|
||||
from torch._inductor.codecache import CacheBase
|
||||
system_factors = CacheBase.get_system()
|
||||
factors.append(system_factors)
|
||||
|
||||
# summarize pytorch state
|
||||
from torch._inductor.codecache import torch_key
|
||||
torch_factors = torch_key()
|
||||
factors.append(torch_factors)
|
||||
return factors
|
||||
|
||||
|
||||
class InductorStandaloneAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler.
|
||||
Requires PyTorch 2.8+.
|
||||
This is not on by default yet, but we plan to turn it on by default for
|
||||
PyTorch 2.8.
|
||||
|
||||
Use VLLM_USE_STANDALONE_COMPILE to toggle this on or off.
|
||||
"""
|
||||
name = "inductor_standalone"
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(self,
|
||||
cache_dir: str,
|
||||
disable_cache: bool = False,
|
||||
prefix: str = ""):
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
set_inductor_config(current_config, runtime_shape)
|
||||
|
||||
if isinstance(runtime_shape, int):
|
||||
dynamic_shapes = "from_example_inputs"
|
||||
else:
|
||||
dynamic_shapes = "from_tracing_context"
|
||||
|
||||
from torch._inductor import standalone_compile
|
||||
with pass_context(runtime_shape):
|
||||
compiled_graph = standalone_compile(
|
||||
graph,
|
||||
example_inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
options={"config_patches": current_config})
|
||||
|
||||
# Save the compiled artifact to disk in the specified path
|
||||
assert key is not None
|
||||
path = os.path.join(self.cache_dir, key)
|
||||
if not envs.VLLM_DISABLE_COMPILE_CACHE:
|
||||
compiled_graph.save(path=path, format="unpacked")
|
||||
compilation_counter.num_compiled_artifacts_saved += 1
|
||||
return compiled_graph, (key, path)
|
||||
|
||||
def load(self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None) -> Callable:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
path = handle[1]
|
||||
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
|
||||
path=path, format="unpacked")
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
def compiled_graph_wrapper(*args):
|
||||
graph_output = inductor_compiled_graph(*args)
|
||||
# unpack the tuple if needed
|
||||
# TODO(rzou): the implication is that we're not
|
||||
# reading the python bytecode correctly in vLLM?
|
||||
if returns_tuple:
|
||||
return graph_output
|
||||
else:
|
||||
return graph_output[0]
|
||||
|
||||
return compiled_graph_wrapper
|
||||
|
||||
|
||||
class InductorAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
|
||||
"""
|
||||
name = "inductor"
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(self,
|
||||
cache_dir: str,
|
||||
disable_cache: bool = False,
|
||||
prefix: str = ""):
|
||||
self.cache_dir = cache_dir
|
||||
self.prefix = prefix
|
||||
self.base_cache_dir = cache_dir[:-len(prefix)] if prefix else cache_dir
|
||||
if disable_cache:
|
||||
return
|
||||
# redirect the cache directory to a sub-directory
|
||||
# set flags so that Inductor and Triton store their cache
|
||||
# in the cache_dir, then users only need to copy the cache_dir
|
||||
# to another machine to reuse the cache.
|
||||
inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
|
||||
os.makedirs(inductor_cache, exist_ok=True)
|
||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
|
||||
triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
|
||||
os.makedirs(triton_cache, exist_ok=True)
|
||||
os.environ["TRITON_CACHE_DIR"] = triton_cache
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
|
||||
# disable remote cache
|
||||
current_config["fx_graph_cache"] = True
|
||||
current_config["fx_graph_remote_cache"] = False
|
||||
|
||||
set_inductor_config(current_config, runtime_shape)
|
||||
|
||||
# inductor can inplace modify the graph, so we need to copy it
|
||||
# see https://github.com/pytorch/pytorch/issues/138980
|
||||
graph = copy.deepcopy(graph)
|
||||
|
||||
# it's the first time we compile this graph
|
||||
# the assumption is that we don't have nested Inductor compilation.
|
||||
# compiled_fx_graph_hash will only be called once, and we can hook
|
||||
# it to get the hash of the compiled graph directly.
|
||||
|
||||
hash_str, file_path = None, None
|
||||
from torch._inductor.codecache import (FxGraphCache,
|
||||
compiled_fx_graph_hash)
|
||||
if torch.__version__.startswith("2.5"):
|
||||
original_load = FxGraphCache.load
|
||||
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
|
||||
|
||||
def hijack_load(*args, **kwargs):
|
||||
inductor_compiled_graph = original_load(*args, **kwargs)
|
||||
nonlocal file_path
|
||||
compiled_fn = inductor_compiled_graph.current_callable
|
||||
file_path = compiled_fn.__code__.co_filename # noqa
|
||||
if not file_path.startswith(self.base_cache_dir):
|
||||
# hooked in the align_inputs_from_check_idxs function
|
||||
# in torch/_inductor/utils.py
|
||||
for cell in compiled_fn.__closure__:
|
||||
if not callable(cell.cell_contents):
|
||||
continue
|
||||
if cell.cell_contents.__code__.co_filename.startswith(
|
||||
self.base_cache_dir):
|
||||
# this is the real file path compiled from Inductor
|
||||
file_path = cell.cell_contents.__code__.co_filename
|
||||
break
|
||||
return inductor_compiled_graph
|
||||
|
||||
hijacked_compile_fx_inner = torch._inductor.compile_fx.compile_fx_inner # noqa
|
||||
elif torch.__version__ >= "2.6":
|
||||
# function renamed in 2.6
|
||||
original_load_name = None
|
||||
|
||||
def hijacked_compile_fx_inner(*args, **kwargs):
|
||||
output = torch._inductor.compile_fx.compile_fx_inner(
|
||||
*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
inductor_compiled_graph = output
|
||||
if inductor_compiled_graph is not None:
|
||||
nonlocal file_path
|
||||
compiled_fn = inductor_compiled_graph.current_callable
|
||||
file_path = compiled_fn.__code__.co_filename # noqa
|
||||
if not file_path.startswith(self.base_cache_dir):
|
||||
# hooked in the align_inputs_from_check_idxs function
|
||||
# in torch/_inductor/utils.py
|
||||
for cell in compiled_fn.__closure__:
|
||||
if not callable(cell.cell_contents):
|
||||
continue
|
||||
code = cell.cell_contents.__code__
|
||||
if code.co_filename.startswith(
|
||||
self.base_cache_dir):
|
||||
# this is the real file path
|
||||
# compiled from Inductor
|
||||
file_path = code.co_filename
|
||||
break
|
||||
hash_str = inductor_compiled_graph._fx_graph_cache_key
|
||||
return output
|
||||
|
||||
def hijack_compiled_fx_graph_hash(*args, **kwargs):
|
||||
out = compiled_fx_graph_hash(*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
hash_str = out[0]
|
||||
return out
|
||||
|
||||
def _check_can_cache(*args, **kwargs):
|
||||
# no error means it can be cached.
|
||||
# Inductor refuses to cache the graph outside of Dynamo
|
||||
# tracing context, and also disables caching for graphs
|
||||
# with high-order ops.
|
||||
# For vLLM, in either case, we want to cache the graph.
|
||||
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
|
||||
return
|
||||
|
||||
def _get_shape_env() -> AlwaysHitShapeEnv:
|
||||
return AlwaysHitShapeEnv()
|
||||
|
||||
with ExitStack() as stack:
|
||||
# hijack to get the compiled graph itself
|
||||
if original_load_name is not None:
|
||||
stack.enter_context(patch(original_load_name, hijack_load))
|
||||
|
||||
# for hijacking the hash of the compiled graph
|
||||
stack.enter_context(
|
||||
patch("torch._inductor.codecache.compiled_fx_graph_hash",
|
||||
hijack_compiled_fx_graph_hash))
|
||||
|
||||
# for providing a dummy shape environment
|
||||
stack.enter_context(
|
||||
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
_get_shape_env))
|
||||
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
AOTAutogradCache)
|
||||
|
||||
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
|
||||
if hasattr(AOTAutogradCache, "_get_shape_env"):
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
|
||||
_get_shape_env))
|
||||
|
||||
# for forcing the graph to be cached
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._check_can_cache",
|
||||
_check_can_cache))
|
||||
|
||||
# Dynamo metrics context, see method for more details.
|
||||
stack.enter_context(self.metrics_context())
|
||||
|
||||
# Disable remote caching. When these are on, on remote cache-hit,
|
||||
# the monkey-patched functions never actually get called.
|
||||
# vLLM today assumes and requires the monkey-patched functions to
|
||||
# get hit.
|
||||
# TODO(zou3519): we're going to replace this all with
|
||||
# standalone_compile sometime.
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
stack.enter_context(
|
||||
torch._inductor.config.patch(fx_graph_remote_cache=False))
|
||||
# InductorAdaptor (unfortunately) requires AOTAutogradCache
|
||||
# to be turned off to run. It will fail to acquire the hash_str
|
||||
# and error if not.
|
||||
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(enable_autograd_cache=False))
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(
|
||||
enable_remote_autograd_cache=False))
|
||||
|
||||
with pass_context(runtime_shape):
|
||||
compiled_graph = compile_fx(
|
||||
graph,
|
||||
example_inputs,
|
||||
inner_compile=hijacked_compile_fx_inner,
|
||||
config_patches=current_config)
|
||||
|
||||
# We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch
|
||||
# compilation cache. So turn off the checks if we disable the
|
||||
# compilation cache.
|
||||
if not envs.VLLM_DISABLE_COMPILE_CACHE:
|
||||
if hash_str is None:
|
||||
raise RuntimeError(
|
||||
"vLLM failed to compile the model. The most "
|
||||
"likely reason for this is that a previous compilation "
|
||||
"failed, leading to a corrupted compilation artifact. "
|
||||
"We recommend trying to "
|
||||
"remove ~/.cache/vllm/torch_compile_cache and try again "
|
||||
"to see the real issue. ")
|
||||
assert file_path is not None, (
|
||||
"failed to get the file path of the compiled graph")
|
||||
return compiled_graph, (hash_str, file_path)
|
||||
|
||||
def load(self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None) -> Callable:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
hash_str = handle[0]
|
||||
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
AOTAutogradCache)
|
||||
from torch._inductor.codecache import FxGraphCache
|
||||
with ExitStack() as exit_stack:
|
||||
exit_stack.enter_context(
|
||||
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv()))
|
||||
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
|
||||
if hasattr(AOTAutogradCache, "_get_shape_env"):
|
||||
exit_stack.enter_context(
|
||||
patch(
|
||||
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv()))
|
||||
|
||||
# Dynamo metrics context, see method for more details.
|
||||
exit_stack.enter_context(self.metrics_context())
|
||||
|
||||
if torch.__version__.startswith("2.5"):
|
||||
inductor_compiled_graph = FxGraphCache._lookup_graph(
|
||||
hash_str, example_inputs, True, False)
|
||||
assert inductor_compiled_graph is not None, (
|
||||
"Inductor cache lookup failed. Please remove"
|
||||
f"the cache directory and try again." # noqa
|
||||
)
|
||||
elif torch.__version__ >= "2.6":
|
||||
from torch._inductor.output_code import (
|
||||
CompiledFxGraphConstantsWithGm)
|
||||
constants = CompiledFxGraphConstantsWithGm(graph)
|
||||
inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
|
||||
hash_str, example_inputs, True, None, constants)
|
||||
assert inductor_compiled_graph is not None, (
|
||||
"Inductor cache lookup failed. Please remove"
|
||||
f"the cache directory and try again." # noqa
|
||||
)
|
||||
|
||||
# Inductor calling convention (function signature):
|
||||
# f(list) -> tuple
|
||||
# Dynamo calling convention (function signature):
|
||||
# f(*args) -> Any
|
||||
|
||||
# need to know if the graph returns a tuple
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
# this is the callable we return to Dynamo to run
|
||||
def compiled_graph(*args):
|
||||
# convert args to list
|
||||
list_args = list(args)
|
||||
graph_output = inductor_compiled_graph(list_args)
|
||||
# unpack the tuple if needed
|
||||
if returns_tuple:
|
||||
return graph_output
|
||||
else:
|
||||
return graph_output[0]
|
||||
|
||||
return compiled_graph
|
||||
|
||||
def metrics_context(self) -> contextlib.AbstractContextManager:
|
||||
"""
|
||||
This method returns the Dynamo metrics context (if it exists,
|
||||
otherwise a null context). It is used by various compile components.
|
||||
Present in torch>=2.6, it's used inside FxGraphCache in
|
||||
torch==2.6 (but not after). It might also be used in various other
|
||||
torch.compile internal functions.
|
||||
|
||||
Because it is re-entrant, we always set it (even if entering via Dynamo
|
||||
and the context was already entered). We might want to revisit if it
|
||||
should be set at a different level of compilation.
|
||||
|
||||
This is likely a bug in PyTorch: public APIs should not rely on
|
||||
manually setting up internal contexts. But we also rely on non-public
|
||||
APIs which might not provide these guarantees.
|
||||
"""
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
import torch._dynamo.utils
|
||||
return torch._dynamo.utils.get_metrics_context()
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
def set_inductor_config(config, runtime_shape):
|
||||
if isinstance(runtime_shape, int):
|
||||
# for a specific batchsize, tuning triton kernel parameters
|
||||
# can be beneficial
|
||||
config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE
|
||||
config["coordinate_descent_tuning"] = (
|
||||
envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING)
|
||||
|
||||
|
||||
class EagerAdaptor(CompilerInterface):
|
||||
name = "eager"
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
compilation_counter.num_eager_compiles += 1
|
||||
# we don't need to compile the graph, just return the graph itself.
|
||||
# It does not support caching, return None for the handle.
|
||||
return graph, None
|
||||
47
vllm/compilation/counter.py
Normal file
47
vllm/compilation/counter.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompilationCounter:
|
||||
num_models_seen: int = 0
|
||||
num_graphs_seen: int = 0
|
||||
# including the splitting ops
|
||||
num_piecewise_graphs_seen: int = 0
|
||||
# not including the splitting ops
|
||||
num_piecewise_capturable_graphs_seen: int = 0
|
||||
num_backend_compilations: int = 0
|
||||
# Number of gpu_model_runner attempts to trigger CUDAGraphs capture
|
||||
num_gpu_runner_capture_triggers: int = 0
|
||||
# Number of CUDAGraphs captured
|
||||
num_cudagraph_captured: int = 0
|
||||
# InductorAdapter.compile calls
|
||||
num_inductor_compiles: int = 0
|
||||
# EagerAdapter.compile calls
|
||||
num_eager_compiles: int = 0
|
||||
# The number of time vLLM's compiler cache entry was updated
|
||||
num_cache_entries_updated: int = 0
|
||||
# The number of standalone_compile compiled artifacts saved
|
||||
num_compiled_artifacts_saved: int = 0
|
||||
# Number of times a model was loaded with CompilationLevel.DYNAMO_AS_IS
|
||||
dynamo_as_is_count: int = 0
|
||||
|
||||
def clone(self) -> "CompilationCounter":
|
||||
return copy.deepcopy(self)
|
||||
|
||||
@contextmanager
|
||||
def expect(self, **kwargs):
|
||||
old = self.clone()
|
||||
yield
|
||||
for k, v in kwargs.items():
|
||||
assert getattr(self, k) - getattr(old, k) == v, (
|
||||
f"{k} not as expected, before it is {getattr(old, k)}"
|
||||
f", after it is {getattr(self, k)}, "
|
||||
f"expected diff is {v}")
|
||||
|
||||
|
||||
compilation_counter = CompilationCounter()
|
||||
199
vllm/compilation/cuda_graph.py
Normal file
199
vllm/compilation/cuda_graph.py
Normal file
@@ -0,0 +1,199 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
set_graph_pool_id)
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import weak_ref_tensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CUDAGraphEntry:
|
||||
batch_descriptor: BatchDescriptor
|
||||
cudagraph: Optional[torch.cuda.CUDAGraph] = None
|
||||
output: Optional[Any] = None
|
||||
|
||||
# for cudagraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: Optional[list[int]] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CUDAGraphOptions:
|
||||
debug_log_enable: bool = True
|
||||
gc_disable: bool = False
|
||||
weak_ref_output: bool = True
|
||||
|
||||
|
||||
class CUDAGraphWrapper:
|
||||
"""Wraps a runnable to add CUDA graph capturing and replaying ability. And
|
||||
provide attribute access to the underlying `runnable` via `__getattr__`.
|
||||
|
||||
The workflow of this wrapper in the cudagraph dispatching is as follows:
|
||||
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
|
||||
PIECEWISE).
|
||||
2. At runtime, the wrapper receives a runtime_mode and a
|
||||
batch_descriptor(key) from the forward context and blindly trust them
|
||||
for cudagraph dispatching.
|
||||
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
|
||||
wrapper, just call the runnable directly.
|
||||
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
|
||||
the wrapper will perform cudagraph capture(if key does not exist, create
|
||||
a new entry and cache it) or replay (if key exists in the cache).
|
||||
|
||||
Note: CUDAGraphWrapper does not store persistent buffers or copy any
|
||||
runtime inputs into that buffers for replay. We assume implementing them
|
||||
is done outside of the wrapper. That is because we do not make any
|
||||
assumption on the dynamic shape (batch size) of the runtime inputs, as a
|
||||
trade-off for staying orthogonal to compilation logic. Nevertheless,
|
||||
tracing and checking the input addresses to be consistent during replay is
|
||||
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
runnable: Callable,
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
cudagraph_options: Optional[CUDAGraphOptions] = None):
|
||||
self.runnable = runnable
|
||||
self.vllm_config = vllm_config
|
||||
self.runtime_mode = runtime_mode
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
self.first_run_finished = False
|
||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||
|
||||
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
|
||||
# need to initialize a CUDAGraphWrapper.
|
||||
assert self.runtime_mode != CUDAGraphMode.NONE
|
||||
# TODO: in the future, if we want to use multiple
|
||||
# streams, it might not be safe to share a global pool.
|
||||
# only investigate this when we use multiple streams
|
||||
self.graph_pool = current_platform.get_global_graph_pool()
|
||||
|
||||
if cudagraph_options is None:
|
||||
cudagraph_options = CUDAGraphOptions()
|
||||
self.cudagraph_options = cudagraph_options
|
||||
# the entries for different batch descriptors that we need to capture
|
||||
# cudagraphs for.
|
||||
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry]\
|
||||
= {}
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
return getattr(self.runnable, key)
|
||||
raise AttributeError(f"Attribute {key} not exists in the runnable of "
|
||||
f"cudagraph wrapper: {self.runnable}")
|
||||
|
||||
def unwrap(self) -> Callable:
|
||||
# in case we need to access the original runnable.
|
||||
return self.runnable
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
forward_context = get_forward_context()
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
|
||||
if cudagraph_runtime_mode == CUDAGraphMode.NONE or \
|
||||
cudagraph_runtime_mode != self.runtime_mode:
|
||||
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
|
||||
# running without cudagraphs.
|
||||
# We do not trigger capture/replay if the runtime mode is not
|
||||
# matches. This enables properly dispatching to the correct
|
||||
# CUDAGraphWrapper when nesting multiple instances with different
|
||||
# runtime modes.
|
||||
return self.runnable(*args, **kwargs)
|
||||
|
||||
if batch_descriptor not in self.concrete_cudagraph_entries:
|
||||
# create a new entry for this batch descriptor
|
||||
self.concrete_cudagraph_entries[batch_descriptor] = \
|
||||
CUDAGraphEntry(batch_descriptor=batch_descriptor)
|
||||
|
||||
entry = self.concrete_cudagraph_entries[batch_descriptor]
|
||||
|
||||
if entry.cudagraph is None:
|
||||
if self.cudagraph_options.debug_log_enable:
|
||||
# Since we capture cudagraph for many different shapes and
|
||||
# capturing is fast, we don't need to log it for every
|
||||
# shape. E.g. we only log it for the first subgraph in
|
||||
# piecewise mode.
|
||||
logger.debug("Capturing a cudagraph on (%s,%s)",
|
||||
self.runtime_mode.name, entry.batch_descriptor)
|
||||
# validate that cudagraph capturing is legal at this point.
|
||||
validate_cudagraph_capturing_enabled()
|
||||
|
||||
input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
entry.input_addresses = input_addresses
|
||||
cudagraph = torch.cuda.CUDAGraph()
|
||||
|
||||
with ExitStack() as stack:
|
||||
if self.cudagraph_options.gc_disable:
|
||||
# during every model forward for piecewise cudagraph
|
||||
# mode, we will capture many pieces of cudagraphs
|
||||
# (roughly one per layer). running gc again and again
|
||||
# across layers will make the cudagraph capture very slow.
|
||||
# therefore, we only run gc for the first graph,
|
||||
# and disable gc for the rest of the graphs.
|
||||
stack.enter_context(patch("gc.collect", lambda: None))
|
||||
stack.enter_context(
|
||||
patch("torch.cuda.empty_cache", lambda: None))
|
||||
|
||||
if self.graph_pool is not None:
|
||||
set_graph_pool_id(self.graph_pool)
|
||||
else:
|
||||
set_graph_pool_id(current_platform.graph_pool_handle())
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
||||
# `output` is managed by pytorch's cudagraph pool
|
||||
output = self.runnable(*args, **kwargs)
|
||||
if self.cudagraph_options.weak_ref_output:
|
||||
# by converting it to weak ref,
|
||||
# the original `output` will immediately be released
|
||||
# to save memory. It is only safe to do this for
|
||||
# the last graph in piecewise cuadgraph mode, because
|
||||
# the output of the last graph will not be used by
|
||||
# any other cuda graph.
|
||||
output = weak_ref_tensors(output)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# to save memory
|
||||
entry.output = weak_ref_tensors(output)
|
||||
entry.cudagraph = cudagraph
|
||||
|
||||
compilation_counter.num_cudagraph_captured += 1
|
||||
|
||||
# important: we need to return the output, rather than
|
||||
# the weak ref of the output, so that pytorch can correctly
|
||||
# manage the memory during cuda graph capture
|
||||
return output
|
||||
|
||||
if self.is_debugging_mode:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
assert new_input_addresses == entry.input_addresses, (
|
||||
f"Input addresses for cudagraphs are different "
|
||||
f"during replay. Expected {entry.input_addresses}, "
|
||||
f"got {new_input_addresses}")
|
||||
|
||||
entry.cudagraph.replay()
|
||||
return entry.output
|
||||
117
vllm/compilation/cuda_piecewise_backend.py
Normal file
117
vllm/compilation/cuda_piecewise_backend.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch.fx as fx
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.compilation.monitor import end_monitoring_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ConcreteSizeEntry:
|
||||
runtime_shape: int
|
||||
compiled: bool = False
|
||||
runnable: Callable = None # type: ignore
|
||||
|
||||
|
||||
class PiecewiseBackend:
|
||||
|
||||
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
|
||||
piecewise_compile_index: int, total_piecewise_compiles: int,
|
||||
sym_shape_indices: list[int],
|
||||
compiled_graph_for_general_shape: Callable,
|
||||
vllm_backend: VllmBackend):
|
||||
"""
|
||||
The backend for piecewise compilation.
|
||||
It mainly handles the compilation of static shapes and
|
||||
dispatching based on runtime shape.
|
||||
|
||||
We will compile `self.graph` once for the general shape,
|
||||
and then compile for different shapes specified in
|
||||
`compilation_config.compile_sizes`.
|
||||
"""
|
||||
self.graph = graph
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.piecewise_compile_index = piecewise_compile_index
|
||||
self.total_piecewise_compiles = total_piecewise_compiles
|
||||
self.vllm_backend = vllm_backend
|
||||
|
||||
self.is_first_graph = piecewise_compile_index == 0
|
||||
self.is_last_graph = (
|
||||
piecewise_compile_index == total_piecewise_compiles - 1)
|
||||
|
||||
self.is_full_graph = total_piecewise_compiles == 1
|
||||
|
||||
self.compile_sizes: set[int] = set(
|
||||
self.compilation_config.compile_sizes)
|
||||
|
||||
self.first_run_finished = False
|
||||
|
||||
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
|
||||
|
||||
self.sym_shape_indices = sym_shape_indices
|
||||
|
||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||
|
||||
# the entries for different shapes that we need to compile
|
||||
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
|
||||
|
||||
# to_be_compiled_sizes tracks the remaining sizes to compile,
|
||||
# and updates during the compilation process, so we need to copy it
|
||||
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
|
||||
|
||||
# We only keep compilation management inside this class directly.
|
||||
for shape in self.compile_sizes:
|
||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||
runtime_shape=shape,
|
||||
runnable=self.compiled_graph_for_general_shape,
|
||||
)
|
||||
|
||||
def check_for_ending_compilation(self):
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
# no specific sizes to compile
|
||||
# save the hash of the inductor graph for the next run
|
||||
self.vllm_backend.compiler_manager.save_to_file()
|
||||
end_monitoring_torch_compile(self.vllm_config)
|
||||
|
||||
def __call__(self, *args) -> Any:
|
||||
if not self.first_run_finished:
|
||||
self.first_run_finished = True
|
||||
self.check_for_ending_compilation()
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
|
||||
runtime_shape = args[self.sym_shape_indices[0]]
|
||||
|
||||
if runtime_shape not in self.concrete_size_entries:
|
||||
# we don't need to do anything for this shape
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
|
||||
entry = self.concrete_size_entries[runtime_shape]
|
||||
|
||||
if not entry.compiled:
|
||||
entry.compiled = True
|
||||
self.to_be_compiled_sizes.remove(runtime_shape)
|
||||
# args are real arguments
|
||||
entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||
self.graph,
|
||||
args,
|
||||
self.compilation_config.inductor_compile_config,
|
||||
self.compilation_config,
|
||||
graph_index=self.piecewise_compile_index,
|
||||
num_graphs=self.total_piecewise_compiles,
|
||||
runtime_shape=runtime_shape)
|
||||
|
||||
# finished compilations for all required shapes
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
self.check_for_ending_compilation()
|
||||
|
||||
return entry.runnable(*args)
|
||||
400
vllm/compilation/decorators.py
Normal file
400
vllm/compilation/decorators.py
Normal file
@@ -0,0 +1,400 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
from typing import Callable, Optional, TypeVar, Union, overload
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import resolve_obj_by_qualname, supports_dynamo
|
||||
|
||||
from .monitor import start_monitoring_torch_compile
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"
|
||||
|
||||
_T = TypeVar("_T", bound=type[nn.Module])
|
||||
|
||||
|
||||
def ignore_torch_compile(cls: _T) -> _T:
|
||||
"""
|
||||
A decorator to ignore support_torch_compile decorator
|
||||
on the class. This is useful when a parent class has
|
||||
a support_torch_compile decorator, but we don't want to
|
||||
compile the class `cls` that inherits the parent class.
|
||||
This only ignores compiling the forward of the class the
|
||||
decorator is applied to.
|
||||
|
||||
If the parent has ignore_torch_compile but the child has
|
||||
support_torch_compile, the child will still be compiled.
|
||||
|
||||
If the class has one or more submodules
|
||||
that have support_torch_compile decorator applied, compile will
|
||||
not be ignored for those submodules.
|
||||
"""
|
||||
setattr(cls, IGNORE_COMPILE_KEY, True)
|
||||
return cls
|
||||
|
||||
|
||||
def _should_ignore_torch_compile(cls) -> bool:
|
||||
"""
|
||||
Check if the class should be ignored for torch.compile.
|
||||
"""
|
||||
return getattr(cls, IGNORE_COMPILE_KEY, False)
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
|
||||
) -> Callable[[_T], _T]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]],
|
||||
) -> Callable[[_T], _T]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(cls: _T) -> _T:
|
||||
...
|
||||
|
||||
|
||||
def support_torch_compile(
|
||||
cls: Optional[_T] = None,
|
||||
*,
|
||||
dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None,
|
||||
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
|
||||
) -> Union[Callable[[_T], _T], _T]:
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
|
||||
Usage 1: use directly as a decorator without arguments:
|
||||
|
||||
```python
|
||||
@support_torch_compile
|
||||
class MyModel(nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
|
||||
...
|
||||
```
|
||||
|
||||
Usage 2: use as a decorator with arguments:
|
||||
|
||||
```python
|
||||
@support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
|
||||
class MyModel(nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
|
||||
...
|
||||
```
|
||||
|
||||
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
|
||||
dimensions of the argument. The dynamic dimensions can be either a single
|
||||
integer or a list of integers.
|
||||
|
||||
if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
|
||||
of the `forward` method, based on the following default rules:
|
||||
|
||||
- if the argument is annotated as `torch.Tensor` or
|
||||
`Optional[torch.Tensor]`, the first dimension will be
|
||||
marked as dynamic.
|
||||
- if the argument is annotated as `IntermediateTensors`, the first
|
||||
dimension of all the tensors in the intermediate tensors
|
||||
will be marked as dynamic.
|
||||
|
||||
During runtime, when we actually mark dimensions of tensors,
|
||||
it depends on the value of arguments:
|
||||
|
||||
- if it is a single integer (can be negative), the corresponding dimension
|
||||
of the argument will be marked as dynamic.
|
||||
- if it is `None`, ignored.
|
||||
- if it is `IntermediateTensors`, all the tensors in the intermediate
|
||||
tensors will be marked as dynamic.
|
||||
- otherwise, it will raise an error.
|
||||
|
||||
NOTE: if an argument is `None`, it should always be passed as `None` during
|
||||
the lifetime of the model, otherwise, it cannot be captured as a single
|
||||
computation graph.
|
||||
|
||||
`enable_if` is a function that takes a `VllmConfig` object as input and
|
||||
returns a boolean value indicating whether to compile the model or not.
|
||||
This is useful if you want to compile the model only when certain
|
||||
conditions are met.
|
||||
"""
|
||||
|
||||
def cls_decorator_helper(cls: _T) -> _T:
|
||||
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
|
||||
# to avoid too much indentation for `_support_torch_compile``
|
||||
if not hasattr(cls, 'forward'):
|
||||
raise TypeError("decorated class should have a forward method.")
|
||||
sig = inspect.signature(cls.forward)
|
||||
inferred_dynamic_arg_dims = dynamic_arg_dims
|
||||
if inferred_dynamic_arg_dims is None:
|
||||
inferred_dynamic_arg_dims = {}
|
||||
for k, v in sig.parameters.items():
|
||||
if v.annotation in [
|
||||
torch.Tensor, Optional[torch.Tensor],
|
||||
IntermediateTensors, Optional[IntermediateTensors]
|
||||
]:
|
||||
inferred_dynamic_arg_dims[k] = 0
|
||||
|
||||
logger.debug(("Inferred dynamic dimensions for "
|
||||
"forward method of %s: %s"), cls,
|
||||
list(inferred_dynamic_arg_dims.keys()))
|
||||
|
||||
if len(inferred_dynamic_arg_dims) == 0:
|
||||
raise ValueError(
|
||||
"No dynamic dimensions found in the forward method of "
|
||||
f"{cls}. Please provide dynamic_arg_dims explicitly.")
|
||||
|
||||
for k in inferred_dynamic_arg_dims:
|
||||
if k not in sig.parameters:
|
||||
raise ValueError(
|
||||
f"Argument {k} not found in the forward method of {cls}")
|
||||
return _support_torch_compile(cls, inferred_dynamic_arg_dims,
|
||||
enable_if)
|
||||
|
||||
if cls is not None:
|
||||
# use `support_torch_compile` as a decorator without arguments
|
||||
assert isinstance(cls, type)
|
||||
return cls_decorator_helper(cls)
|
||||
|
||||
return cls_decorator_helper
|
||||
|
||||
|
||||
def _support_torch_compile(
|
||||
cls: _T,
|
||||
dynamic_arg_dims: dict[str, Union[int, list[int]]],
|
||||
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
|
||||
) -> _T:
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
"""
|
||||
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
|
||||
# support decorating multiple times
|
||||
return cls
|
||||
|
||||
# take care of method resolution order
|
||||
# make sure super().__init__ is called on the base class
|
||||
# other than TorchCompileWrapperWithCustomDispatcher
|
||||
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
|
||||
|
||||
old_init = cls.__init__
|
||||
|
||||
setattr(cls, IGNORE_COMPILE_KEY, False)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
|
||||
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
self.vllm_config = vllm_config
|
||||
enable_compile = enable_if is None or enable_if(vllm_config)
|
||||
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
|
||||
# will handle the compilation, so we don't need to do anything here.
|
||||
self.do_not_compile = \
|
||||
vllm_config.compilation_config.level in [
|
||||
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
|
||||
] or not supports_dynamo() or _should_ignore_torch_compile(
|
||||
self.__class__) or not enable_compile
|
||||
if self.do_not_compile:
|
||||
return
|
||||
|
||||
compilation_counter.num_models_seen += 1
|
||||
TorchCompileWrapperWithCustomDispatcher.__init__(
|
||||
self, compilation_level=vllm_config.compilation_config.level)
|
||||
|
||||
cls.__init__ = __init__
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# torch.compiler.is_compiling() means we are inside the compilation
|
||||
# e.g. TPU has the compilation logic in model runner, so we don't
|
||||
# need to compile the model inside.
|
||||
if self.do_not_compile or torch.compiler.is_compiling():
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
# the first compilation needs to have dynamic shapes marked
|
||||
if len(self.compiled_codes) < 1:
|
||||
sig = inspect.signature(self.__class__.forward)
|
||||
bound_args = sig.bind(self, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
for k, dims in dynamic_arg_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
if arg is not None:
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [
|
||||
arg.ndim + dim if dim < 0 else dim for dim in dims
|
||||
]
|
||||
torch._dynamo.mark_dynamic(arg, dims)
|
||||
elif isinstance(arg, IntermediateTensors):
|
||||
for tensor in arg.tensors.values():
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [
|
||||
tensor.ndim + dim if dim < 0 else dim
|
||||
for dim in dims
|
||||
]
|
||||
torch._dynamo.mark_dynamic(tensor, dims)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported dynamic dimensions"
|
||||
f" {dims} for argument {k} with type {type(arg)}.")
|
||||
# here, it is the starting point of the `torch.compile` process
|
||||
start_monitoring_torch_compile(self.vllm_config)
|
||||
logger.debug("Start compiling function %s",
|
||||
self.original_code_object)
|
||||
|
||||
# if we don't use custom dispatcher, we can directly call the
|
||||
# compiled function and let torch.compile handle the dispatching,
|
||||
# with the overhead of guard evaluation and recompilation.
|
||||
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
|
||||
# it seems Dynamo reuse the compilation across instances,
|
||||
# while we need to make sure the compiled code is not reused.
|
||||
# we need to control all the compilation of the model.
|
||||
torch._dynamo.eval_frame.remove_from_cache(
|
||||
self.original_code_object)
|
||||
|
||||
# collect all relevant files traced by Dynamo,
|
||||
# so that the compilation cache can trigger re-compilation
|
||||
# properly when any of these files change.
|
||||
|
||||
# 1. the file containing the top-level forward function
|
||||
self.vllm_config.compilation_config.traced_files.add(
|
||||
self.original_code_object.co_filename)
|
||||
|
||||
# 2. every time Dynamo sees a function call, it will inline
|
||||
# the function by calling InliningInstructionTranslator.inline_call
|
||||
# we hijack this function to know all the functions called
|
||||
# during Dynamo tracing, and their corresponding files
|
||||
inline_call = InliningInstructionTranslator.inline_call
|
||||
|
||||
def patched_inline_call(parent, func, args, kwargs):
|
||||
code = func.get_code()
|
||||
self.vllm_config.compilation_config.traced_files.add(
|
||||
code.co_filename)
|
||||
return inline_call(parent, func, args, kwargs)
|
||||
|
||||
# Disable the C++ compilation of symbolic shape guards. C++-fication
|
||||
# of symbolic shape guards can improve guard overhead. But, since
|
||||
# vllm skip guards anyways, setting this flag to False can improve
|
||||
# compile time.
|
||||
dynamo_config_patches = {}
|
||||
try:
|
||||
_ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
|
||||
dynamo_config_patches[
|
||||
"enable_cpp_symbolic_shape_guards"] = False
|
||||
except AttributeError:
|
||||
# Note: this config is not available in torch 2.6, we can skip
|
||||
# if the config doesn't exist
|
||||
logger.debug(
|
||||
"enable_cpp_symbolic_shape_guards config not available")
|
||||
|
||||
with patch.object(
|
||||
InliningInstructionTranslator, "inline_call",
|
||||
patched_inline_call), torch._dynamo.config.patch(
|
||||
**dynamo_config_patches
|
||||
), maybe_use_cudagraph_partition_wrapper(
|
||||
self.vllm_config), _torch27_patch_tensor_subclasses():
|
||||
output = self.compiled_callable(*args, **kwargs)
|
||||
return output
|
||||
|
||||
# usually, capturing the model once is enough, and then we can
|
||||
# dispatch to the compiled code directly, without going through
|
||||
# the Dynamo guard mechanism.
|
||||
with self.dispatch_to_code(0):
|
||||
model_output = self.forward(*args, **kwargs)
|
||||
return model_output
|
||||
|
||||
cls.__call__ = __call__
|
||||
return cls
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
|
||||
"""
|
||||
Context manager to set/unset customized cudagraph partition wrappers.
|
||||
|
||||
If we're using Inductor-based graph partitioning, we currently have the
|
||||
whole `fx.Graph` before Inductor lowering and and the piecewise
|
||||
splitting happens after all graph passes and fusions. Here, we add
|
||||
a custom hook for Inductor to wrap each partition with our static
|
||||
graph wrapper class to maintain more control over static graph
|
||||
capture and replay.
|
||||
"""
|
||||
from vllm.config import CUDAGraphMode
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and compilation_config.use_inductor_graph_partition):
|
||||
from torch._inductor.utils import CUDAGraphWrapperMetadata
|
||||
|
||||
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
static_graph_wrapper_class = resolve_obj_by_qualname(
|
||||
current_platform.get_static_graph_wrapper_cls())
|
||||
|
||||
def customized_cudagraph_wrapper(f,
|
||||
metadata: CUDAGraphWrapperMetadata):
|
||||
partition_id = metadata.partition_index
|
||||
num_partitions = metadata.num_partitions
|
||||
return static_graph_wrapper_class(
|
||||
runnable=f,
|
||||
vllm_config=vllm_config,
|
||||
runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
cudagraph_options=CUDAGraphOptions(
|
||||
debug_log_enable=partition_id == 0,
|
||||
gc_disable=partition_id != 0,
|
||||
weak_ref_output=partition_id == num_partitions - 1,
|
||||
))
|
||||
|
||||
torch._inductor.utils.set_customized_partition_wrappers(
|
||||
customized_cudagraph_wrapper)
|
||||
|
||||
yield
|
||||
|
||||
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and compilation_config.use_inductor_graph_partition):
|
||||
torch._inductor.utils.set_customized_partition_wrappers(None)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _torch27_patch_tensor_subclasses():
|
||||
"""
|
||||
Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
|
||||
using torch 2.7.0. This enables using weight_loader_v2 and the use of
|
||||
`BasevLLMParameters` without having to replace them with regular tensors
|
||||
before `torch.compile`-time.
|
||||
"""
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ModelWeightParameter,
|
||||
RowvLLMParameter,
|
||||
_ColumnvLLMParameter)
|
||||
|
||||
def return_false(*args, **kwargs):
|
||||
return False
|
||||
|
||||
if version.parse("2.7") <= version.parse(
|
||||
torch.__version__) < version.parse("2.8"):
|
||||
yield
|
||||
return
|
||||
|
||||
with (torch._dynamo.config.patch("traceable_tensor_subclasses", [
|
||||
BasevLLMParameter, ModelWeightParameter, _ColumnvLLMParameter,
|
||||
RowvLLMParameter
|
||||
]),
|
||||
patch("torch._dynamo.variables.torch.can_dispatch_torch_function",
|
||||
return_false)):
|
||||
yield
|
||||
205
vllm/compilation/fix_functionalization.py
Normal file
205
vllm/compilation/fix_functionalization.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import operator
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fx_utils import is_func
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FixFunctionalizationPass(VllmInductorPass):
|
||||
"""
|
||||
This pass defunctionalizes certain nodes to avoid redundant tensor copies.
|
||||
After this pass, DCE (dead-code elimination) should never be run,
|
||||
as de-functionalized nodes may appear as dead code.
|
||||
|
||||
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
# XPU does not support auto-functionalization yet.
|
||||
# Will enable this when switch to vllm-xpu-kernels.
|
||||
if current_platform.is_xpu():
|
||||
logger.debug("XPU platform does not support fix functionalization"
|
||||
"pass currently.")
|
||||
return
|
||||
|
||||
self.nodes_to_remove: list[torch.fx.Node] = []
|
||||
count = 0
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, auto_functionalized):
|
||||
continue # Avoid deep if-elif nesting
|
||||
|
||||
kwargs = node.kwargs
|
||||
at_target = node.args[0]
|
||||
|
||||
if at_target == torch.ops._C.rotary_embedding.default:
|
||||
query = kwargs['query']
|
||||
mm_node = query.args[0].args[0]
|
||||
|
||||
# rotary_embedding is a special case: the two mutating inputs
|
||||
# are query and key, which are slices of mm_node.
|
||||
# While functionalized, results at[1] and at[2] are scattered
|
||||
# back into mm_node. After de-functionalization, we can just
|
||||
# use mm_node directly.
|
||||
for idx, user in self.getitem_users(node).items():
|
||||
for user_of_getitem in user.users:
|
||||
if is_func(user_of_getitem,
|
||||
torch.ops.aten.slice_scatter.default):
|
||||
user_of_getitem.replace_all_uses_with(mm_node)
|
||||
self._remove(user_of_getitem)
|
||||
self._remove(user)
|
||||
|
||||
self.insert_defunctionalized(graph, node)
|
||||
self._remove(node)
|
||||
|
||||
# rms_norm replacements avoid the most copies for LLaMa.
|
||||
elif at_target == torch.ops._C.fused_add_rms_norm.default:
|
||||
mutated_args = {1: 'input', 2: 'residual'}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
|
||||
mutated_args = {1: 'result', 2: 'residual'}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
|
||||
mutated_args = {1: 'result', 2: 'scale', 3: 'residual'}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target in [
|
||||
torch.ops._C.rms_norm.default,
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default,
|
||||
]:
|
||||
mutated_args = {1: 'result'}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
# For some reason we need to specify the args for both
|
||||
# silu_and_mul and silu_and_mul_quant. The kwargs
|
||||
# pathway gets the wrong answer.
|
||||
elif at_target == torch.ops._C.silu_and_mul.default:
|
||||
mutated_args = {1: 'result'}
|
||||
self.defunctionalize(graph,
|
||||
node,
|
||||
mutated_args,
|
||||
args=('result', 'input'))
|
||||
elif at_target == torch.ops._C.silu_and_mul_quant.default:
|
||||
mutated_args = {1: 'result'}
|
||||
self.defunctionalize(graph,
|
||||
node,
|
||||
mutated_args,
|
||||
args=('result', 'input', 'scale'))
|
||||
elif hasattr(
|
||||
torch.ops._C, "silu_and_mul_nvfp4_quant"
|
||||
) and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default:
|
||||
mutated_args = {1: 'result', 2: 'result_block_scale'}
|
||||
self.defunctionalize(graph,
|
||||
node,
|
||||
mutated_args,
|
||||
args=('result', 'result_block_scale',
|
||||
'input', 'input_global_scale'))
|
||||
else:
|
||||
continue # skip the count
|
||||
|
||||
count += 1
|
||||
|
||||
self.dump_graph(graph, "before_cleanup")
|
||||
|
||||
# Remove the nodes all at once
|
||||
count_removed = len(self.nodes_to_remove)
|
||||
for node in self.nodes_to_remove:
|
||||
graph.erase_node(node)
|
||||
|
||||
logger.debug("De-functionalized %s nodes, removed %s nodes", count,
|
||||
count_removed)
|
||||
self.nodes_to_remove.clear()
|
||||
|
||||
def _remove(self, node_or_nodes: Union[torch.fx.Node,
|
||||
Iterable[torch.fx.Node]]):
|
||||
"""
|
||||
Stage a node (or nodes) for removal at the end of the pass.
|
||||
"""
|
||||
if isinstance(node_or_nodes, torch.fx.Node):
|
||||
self.nodes_to_remove.append(node_or_nodes)
|
||||
else:
|
||||
self.nodes_to_remove.extend(node_or_nodes)
|
||||
|
||||
def defunctionalize(self,
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
mutated_args: dict[int, Union[torch.fx.Node, str]],
|
||||
args: Optional[tuple[Union[torch.fx.Node, str],
|
||||
...]] = None):
|
||||
"""
|
||||
De-functionalize a node by replacing it with a call to the original.
|
||||
It also replaces the getitem users with the mutated arguments.
|
||||
See replace_users_with_mutated_args and insert_defunctionalized.
|
||||
"""
|
||||
self.replace_users_with_mutated_args(node, mutated_args)
|
||||
self.insert_defunctionalized(graph, node, args=args)
|
||||
self._remove(node)
|
||||
|
||||
def replace_users_with_mutated_args(self, node: torch.fx.Node,
|
||||
mutated_args: dict[int,
|
||||
Union[torch.fx.Node,
|
||||
str]]):
|
||||
"""
|
||||
Replace all getitem users of the auto-functionalized node with the
|
||||
mutated arguments.
|
||||
:param node: The auto-functionalized node
|
||||
:param mutated_args: The mutated arguments, indexed by getitem index.
|
||||
If the value of an arg is a string, `node.kwargs[arg]` is used.
|
||||
"""
|
||||
for idx, user in self.getitem_users(node).items():
|
||||
arg = mutated_args[idx]
|
||||
arg = node.kwargs[arg] if isinstance(arg, str) else arg
|
||||
user.replace_all_uses_with(arg)
|
||||
self._remove(user)
|
||||
|
||||
def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
|
||||
"""
|
||||
Returns the operator.getitem users of the auto-functionalized node,
|
||||
indexed by the index they are getting.
|
||||
"""
|
||||
users = {}
|
||||
for user in node.users:
|
||||
if is_func(user, operator.getitem):
|
||||
idx = user.args[1]
|
||||
users[idx] = user
|
||||
return users
|
||||
|
||||
def insert_defunctionalized(self,
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
args: Optional[tuple[Union[torch.fx.Node, str],
|
||||
...]] = None):
|
||||
"""
|
||||
Insert a new defunctionalized node into the graph before node.
|
||||
If one of the kwargs is 'out', provide args directly,
|
||||
as node.kwargs cannot be used.
|
||||
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
|
||||
|
||||
:param graph: Graph to insert the defunctionalized node into
|
||||
:param node: The auto-functionalized node to defunctionalize
|
||||
:param args: If we cannot use kwargs, specify args directly.
|
||||
If an arg is a string, `node.kwargs[arg]` is used.
|
||||
""" # noqa: E501
|
||||
assert is_func(node, auto_functionalized), \
|
||||
f"node must be auto-functionalized, is {node} instead"
|
||||
|
||||
# Create a new call to the original function
|
||||
with graph.inserting_before(node):
|
||||
function = node.args[0]
|
||||
if args is None:
|
||||
graph.call_function(function, kwargs=node.kwargs)
|
||||
else:
|
||||
# Args passed as strings refer to items in node.kwargs
|
||||
args = tuple(node.kwargs[arg] if isinstance(arg, str) else arg
|
||||
for arg in args)
|
||||
graph.call_function(function, args=args)
|
||||
383
vllm/compilation/fusion.py
Normal file
383
vllm/compilation/fusion.py
Normal file
@@ -0,0 +1,383 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def empty_bf16(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
|
||||
def empty_fp32(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
|
||||
|
||||
|
||||
def empty_i32(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
|
||||
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
|
||||
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym:
|
||||
torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTensorSym:
|
||||
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTokenSym:
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
|
||||
|
||||
|
||||
class FusedRMSQuantKey(NamedTuple):
|
||||
"""
|
||||
Named tuple for identifying the type of RMSNorm + quant fusion.
|
||||
quant: type of quantization
|
||||
fused_add: does the op also perform the residual add
|
||||
"""
|
||||
quant: QuantKey
|
||||
fused_add: bool
|
||||
|
||||
def __str__(self):
|
||||
return (f"FusedQuantKey({self.quant}, with"
|
||||
f"{'' if self.fused_add else 'out'} residual)")
|
||||
|
||||
|
||||
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
|
||||
FusedRMSQuantKey(kFp8StaticTensorSym, False):
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(kFp8StaticTensorSym, True):
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(kFp8DynamicTokenSym, False):
|
||||
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(kFp8DynamicTokenSym, True):
|
||||
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
|
||||
class RMSNormQuantPattern:
|
||||
|
||||
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
|
||||
self.epsilon = epsilon
|
||||
self.quant_dtype = key.quant.dtype
|
||||
|
||||
assert key.quant in QUANT_OPS, \
|
||||
f"unsupported quantization scheme {key.quant}"
|
||||
self.QUANT_OP = QUANT_OPS[key.quant]
|
||||
|
||||
assert key in FUSED_OPS, \
|
||||
f"unsupported fused rmsnorm+quant op for {key}"
|
||||
self.FUSED_OP = FUSED_OPS[key]
|
||||
|
||||
|
||||
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
symmetric=True):
|
||||
fused_key = FusedRMSQuantKey(fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
scale=kStaticTensorScale,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, fused_key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
# Cannot use methods, as the self argument affects tracing
|
||||
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
|
||||
input: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at1 = auto_functionalized(RMS_OP,
|
||||
result=result_rms,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon)
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
result=result,
|
||||
input=at1[1],
|
||||
scale=scale)
|
||||
|
||||
# result
|
||||
return at2[1]
|
||||
|
||||
def replacement(result: torch.Tensor, result_rms: torch.Tensor,
|
||||
input: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon)
|
||||
|
||||
# result
|
||||
return at[1]
|
||||
|
||||
inputs = [
|
||||
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||
empty_bf16(5, 4), # result_rms
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1) # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only,
|
||||
pm_pass)
|
||||
|
||||
|
||||
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
symmetric=True):
|
||||
key = FusedRMSQuantKey(fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
scale=kStaticTensorScale,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, input: torch.Tensor,
|
||||
residual: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(RMS_ADD_OP,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon)
|
||||
at1 = auto_functionalized(self.QUANT_OP,
|
||||
result=result,
|
||||
input=at[1],
|
||||
scale=scale)
|
||||
|
||||
# result, residual
|
||||
return at1[1], at[2]
|
||||
|
||||
def replacement(result: torch.Tensor, input: torch.Tensor,
|
||||
residual: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon)
|
||||
|
||||
# result, residual
|
||||
return at[1], at[2]
|
||||
|
||||
inputs = [
|
||||
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(5, 4), # residual
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1) # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True):
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
scale=scale,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
|
||||
input: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at1 = auto_functionalized(RMS_OP,
|
||||
result=result_rms,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon)
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
result=result,
|
||||
input=at1[1],
|
||||
scale=scale,
|
||||
scale_ub=None)
|
||||
|
||||
# result, scale
|
||||
return at2[1], at2[2]
|
||||
|
||||
def replacement(result: torch.Tensor, result_rms: torch.Tensor,
|
||||
input: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=None)
|
||||
|
||||
# result, scale
|
||||
return at[1], at[2]
|
||||
|
||||
inputs = [
|
||||
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||
empty_bf16(5, 4), # result_rms
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1) # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True):
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
scale=scale,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, input: torch.Tensor,
|
||||
residual: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(RMS_ADD_OP,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon)
|
||||
at1 = auto_functionalized(self.QUANT_OP,
|
||||
result=result,
|
||||
input=at[1],
|
||||
scale=scale,
|
||||
scale_ub=None)
|
||||
|
||||
# result, residual, scale
|
||||
return at1[1], at[2], at1[2]
|
||||
|
||||
def replacement(result: torch.Tensor, input: torch.Tensor,
|
||||
residual: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=residual)
|
||||
|
||||
# result, residual, scale
|
||||
return at[1], at[3], at[2]
|
||||
|
||||
inputs = [
|
||||
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(5, 4), # residual
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1) # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
|
||||
It also supports fused_add_rms_norm.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rmsnorm_quant_fusion_pass")
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# Fuse rms_norm + static fp8 quant
|
||||
RMSNormStaticQuantPattern(epsilon,
|
||||
FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + static fp8 quant
|
||||
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns)
|
||||
|
||||
# Fuse rms_norm + dynamic per-token fp8 quant
|
||||
RMSNormDynamicQuantPattern(epsilon,
|
||||
FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
|
||||
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> Any:
|
||||
return self.hash_source(self, RMSNormQuantPattern,
|
||||
RMSNormStaticQuantPattern,
|
||||
RMSNormDynamicQuantPattern,
|
||||
FusedAddRMSNormStaticQuantPattern,
|
||||
FusedAddRMSNormDynamicQuantPattern)
|
||||
295
vllm/compilation/fusion_attn.py
Normal file
295
vllm/compilation/fusion_attn.py
Normal file
@@ -0,0 +1,295 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey, kNvfp4Quant, kStaticTensorScale)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import round_up
|
||||
|
||||
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
|
||||
RESHAPE_OP = torch.ops.aten.reshape.default
|
||||
|
||||
|
||||
class AttentionQuantPattern(ABC):
|
||||
"""
|
||||
The base class for Attn+Quant fusions.
|
||||
Should not be used directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: Attention,
|
||||
quant_key: QuantKey,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
self.layer = layer
|
||||
self.layer_name = layer.layer_name
|
||||
self.num_heads = layer.num_heads
|
||||
self.head_size = layer.head_size
|
||||
self.quant_key = quant_key
|
||||
self.quant_dtype = quant_key.dtype
|
||||
self.dtype = dtype
|
||||
|
||||
assert self.quant_key in QUANT_OPS, \
|
||||
f"unsupported quantization scheme {self.quant_key}"
|
||||
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
||||
|
||||
def empty(self, *args, **kwargs):
|
||||
kwargs = {'dtype': self.dtype, 'device': "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
def empty_quant(self, *args, **kwargs):
|
||||
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def wrap_trace_fn(process_fx, trace_fn):
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
return process_fx(trace_fn(*args, **kwargs))
|
||||
|
||||
return wrapped
|
||||
|
||||
@staticmethod
|
||||
def fx_view_to_reshape(gm: torch.fx.GraphModule):
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
view_to_reshape(gm)
|
||||
return gm
|
||||
|
||||
def register_if_supported(self, pm_pass: PatternMatcherPass):
|
||||
if self.layer.impl.fused_output_quant_supported(self.quant_key):
|
||||
self._register(pm_pass)
|
||||
|
||||
@abstractmethod
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
"""
|
||||
Fusion for Attention+Fp8StaticQuant.
|
||||
|
||||
Only triggers when the attention implementation returns True in
|
||||
`fused_output_quant_supported()`. If the pattern is found, the
|
||||
Fp8StaticQuant op will be removed from the graph, and its scale
|
||||
will be passed into Attention op as the `output_scale` argument.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: Attention,
|
||||
dtype: torch.dtype,
|
||||
symmetric: bool = True,
|
||||
):
|
||||
quant_key = QuantKey(dtype=FP8_DTYPE,
|
||||
scale=kStaticTensorScale,
|
||||
symmetric=symmetric)
|
||||
super().__init__(layer, quant_key, dtype)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
output_attn: torch.Tensor, output_quant: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at1 = auto_functionalized(ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None)
|
||||
attn_out_view = RESHAPE_OP(
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size])
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
result=output_quant,
|
||||
input=attn_out_view,
|
||||
scale=scale)
|
||||
return at2[1]
|
||||
|
||||
def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
output_attn: torch.Tensor, output_quant: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
# attn output in quant_dtype
|
||||
output_attn = torch.ops.aten.full.default(
|
||||
[q.shape[0], self.num_heads, self.head_size],
|
||||
0.0,
|
||||
dtype=self.quant_dtype,
|
||||
device=q.device)
|
||||
at1 = auto_functionalized(ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=scale,
|
||||
output_block_scale=None)
|
||||
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
||||
|
||||
inputs = [
|
||||
self.empty(5, self.num_heads, self.head_size,
|
||||
dtype=self.dtype), # q
|
||||
self.empty(5, self.num_heads, self.head_size,
|
||||
dtype=self.dtype), # k
|
||||
self.empty(5, self.num_heads, self.head_size,
|
||||
dtype=self.dtype), # v
|
||||
self.empty(5, self.num_heads, self.head_size,
|
||||
dtype=self.dtype), # attn_output
|
||||
self.empty_quant(5,
|
||||
self.num_heads * self.head_size), # quant_output
|
||||
empty_fp32(1, 1) # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
|
||||
pm_pass)
|
||||
|
||||
|
||||
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
"""
|
||||
Fusion for Attention+Nvfp4Quant.
|
||||
|
||||
Only triggers when the attention implementation returns True in
|
||||
`fused_output_quant_supported()`. If the pattern is found, the
|
||||
Nvfp4Quant op will be removed from the graph, and its scale
|
||||
will be passed into Attention op as the `output_scale` argument.
|
||||
"""
|
||||
|
||||
def __init__(self, layer: Attention, dtype: torch.dtype):
|
||||
super().__init__(layer, kNvfp4Quant, dtype)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
output_attn: torch.Tensor, output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor, input_scale: torch.Tensor):
|
||||
at1 = auto_functionalized(ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None)
|
||||
attn_out_view = RESHAPE_OP(
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size])
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
output=output_quant,
|
||||
input=attn_out_view,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_scale)
|
||||
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
|
||||
return at2[1], output_scale_view
|
||||
|
||||
def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
output_attn: torch.Tensor, output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor, input_scale: torch.Tensor):
|
||||
# attention output in quant_dtype
|
||||
output_attn = torch.ops.aten.full.default(
|
||||
[q.shape[0], self.num_heads, self.head_size // 2],
|
||||
0.0,
|
||||
dtype=self.quant_dtype,
|
||||
device=q.device)
|
||||
# attention output block scale
|
||||
output_scale_view = torch.ops.aten.view.dtype(
|
||||
output_scale, FP8_DTYPE)
|
||||
at2 = auto_functionalized(ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=input_scale,
|
||||
output_block_scale=output_scale_view)
|
||||
output = RESHAPE_OP(at2[1],
|
||||
[-1, self.num_heads * self.head_size // 2])
|
||||
return output, at2[2]
|
||||
|
||||
inputs = [
|
||||
empty_bf16(5, self.num_heads, self.head_size), # q
|
||||
empty_bf16(5, self.num_heads, self.head_size), # k
|
||||
empty_bf16(5, self.num_heads, self.head_size), # v
|
||||
empty_bf16(5, self.num_heads, self.head_size), # output_attn
|
||||
self.empty_quant(5, self.num_heads * self.head_size //
|
||||
2), # output_quant
|
||||
empty_i32(128, round_up(self.num_heads * self.head_size // 16,
|
||||
4)), # output_scale
|
||||
empty_fp32(1, 1), # input_scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
|
||||
pm_pass)
|
||||
|
||||
|
||||
class AttnFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses post-attention quantization onto attention if supported.
|
||||
|
||||
It uses the pattern matcher and matches each layer manually, as strings
|
||||
cannot be wildcarded. This also lets us check support on attention layers
|
||||
upon registration instead of during pattern matching.
|
||||
|
||||
Currently, only static fp8 quant is supported, but patterns could easily be
|
||||
added for other quant schemes and dtypes. The bigger hurdle for wider
|
||||
support are attention kernels, which need to support fusing output quant.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
|
||||
|
||||
attn_layers = get_layers_from_vllm_config(config, Attention)
|
||||
for layer_name, layer in attn_layers.items():
|
||||
pattern_fp8 = AttentionFp8StaticQuantPattern(
|
||||
layer, config.model_config.dtype)
|
||||
pattern_fp8.register_if_supported(self.patterns)
|
||||
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C,
|
||||
"scaled_fp4_quant"):
|
||||
pattern_nvfp4 = AttentionNvfp4QuantPattern(
|
||||
layer, config.model_config.dtype)
|
||||
pattern_nvfp4.register_if_supported(self.patterns)
|
||||
|
||||
if len(attn_layers) == 0:
|
||||
logger.warning(
|
||||
"Attention + quant fusion is enabled, but no attention layers "
|
||||
"were found in CompilationConfig.static_forward_context "
|
||||
"so no fusion patterns were registered.")
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Fused quant onto %s attention nodes", self.matched_count)
|
||||
|
||||
def uuid(self):
|
||||
return VllmInductorPass.hash_source(self, AttentionQuantPattern,
|
||||
AttentionFp8StaticQuantPattern,
|
||||
AttentionNvfp4QuantPattern)
|
||||
84
vllm/compilation/fx_utils.py
Normal file
84
vllm/compilation/fx_utils.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import operator
|
||||
from collections.abc import Iterable, Iterator
|
||||
from typing import Optional
|
||||
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._ops import OpOverload
|
||||
|
||||
|
||||
def is_func(node: fx.Node, target) -> bool:
|
||||
return node.op == "call_function" and node.target == target
|
||||
|
||||
|
||||
def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
|
||||
return is_func(node, auto_functionalized) and node.args[0] == op
|
||||
|
||||
|
||||
# Returns the first specified node with the given op (if it exists)
|
||||
def find_specified_fn_maybe(nodes: Iterable[fx.Node],
|
||||
op: OpOverload) -> Optional[fx.Node]:
|
||||
for node in nodes:
|
||||
if node.target == op:
|
||||
return node
|
||||
return None
|
||||
|
||||
|
||||
# Returns the first specified node with the given op
|
||||
def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
||||
node = find_specified_fn_maybe(nodes, op)
|
||||
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
||||
return node
|
||||
|
||||
|
||||
# Returns the first auto_functionalized node with the given op (if it exists)
|
||||
def find_auto_fn_maybe(nodes: Iterable[fx.Node],
|
||||
op: OpOverload) -> Optional[fx.Node]:
|
||||
for node in nodes:
|
||||
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
|
||||
return node
|
||||
return None
|
||||
|
||||
|
||||
# Returns the first auto_functionalized node with the given op
|
||||
def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
||||
node = find_auto_fn_maybe(nodes, op)
|
||||
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
||||
return node
|
||||
|
||||
|
||||
# Returns the getitem node that extracts the idx-th element from node
|
||||
# (if it exists)
|
||||
def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]:
|
||||
for user in node.users:
|
||||
if is_func(user, operator.getitem) and user.args[1] == idx:
|
||||
return user
|
||||
return None
|
||||
|
||||
|
||||
# Returns the getitem node that extracts the idx-th element from node
|
||||
def find_getitem(node: fx.Node, idx: int) -> fx.Node:
|
||||
ret = find_getitem_maybe(node, idx)
|
||||
assert ret is not None, f"Could not find getitem {idx} in node {node}"
|
||||
return ret
|
||||
|
||||
|
||||
# An auto-functionalization-aware utility for finding nodes with a specific op
|
||||
def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]:
|
||||
if not op._schema.is_mutable:
|
||||
yield from graph.find_nodes(op="call_function", target=op)
|
||||
|
||||
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
|
||||
if n.args[0] == op:
|
||||
yield n
|
||||
|
||||
|
||||
# Asserts that the node only has one user and returns it
|
||||
# Even if a node has only 1 user, it might share storage with another node,
|
||||
# which might need to be taken into account.
|
||||
def get_only_user(node: fx.Node) -> fx.Node:
|
||||
assert len(node.users) == 1
|
||||
return next(iter(node.users))
|
||||
136
vllm/compilation/inductor_pass.py
Normal file
136
vllm/compilation/inductor_pass.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
from torch._subclasses.fake_tensor import (FakeTensorMode,
|
||||
unset_fake_temporarily)
|
||||
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
from torch._inductor.custom_graph_pass import CustomGraphPass
|
||||
else:
|
||||
# CustomGraphPass is not present in 2.5 or lower, import our version
|
||||
from .torch25_custom_graph_pass import ( # noqa: E501
|
||||
Torch25CustomGraphPass as CustomGraphPass)
|
||||
|
||||
_pass_context = None
|
||||
|
||||
|
||||
class PassContext:
|
||||
|
||||
def __init__(self, runtime_shape: Optional[int]):
|
||||
self.runtime_shape = runtime_shape
|
||||
|
||||
|
||||
def get_pass_context() -> PassContext:
|
||||
"""Get the current pass context."""
|
||||
assert _pass_context is not None
|
||||
return _pass_context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def pass_context(runtime_shape: Optional[int]):
|
||||
"""A context manager that stores the current pass context,
|
||||
usually it is a list of sizes to specialize.
|
||||
"""
|
||||
global _pass_context
|
||||
prev_context = _pass_context
|
||||
_pass_context = PassContext(runtime_shape)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_pass_context = prev_context
|
||||
|
||||
|
||||
class InductorPass(CustomGraphPass):
|
||||
"""
|
||||
A custom graph pass that uses a hash of its source as the UUID.
|
||||
This is defined as a convenience and should work in most cases.
|
||||
"""
|
||||
|
||||
def uuid(self) -> Any:
|
||||
"""
|
||||
Provide a unique identifier for the pass, used in Inductor code cache.
|
||||
This should depend on the pass implementation, so that changes to the
|
||||
pass result in recompilation.
|
||||
By default, the object source is hashed.
|
||||
"""
|
||||
return InductorPass.hash_source(self)
|
||||
|
||||
@staticmethod
|
||||
def hash_source(*srcs: Union[str, Any]):
|
||||
"""
|
||||
Utility method to hash the sources of functions or objects.
|
||||
:param srcs: strings or objects to add to the hash.
|
||||
Objects and functions have their source inspected.
|
||||
:return:
|
||||
"""
|
||||
hasher = hashlib.sha256()
|
||||
for src in srcs:
|
||||
if isinstance(src, str):
|
||||
src_str = src
|
||||
elif isinstance(src, (types.FunctionType, type)):
|
||||
src_str = inspect.getsource(src)
|
||||
else:
|
||||
# object instance
|
||||
src_str = inspect.getsource(src.__class__)
|
||||
hasher.update(src_str.encode("utf-8"))
|
||||
return hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def hash_dict(dict_: dict[Any, Any]):
|
||||
"""
|
||||
Utility method to hash a dictionary, can alternatively be used for uuid.
|
||||
:return: A sha256 hash of the json rep of the dictionary.
|
||||
"""
|
||||
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
||||
return hashlib.sha256(encoded).hexdigest()
|
||||
|
||||
def is_applicable_for_shape(self, shape: Optional[int]):
|
||||
return True
|
||||
|
||||
|
||||
class CallableInductorPass(InductorPass):
|
||||
"""
|
||||
This class is a wrapper for a callable that automatically provides an
|
||||
implementation of the UUID.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
callable: Callable[[fx.Graph], None],
|
||||
uuid: Optional[Any] = None):
|
||||
self.callable = callable
|
||||
self._uuid = self.hash_source(callable) if uuid is None else uuid
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.callable(graph)
|
||||
|
||||
def uuid(self) -> Any:
|
||||
return self._uuid
|
||||
|
||||
|
||||
def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""
|
||||
Applies a FakeTensorMode context. This is useful when you don't want to
|
||||
create or run things with real tensors.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def fn_new(*args, **kwargs) -> Any:
|
||||
with torch._guards.tracing(
|
||||
None), unset_fake_temporarily(), FakeTensorMode():
|
||||
result = fn(*args, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
return fn_new
|
||||
57
vllm/compilation/monitor.py
Normal file
57
vllm/compilation/monitor.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
context_manager = None
|
||||
torch_compile_start_time: float = 0.0
|
||||
|
||||
|
||||
def start_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
global torch_compile_start_time
|
||||
torch_compile_start_time = time.time()
|
||||
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE and \
|
||||
compilation_config.debug_dump_path:
|
||||
import depyf
|
||||
path = os.path.join(compilation_config.debug_dump_path,
|
||||
f"rank_{vllm_config.parallel_config.rank}")
|
||||
global context_manager
|
||||
context_manager = depyf.prepare_debug(path)
|
||||
context_manager.__enter__()
|
||||
|
||||
|
||||
def end_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
logger.info("torch.compile takes %.2f s in total",
|
||||
compilation_config.compilation_time)
|
||||
global context_manager
|
||||
if context_manager is not None:
|
||||
context_manager.__exit__(None, None, None)
|
||||
context_manager = None
|
||||
|
||||
|
||||
cudagraph_capturing_enabled: bool = True
|
||||
|
||||
|
||||
def validate_cudagraph_capturing_enabled():
|
||||
# used to monitor whether a cudagraph capturing is legal at runtime.
|
||||
# should be called before any cudagraph capturing.
|
||||
# if an illegal cudagraph capturing happens, raise an error.
|
||||
global cudagraph_capturing_enabled
|
||||
if not cudagraph_capturing_enabled:
|
||||
raise RuntimeError("CUDA graph capturing detected at an inappropriate "
|
||||
"time. This operation is currently disabled.")
|
||||
|
||||
|
||||
def set_cudagraph_capturing_enabled(enabled: bool):
|
||||
global cudagraph_capturing_enabled
|
||||
cudagraph_capturing_enabled = enabled
|
||||
158
vllm/compilation/noop_elimination.py
Normal file
158
vllm/compilation/noop_elimination.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
|
||||
import torch.fx
|
||||
from torch import SymInt
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .fx_utils import is_func
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class NoOpEliminationPass(VllmInductorPass):
|
||||
"""
|
||||
This is an inductor pass that removes redundant reshape/slice operations.
|
||||
It is required for RMSNorm-quant fusion to work properly.
|
||||
That's because apply_fp8_linear adds a reshape, which is redundant
|
||||
in the 2D-case. Additionally, torch internal no-op elimination pass does
|
||||
not handle certain slice variants.
|
||||
|
||||
Cases handled:
|
||||
1. A chain of reshapes is equivalent to the last reshape called on the
|
||||
base tensor (input of the first reshape).
|
||||
2. A reshape that produces the shape of the input is redundant
|
||||
3. A slice that produces the shape of the input is redundant
|
||||
|
||||
Example graph 1:
|
||||
mul_1: "f16[s0, 4096]" = ...
|
||||
view_1: "f16[s0, 128, 32]" = torch.reshape(mul_1, [-1, 128, 32])
|
||||
view_2: "f16[s0, 4096]" = torch.reshape(view_2, [-1, 4096])
|
||||
view_3: "f16[s0, 128, 32]" = torch.reshape(view_3, [-1, 128, 32])
|
||||
|
||||
Can be replaced with:
|
||||
mul_1: "f16[s0, 4096]" = ...
|
||||
view_3: "f16[s0, 128, 32]" = ...
|
||||
|
||||
Example graph 2:
|
||||
getitem_1: "f16[s0, 4096]" = ...
|
||||
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
|
||||
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
|
||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||
|
||||
Can be replaced with:
|
||||
getitem_1: "f16[s0, 4096]" = ...
|
||||
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
|
||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||
|
||||
Example graph 3:
|
||||
arg0: "s0" = SymInt(s0)
|
||||
scaled_mm: "f16[s0, 4096]" = ...
|
||||
slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0)
|
||||
at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...)
|
||||
out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0)
|
||||
|
||||
Can be replaced with:
|
||||
arg0: "s0" = SymInt(s0)
|
||||
scaled_mm: "f16[s0, 4096]" = ...
|
||||
at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
|
||||
out: "f16[s0, 4096]" = at[1]
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
count = 0
|
||||
# Remove no-op reshapes/views:
|
||||
for node in graph.nodes:
|
||||
if is_func(node, torch.ops.aten.reshape.default):
|
||||
# Case 1: rewrite reshape chains to reshapes on the base tensor
|
||||
input = node.args[0]
|
||||
# If the input is a reshape, rebind to that node
|
||||
if is_func(input, torch.ops.aten.reshape.default):
|
||||
# The new input is guaranteed not to be a reshape,
|
||||
# because we process nodes in order
|
||||
node.update_arg(0, input.args[0])
|
||||
if len(input.users) == 0:
|
||||
graph.erase_node(input)
|
||||
count += 1
|
||||
|
||||
# Case 2: remove this reshape if it produces the original shape
|
||||
input, shape = node.args[:2]
|
||||
input_shape = input.meta["val"].shape
|
||||
if len(shape) != len(input_shape):
|
||||
# Reshape changing rank, skip
|
||||
continue
|
||||
|
||||
if shape.count(-1) > 1:
|
||||
# Invalid reshape args, skip
|
||||
continue
|
||||
|
||||
if self.reshape_all_dims_equivalent(shape, input_shape):
|
||||
node.replace_all_uses_with(input)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
|
||||
elif is_func(node, torch.ops.aten.slice.Tensor):
|
||||
# python slicing semantics are different from reshape
|
||||
# Don't treat -1 as inferred dimension
|
||||
input, dim_index, start, end = node.args[:4]
|
||||
input_shape = input.meta["val"].shape
|
||||
output_shape = node.meta["val"].shape
|
||||
|
||||
if output_shape == input_shape:
|
||||
node.replace_all_uses_with(input)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
|
||||
elif is_func(node, torch.ops.aten.slice_scatter.default):
|
||||
base, view, dim_index, start, end = node.args[:5]
|
||||
base_shape = base.meta["val"].shape
|
||||
view_shape = view.meta["val"].shape
|
||||
|
||||
if base_shape == view_shape:
|
||||
node.replace_all_uses_with(view)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
|
||||
logger.debug("Removed %s no-op reshapes and slices", count)
|
||||
|
||||
# ---------------------- Reshape helpers ----------------------
|
||||
def reshape_dims_equivalent(self, dim: Union[int, torch.fx.Node],
|
||||
i_dim: Union[int, SymInt]) -> bool:
|
||||
"""
|
||||
This function checks if two dimensions are equivalent.
|
||||
:param dim: The dimension arg to reshape/slice
|
||||
:param i_dim: The corresponding dimension in the input tensor
|
||||
:return: Are the dimensions equivalent?
|
||||
|
||||
There are three cases in which the dimensions are equivalent:
|
||||
1. The dimensions are equal (both integers)
|
||||
2. The reshape dimension is -1 (i.e. inferred)
|
||||
3. The dimensions both correspond to the same SymInt
|
||||
|
||||
While case 2 does not guarantee the dimensions are equal,
|
||||
they are equal if all other dimensions are equal.
|
||||
|
||||
In case 3, the reshape dimension is a torch.fx.Node,
|
||||
and its value is a SymInt. That value is equal to the
|
||||
input dimension.
|
||||
"""
|
||||
# Case 1 and 2
|
||||
if dim == i_dim or dim == -1:
|
||||
return True
|
||||
# Case 3
|
||||
return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim
|
||||
|
||||
def reshape_all_dims_equivalent(
|
||||
self,
|
||||
dims: Iterable[Union[int, torch.fx.Node]],
|
||||
i_dims: Iterable[Union[int, SymInt]],
|
||||
) -> bool:
|
||||
return all(
|
||||
self.reshape_dims_equivalent(s, i_s)
|
||||
for s, i_s in zip(dims, i_dims))
|
||||
125
vllm/compilation/pass_manager.py
Normal file
125
vllm/compilation/pass_manager.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
|
||||
from torch import fx as fx
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import set_env_var
|
||||
|
||||
from .post_cleanup import PostCleanupPass
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .activation_quant_fusion import ActivationQuantFusionPass
|
||||
from .fusion import RMSNormQuantFusionPass
|
||||
from .fusion_attn import AttnFusionPass
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from .collective_fusion import AllReduceFusionPass, AsyncTPPass
|
||||
|
||||
from .fix_functionalization import FixFunctionalizationPass
|
||||
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
|
||||
from .noop_elimination import NoOpEliminationPass
|
||||
from .sequence_parallelism import SequenceParallelismPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def with_pattern_match_debug(fn):
|
||||
"""
|
||||
Function decorator that turns on inductor pattern match debug
|
||||
for the duration of the call.
|
||||
Used to avoid logging builtin Inductor pattern matching.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
|
||||
# optionally check rank here
|
||||
with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
|
||||
return fn(*args, **kwargs)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class PostGradPassManager(CustomGraphPass):
|
||||
"""
|
||||
The pass manager for post-grad passes.
|
||||
It handles configuration, adding custom passes, and running passes.
|
||||
It supports uuid for the Inductor code cache. That includes torch<2.6
|
||||
support using pickling (in .inductor_pass.CustomGraphPass).
|
||||
|
||||
The order of the post-grad post-passes is:
|
||||
1. passes (constructor parameter)
|
||||
2. default passes (NoopEliminationPass, FusionPass)
|
||||
3. config["post_grad_custom_post_pass"] (if it exists)
|
||||
4. fix_functionalization
|
||||
This way, all passes operate on a functionalized graph.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.passes: list[InductorPass] = []
|
||||
|
||||
@with_pattern_match_debug
|
||||
def __call__(self, graph: fx.Graph):
|
||||
VllmInductorPass.dump_prefix = 0 # reset dump index
|
||||
|
||||
shape = get_pass_context().runtime_shape
|
||||
for pass_ in self.passes:
|
||||
if pass_.is_applicable_for_shape(shape):
|
||||
pass_(graph)
|
||||
VllmInductorPass.dump_prefix += 1
|
||||
|
||||
# post-cleanup goes before fix_functionalization
|
||||
# because it requires a functional graph
|
||||
self.post_cleanup(graph)
|
||||
VllmInductorPass.dump_prefix += 1
|
||||
|
||||
# always run fix_functionalization last
|
||||
self.fix_functionalization(graph)
|
||||
VllmInductorPass.dump_prefix = None # Cleanup index
|
||||
|
||||
def configure(self, config: VllmConfig):
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
if self.pass_config.enable_noop:
|
||||
self.passes += [NoOpEliminationPass(config)]
|
||||
|
||||
if self.pass_config.enable_sequence_parallelism:
|
||||
self.passes += [SequenceParallelismPass(config)]
|
||||
if self.pass_config.enable_async_tp:
|
||||
self.passes += [AsyncTPPass(config)]
|
||||
|
||||
if self.pass_config.enable_fi_allreduce_fusion:
|
||||
self.passes += [AllReduceFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_fusion:
|
||||
self.passes += [RMSNormQuantFusionPass(config)]
|
||||
self.passes += [ActivationQuantFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_attn_fusion:
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
|
||||
# needs a functional graph
|
||||
self.post_cleanup = PostCleanupPass(config)
|
||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||
|
||||
def add(self, pass_: InductorPass):
|
||||
assert isinstance(pass_, InductorPass)
|
||||
self.passes.append(pass_)
|
||||
|
||||
def uuid(self):
|
||||
"""
|
||||
The PostGradPassManager is set as a custom pass in the Inductor and
|
||||
affects compilation caching. Its uuid depends on the UUIDs of all
|
||||
dependent passes and the pass config. See InductorPass for more info.
|
||||
"""
|
||||
state = {"pass_config": self.pass_config.uuid(), "passes": []}
|
||||
for pass_ in self.passes:
|
||||
state["passes"].append(pass_.uuid())
|
||||
state["passes"].append(self.fix_functionalization.uuid())
|
||||
return InductorPass.hash_dict(state)
|
||||
20
vllm/compilation/post_cleanup.py
Normal file
20
vllm/compilation/post_cleanup.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from torch import fx
|
||||
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
|
||||
class PostCleanupPass(VllmInductorPass):
|
||||
"""
|
||||
This pass performs cleanup after custom passes.
|
||||
It topologically sorts the graph and removes unused nodes.
|
||||
This is needed because the pattern matcher does not guarantee producing
|
||||
a topologically sorted graph, and there may be unused nodes left around.
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
from torch._inductor.pattern_matcher import stable_topological_sort
|
||||
stable_topological_sort(graph)
|
||||
graph.eliminate_dead_code()
|
||||
478
vllm/compilation/sequence_parallelism.py
Normal file
478
vllm/compilation/sequence_parallelism.py
Normal file
@@ -0,0 +1,478 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _RMSNormAndQuantOpHelper:
|
||||
"""Base helper for RMSNorm and RMSNorm + Quantization functionalization."""
|
||||
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
quant_op: Optional[torch._ops.OpOverload] = None,
|
||||
**kwargs):
|
||||
self.epsilon = epsilon
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.quant_op = quant_op
|
||||
|
||||
def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor):
|
||||
return torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.rms_norm.default,
|
||||
result=result_buffer,
|
||||
input=input_tensor,
|
||||
weight=weight_tensor,
|
||||
epsilon=self.epsilon)
|
||||
|
||||
def _functional_fused_add_rmsnorm(self, input_tensor, residual_tensor,
|
||||
weight_tensor):
|
||||
return torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
input=input_tensor,
|
||||
residual=residual_tensor,
|
||||
weight=weight_tensor,
|
||||
epsilon=self.epsilon)
|
||||
|
||||
def _functional_rmsnorm_then_quant(self, rmsnorm_result_buffer,
|
||||
quant_result_buffer, input_tensor,
|
||||
weight_tensor, scale_tensor):
|
||||
if self.quant_op is None:
|
||||
raise RuntimeError(
|
||||
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
|
||||
)
|
||||
rmsnorm_out_tuple = self._functional_rmsnorm(rmsnorm_result_buffer,
|
||||
input_tensor,
|
||||
weight_tensor)
|
||||
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
|
||||
self.quant_op,
|
||||
result=quant_result_buffer,
|
||||
input=rmsnorm_out_tuple[1],
|
||||
scale=scale_tensor)
|
||||
return quant_out_tuple
|
||||
|
||||
def _functional_fused_add_rmsnorm_then_quant(self, quant_result_buffer,
|
||||
input_tensor, residual_tensor,
|
||||
weight_tensor, scale_tensor):
|
||||
if self.quant_op is None:
|
||||
raise RuntimeError(
|
||||
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
|
||||
)
|
||||
fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm(
|
||||
input_tensor, residual_tensor, weight_tensor)
|
||||
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
|
||||
self.quant_op,
|
||||
result=quant_result_buffer,
|
||||
input=fused_add_rmsnorm_out_tuple[1],
|
||||
scale=scale_tensor)
|
||||
return quant_out_tuple, fused_add_rmsnorm_out_tuple[2]
|
||||
|
||||
|
||||
class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
|
||||
"""Helper for sequence parallelism patterns."""
|
||||
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
quant_op: Optional[torch._ops.OpOverload] = None,
|
||||
**kwargs):
|
||||
super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs)
|
||||
self.tp_group = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.reduce_scatter.default(
|
||||
x,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp_group.unique_name)
|
||||
|
||||
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.all_gather.default(
|
||||
x,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp_group.unique_name)
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def get_inputs(self):
|
||||
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [input, permute, arg3_1]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
permute: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
):
|
||||
all_reduce = self._all_reduce(input)
|
||||
rmsnorm = self._functional_rmsnorm(permute, all_reduce, arg3_1)
|
||||
|
||||
return rmsnorm[1], all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
permute: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
):
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
|
||||
rmsnorm_result = torch.empty_like(reduce_scatter)
|
||||
rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter,
|
||||
arg3_1)
|
||||
|
||||
all_gather = self._all_gather(rmsnorm[1])
|
||||
|
||||
return all_gather, reduce_scatter
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
|
||||
return [
|
||||
residual,
|
||||
mm_1,
|
||||
rms_norm_weights,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
all_reduce, residual, rms_norm_weights)
|
||||
return rmsnorm[1], rmsnorm[2]
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
reduce_scatter, residual, rms_norm_weights)
|
||||
all_gather = self._all_gather(rmsnorm[1])
|
||||
return all_gather, rmsnorm[2]
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
|
||||
return [
|
||||
residual,
|
||||
mm_1,
|
||||
rms_norm_weights,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
all_reduce, residual, rms_norm_weights)
|
||||
return rmsnorm[1]
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
reduce_scatter, residual, rms_norm_weights)
|
||||
normalized = self._all_gather(rmsnorm[1])
|
||||
return normalized
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||
op: torch._ops.OpOverload):
|
||||
super().__init__(epsilon, dtype, device, quant_op=op)
|
||||
|
||||
def get_inputs(self):
|
||||
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
rmsnorm_result = torch.empty([1, 8, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
quant_result = torch.empty([1, 8, 4],
|
||||
device=self.device,
|
||||
dtype=FP8_DTYPE)
|
||||
weight = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
|
||||
return [input, rmsnorm_result, quant_result, weight, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
rmsnorm_result: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
all_reduce = self._all_reduce(input)
|
||||
static_fp8 = self._functional_rmsnorm_then_quant(
|
||||
rmsnorm_result, quant_result, all_reduce, weight, scale)
|
||||
return static_fp8[1], all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
rmsnorm_result: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
|
||||
rmsnorm_result = torch.empty_like(reduce_scatter,
|
||||
dtype=rmsnorm_result.dtype)
|
||||
quant_result = torch.empty_like(
|
||||
rmsnorm_result, # Output of RMSNorm
|
||||
dtype=quant_result.dtype)
|
||||
static_fp8 = self._functional_rmsnorm_then_quant(
|
||||
rmsnorm_result, quant_result, reduce_scatter, weight, scale)
|
||||
all_gather = self._all_gather(static_fp8[1])
|
||||
|
||||
return all_gather, reduce_scatter
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||
op: torch._ops.OpOverload):
|
||||
super().__init__(epsilon, dtype, device, quant_op=op)
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
|
||||
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [
|
||||
result,
|
||||
residual,
|
||||
mm_1,
|
||||
rms_norm_weights,
|
||||
scale,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
|
||||
result, all_reduce, residual, rms_norm_weights, scale)
|
||||
return static_fp8[1], rmsnorm_residual_out
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
quant_result_buf = torch.empty_like(reduce_scatter,
|
||||
dtype=result.dtype)
|
||||
static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
|
||||
quant_result_buf, reduce_scatter, residual, rms_norm_weights,
|
||||
scale)
|
||||
all_gather = self._all_gather(static_fp8[1])
|
||||
return all_gather, rmsnorm_residual_out
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||
op: torch._ops.OpOverload):
|
||||
super().__init__(epsilon, dtype, device, quant_op=op)
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
|
||||
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [
|
||||
result,
|
||||
residual,
|
||||
mm_1,
|
||||
rms_norm_weights,
|
||||
scale,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
|
||||
result, all_reduce, residual, rms_norm_weights, scale)
|
||||
return static_fp8[1]
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
quant_result_buf = torch.empty_like(reduce_scatter,
|
||||
dtype=result.dtype)
|
||||
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
|
||||
quant_result_buf, reduce_scatter, residual, rms_norm_weights,
|
||||
scale)
|
||||
normalized = self._all_gather(static_fp8[1])
|
||||
return normalized
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass enables sequence parallelism for models.
|
||||
It identifies patterns where an AllReduce operation is followed by
|
||||
an RMSNorm (or RMSNorm and then Quantization) operation.
|
||||
These patterns are replaced with a ReduceScatter operation, followed by
|
||||
a local RMSNorm/Quantization, and then an AllGather operation.
|
||||
|
||||
The general transformation is:
|
||||
Input -> AllReduce -> RMSNorm -> Output
|
||||
becomes
|
||||
Input -> ReduceScatter -> RMSNorm -> AllGather -> Output
|
||||
|
||||
While this pass itself does not directly yield performance improvements,
|
||||
it lays the groundwork for subsequent fusion passes, such as
|
||||
GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
|
||||
significantly reduce communication overhead and improve overall model
|
||||
performance.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="sequence_parallelism_pass")
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# RMSNorm + Static FP8 quantization patterns
|
||||
fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
|
||||
FirstAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device,
|
||||
fp8_quant_op).register(self.patterns)
|
||||
MiddleAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device,
|
||||
fp8_quant_op).register(self.patterns)
|
||||
LastAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device,
|
||||
fp8_quant_op).register(self.patterns)
|
||||
|
||||
# Normal RMSNorm patterns
|
||||
FirstAllReduceRMSNormPattern(epsilon, self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
|
||||
MiddleAllReduceRMSNormPattern(epsilon, self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
|
||||
LastAllReduceRMSNormPattern(epsilon, self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return shape is not None and shape % tp_size == 0
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
42
vllm/compilation/torch25_custom_graph_pass.py
Normal file
42
vllm/compilation/torch25_custom_graph_pass.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Torch25CustomGraphPass(ABC): # noqa (redefinition)
|
||||
"""
|
||||
This class replaces CustomGraphPass from torch==2.6 when using torch<2.6.
|
||||
It conforms to the 2.6 interface but also supports pickling, as that's what
|
||||
the inductor code cache uses to determine the cache key before 2.6.
|
||||
(in 2.6 and above, uuid() is used.)
|
||||
|
||||
Subclasses can just "pretend" that uuid is used.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||
"""
|
||||
Implementation of the custom pass.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def uuid(self) -> Optional[Any]:
|
||||
"""
|
||||
Return an ID to uniquely identify your custom pass implementation.
|
||||
Return None to skip inductor code caching entirely.
|
||||
"""
|
||||
|
||||
def __getstate__(self):
|
||||
"""
|
||||
Pickling is used instead of uuid() in torch<2.6. Just return uuid()
|
||||
to enable subclasses to only have to implement uuid.
|
||||
"""
|
||||
return self.uuid()
|
||||
|
||||
def __setstate__(self, state):
|
||||
raise ValueError("Cannot unpickle CustomGraphPass because pickling"
|
||||
" is used for cache key uuid. Use torch>=2.6 with"
|
||||
" native uuid support for custom passes.")
|
||||
156
vllm/compilation/vllm_inductor_pass.py
Normal file
156
vllm/compilation/vllm_inductor_pass.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
import operator
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
from torch._inductor.pattern_matcher import (PatternMatcherPass,
|
||||
PatternPrettyPrinter)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .inductor_pass import InductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class VllmInductorPass(InductorPass):
|
||||
"""
|
||||
An inductor pass with access to vLLM PassConfig.
|
||||
It provides timing, logging, and dumping utilities.
|
||||
"""
|
||||
dump_prefix: ClassVar[Optional[int]] = None
|
||||
"""Keep track of pass index for debug dump ordering."""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
self.model_dtype = config.model_config.dtype if config.model_config \
|
||||
else None
|
||||
self.device = config.device_config.device if config.device_config \
|
||||
else None
|
||||
self.pass_name = self.__class__.__name__
|
||||
|
||||
@staticmethod
|
||||
def time_and_log(call_fn):
|
||||
|
||||
@functools.wraps(call_fn)
|
||||
def wrapped(self: VllmInductorPass, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before")
|
||||
call_fn(self, graph)
|
||||
self.dump_graph(graph, "after")
|
||||
self.end_and_log()
|
||||
|
||||
return wrapped
|
||||
|
||||
def dump_graph(self, graph: torch.fx.Graph, stage: str):
|
||||
i = VllmInductorPass.dump_prefix
|
||||
i_str = "" if i is None else f".{i}"
|
||||
lazy_format_graph_code(f"post_grad{i_str}.{self.pass_name}.{stage}",
|
||||
graph.owning_module)
|
||||
|
||||
def begin(self):
|
||||
self._start_time = time.perf_counter_ns()
|
||||
|
||||
def end_and_log(self):
|
||||
self._end_time = time.perf_counter_ns()
|
||||
duration_ms = float(self._end_time - self._start_time) / 1.0e6
|
||||
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
|
||||
|
||||
|
||||
class VllmPatternMatcherPass(VllmInductorPass):
|
||||
"""
|
||||
A VllmInductorPass that uses the Inductor pattern matcher.
|
||||
Its main use is providing the dump_patterns utility that dumps the
|
||||
Inductor pattern matcher patterns into a file, which greatly aids debugging.
|
||||
|
||||
TODO(luka) move more utilities to this pass.
|
||||
"""
|
||||
matched_count: int = 0
|
||||
"""The number of matched patterns in the pass."""
|
||||
|
||||
_OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile(
|
||||
r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>")
|
||||
|
||||
def _replace_op_overloads(self, string: str) -> str:
|
||||
"""Replace <OpOverload(..., ...)> with nicer formulations"""
|
||||
return self._OP_OVERLOAD_PATTERN.sub(
|
||||
lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
|
||||
string,
|
||||
)
|
||||
|
||||
def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass):
|
||||
"""
|
||||
If debug dumping is enabled, dump the Inductor pattern-matcher patterns
|
||||
into the debug_dump_path folder next to the dumped fx graphs.
|
||||
|
||||
This method does its best to print something that looks like Python code
|
||||
for easier debugging and potentially navigation. If any errors appear in
|
||||
the output, please add to this method.
|
||||
|
||||
TODO(luka): use pattern object to manually produce pattern graph
|
||||
"""
|
||||
debug_dump_path = config.compilation_config.debug_dump_path
|
||||
if not debug_dump_path:
|
||||
return
|
||||
|
||||
rank = config.parallel_config.rank
|
||||
debug_dump_path = Path(debug_dump_path) / f"rank_{rank}"
|
||||
debug_dump_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from vllm.utils import unique_filepath
|
||||
file_path = unique_filepath(
|
||||
lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py")
|
||||
|
||||
with file_path.open("w") as f:
|
||||
print(
|
||||
f'# This file was produced by VllmPatternMatcherPass.'
|
||||
f'dump_patterns for {self.pass_name}.\n'
|
||||
f'# It does its best to produce valid-Python-looking code but'
|
||||
f' please add to dump_patterns if there are any errors.\n\n'
|
||||
f'from torch._higher_order_ops.auto_functionalize import '
|
||||
f'auto_functionalized as auto_functionalized\n'
|
||||
f'from torch._inductor.pattern_matcher import *',
|
||||
file=f)
|
||||
|
||||
for node, patterns in pm_pass.patterns.items():
|
||||
# fix the operator.getitem repr
|
||||
if node[1] == operator.getitem:
|
||||
node_repr = f"({repr(node[0])}, operator.getitem)"
|
||||
else:
|
||||
node_repr = repr(node)
|
||||
|
||||
node_repr = self._replace_op_overloads(node_repr)
|
||||
|
||||
print(f"\n\n# Patterns for op: {node_repr}", file=f)
|
||||
for i, pattern in enumerate(patterns):
|
||||
# reserve auto_functionalized ahead of time
|
||||
pp = PatternPrettyPrinter()
|
||||
pp.namespace.create_name("auto_functionalized", None)
|
||||
|
||||
# Assemble pattern
|
||||
out_node = pp.pretty_print(pattern.pattern)
|
||||
pattern_repr = "\n".join([f"def pattern_{i}():"] + [
|
||||
f"{pp.memoized_objs_names[key]} = "
|
||||
f"{pp.memoized_objs_pp[key]}"
|
||||
for key in pp.memoized_objs_names
|
||||
] + [f"return {out_node}"]).replace("\n", "\n ")
|
||||
|
||||
pattern_repr = self._replace_op_overloads(pattern_repr)
|
||||
print(f"{pattern_repr}\n", file=f)
|
||||
|
||||
|
||||
class PrinterInductorPass(VllmInductorPass):
|
||||
|
||||
def __init__(self, name: str, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
self.name = name
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.dump_graph(graph, self.name)
|
||||
136
vllm/compilation/wrapper.py
Normal file
136
vllm/compilation/wrapper.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from types import CodeType
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import (CompilationLevel, CUDAGraphMode,
|
||||
get_current_vllm_config)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TorchCompileWrapperWithCustomDispatcher:
|
||||
"""
|
||||
A wrapper class for torch.compile, with a custom dispatch logic.
|
||||
Subclasses should:
|
||||
1. Implement the forward method
|
||||
2. Implement the dispatch logic in the __call__ method
|
||||
It can use `self.compiled_codes` to access the compiled bytecode,
|
||||
and `with self.dispatch_to_code(index):` to dispatch to
|
||||
the compiled code.
|
||||
3. Implement the `__init__` method to determine how to call
|
||||
`torch.compile` over the forward method.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
compiled_callable: Optional[Callable] = None,
|
||||
compilation_level: int = 0):
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.vllm_config = vllm_config
|
||||
if compiled_callable is None:
|
||||
# default compilation settings
|
||||
# compiling the forward method
|
||||
|
||||
backend = vllm_config.compilation_config.init_backend(vllm_config)
|
||||
options = None
|
||||
if isinstance(backend, str) and backend == "inductor":
|
||||
options = get_current_vllm_config(
|
||||
).compilation_config.inductor_compile_config
|
||||
|
||||
compiled_callable = torch.compile(self.forward,
|
||||
fullgraph=True,
|
||||
backend=backend,
|
||||
options=options)
|
||||
|
||||
self.compiled_callable = compiled_callable
|
||||
self.original_code_object = self.__class__.forward.__code__
|
||||
self.compiled_codes: list[CodeType] = []
|
||||
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
||||
|
||||
# read the env var to determine whether to use the custom dispatcher
|
||||
# subclasses can use this to switch between the custom dispatcher
|
||||
# and the default Dynamo guard mechanism.
|
||||
self.use_custom_dispatcher: bool = \
|
||||
compilation_level >= CompilationLevel.DYNAMO_ONCE
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Implement the dispatch logic here, beyond the torch.compile level.
|
||||
NOTE: this function can have additional arguments beyond the forward
|
||||
method, for directly dispatching to the compiled code.
|
||||
"""
|
||||
return self.compiled_callable(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs):
|
||||
...
|
||||
|
||||
def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
|
||||
"""Hook to save the compiled bytecode for direct execution."""
|
||||
if old_code is not self.original_code_object:
|
||||
return
|
||||
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
|
||||
frame = sys._getframe()
|
||||
while frame and frame.f_back:
|
||||
frame = frame.f_back
|
||||
code_name = frame.f_code.co_name
|
||||
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
|
||||
if code_name == "_compile" and file_name == "convert_frame.py":
|
||||
break
|
||||
frame = frame.f_locals["frame"]
|
||||
assert frame.f_code == old_code
|
||||
|
||||
if frame.f_locals["self"] is not self:
|
||||
return
|
||||
|
||||
self.compiled_codes.append(new_code)
|
||||
debug_dump_dir = self.vllm_config.compilation_config.debug_dump_path
|
||||
if isinstance(debug_dump_dir, str) and debug_dump_dir != "":
|
||||
rank = self.vllm_config.parallel_config.rank
|
||||
decompiled_file = os.path.join(debug_dump_dir, f"rank_{rank}",
|
||||
"transformed_code.py")
|
||||
if not os.path.exists(decompiled_file):
|
||||
try:
|
||||
# usually the decompilation will succeed for most models,
|
||||
# as we guarantee a full-graph compilation in Dynamo.
|
||||
# but there's no 100% guarantee, since decompliation is
|
||||
# not a reversible process.
|
||||
import depyf
|
||||
src = depyf.decompile(new_code)
|
||||
|
||||
with open(decompiled_file, "w") as f:
|
||||
f.write(src)
|
||||
|
||||
logger.debug("Dynamo transformed code saved to %s",
|
||||
decompiled_file)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if self.vllm_config.compilation_config.cudagraph_mode != \
|
||||
CUDAGraphMode.NONE and "update" in new_code.co_names:
|
||||
import depyf
|
||||
src = depyf.decompile(new_code)
|
||||
msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa
|
||||
raise RuntimeError(msg)
|
||||
|
||||
@contextmanager
|
||||
def dispatch_to_code(self, index: int):
|
||||
"""Context manager to dispatch to the compiled code.
|
||||
Why does this work? Because Dynamo guarantees that the compiled
|
||||
bytecode has exactly the same arguments, cell variables, and free
|
||||
variables as the original code. Therefore we can directly switch
|
||||
the code object in the function and call it.
|
||||
|
||||
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
|
||||
""" # noqa
|
||||
self.__class__.forward.__code__ = self.compiled_codes[index]
|
||||
yield
|
||||
self.__class__.forward.__code__ = self.original_code_object
|
||||
814
vllm/config/__init__.py
Normal file
814
vllm/config/__init__.py
Normal file
@@ -0,0 +1,814 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# ruff: noqa: F401
|
||||
import ast
|
||||
import copy
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import textwrap
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import field, fields, is_dataclass, replace
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import (TYPE_CHECKING, Any, Literal, Optional, Protocol, TypeVar,
|
||||
Union, cast)
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from pydantic import ConfigDict, SkipValidation
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing_extensions import runtime_checkable
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import version
|
||||
from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
|
||||
PrefixCachingHashAlgo)
|
||||
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
|
||||
CUDAGraphMode, PassConfig)
|
||||
from vllm.config.device import Device, DeviceConfig
|
||||
from vllm.config.kv_events import KVEventsConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.config.model import (ConvertOption, HfOverrides, LogprobsMode,
|
||||
ModelConfig, ModelDType, ModelImpl,
|
||||
RunnerOption, TaskOption, TokenizerMode,
|
||||
iter_architecture_defaults,
|
||||
try_match_architecture_defaults)
|
||||
from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
|
||||
MultiModalConfig)
|
||||
from vllm.config.observability import DetailedTraceModules, ObservabilityConfig
|
||||
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
|
||||
ParallelConfig)
|
||||
from vllm.config.pooler import PoolerConfig
|
||||
from vllm.config.scheduler import RunnerType, SchedulerConfig, SchedulerPolicy
|
||||
from vllm.config.speculative import SpeculativeConfig
|
||||
from vllm.config.speech_to_text import SpeechToTextConfig
|
||||
from vllm.config.structured_outputs import StructuredOutputsConfig
|
||||
from vllm.config.utils import ConfigType, config, get_attr_docs, is_init_field
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import DataclassInstance
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
else:
|
||||
DataclassInstance = Any
|
||||
PretrainedConfig = Any
|
||||
QuantizationConfig = Any
|
||||
QuantizationMethods = Any
|
||||
BaseModelLoader = Any
|
||||
LogitsProcessor = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
DataclassInstanceT = TypeVar("DataclassInstanceT", bound=DataclassInstance)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsHash(Protocol):
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
...
|
||||
|
||||
|
||||
class SupportsMetricsInfo(Protocol):
|
||||
|
||||
def metrics_info(self) -> dict[str, str]:
|
||||
...
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class VllmConfig:
|
||||
"""Dataclass which contains all vllm-related configuration. This
|
||||
simplifies passing around the distinct configurations in the codebase.
|
||||
"""
|
||||
|
||||
# TODO: use default_factory once default constructing ModelConfig doesn't
|
||||
# try to download a model
|
||||
model_config: ModelConfig = None # type: ignore
|
||||
"""Model configuration."""
|
||||
cache_config: CacheConfig = field(default_factory=CacheConfig)
|
||||
"""Cache configuration."""
|
||||
parallel_config: ParallelConfig = field(default_factory=ParallelConfig)
|
||||
"""Parallel configuration."""
|
||||
scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig)
|
||||
"""Scheduler configuration."""
|
||||
device_config: DeviceConfig = field(default_factory=DeviceConfig)
|
||||
"""Device configuration."""
|
||||
load_config: LoadConfig = field(default_factory=LoadConfig)
|
||||
"""Load configuration."""
|
||||
lora_config: Optional[LoRAConfig] = None
|
||||
"""LoRA configuration."""
|
||||
speculative_config: Optional[SpeculativeConfig] = None
|
||||
"""Speculative decoding configuration."""
|
||||
structured_outputs_config: StructuredOutputsConfig = field(
|
||||
default_factory=StructuredOutputsConfig)
|
||||
"""Structured outputs configuration."""
|
||||
observability_config: Optional[ObservabilityConfig] = None
|
||||
"""Observability configuration."""
|
||||
quant_config: Optional[QuantizationConfig] = None
|
||||
"""Quantization configuration."""
|
||||
compilation_config: CompilationConfig = field(
|
||||
default_factory=CompilationConfig)
|
||||
"""`torch.compile` and cudagraph capture configuration for the model.
|
||||
|
||||
As a shorthand, `-O<n>` can be used to directly specify the compilation
|
||||
level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`).
|
||||
Currently, -O <n> and -O=<n> are supported as well but this will likely be
|
||||
removed in favor of clearer -O<n> syntax in the future.
|
||||
|
||||
NOTE: level 0 is the default level without any optimization. level 1 and 2
|
||||
are for internal testing only. level 3 is the recommended level for
|
||||
production, also default in V1.
|
||||
|
||||
You can specify the full compilation config like so:
|
||||
`{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
|
||||
"""
|
||||
kv_transfer_config: Optional[KVTransferConfig] = None
|
||||
"""The configurations for distributed KV cache transfer."""
|
||||
kv_events_config: Optional[KVEventsConfig] = None
|
||||
"""The configurations for event publishing."""
|
||||
# some opaque config, only used to provide additional information
|
||||
# for the hash computation, mainly used for testing, debugging or out of
|
||||
# tree config registration.
|
||||
additional_config: Union[dict, SupportsHash] = field(default_factory=dict)
|
||||
"""Additional config for specified platform. Different platforms may
|
||||
support different configs. Make sure the configs are valid for the platform
|
||||
you are using. Contents must be hashable."""
|
||||
instance_id: str = ""
|
||||
"""The ID of the vLLM instance."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: list[Any] = []
|
||||
|
||||
# summarize vllm config
|
||||
vllm_factors: list[Any] = []
|
||||
from vllm import __version__
|
||||
vllm_factors.append(__version__)
|
||||
vllm_factors.append(envs.VLLM_USE_V1)
|
||||
if self.model_config:
|
||||
vllm_factors.append(self.model_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.cache_config:
|
||||
vllm_factors.append(self.cache_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.parallel_config:
|
||||
vllm_factors.append(self.parallel_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.scheduler_config:
|
||||
vllm_factors.append(self.scheduler_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.device_config:
|
||||
vllm_factors.append(self.device_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.load_config:
|
||||
vllm_factors.append(self.load_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.lora_config:
|
||||
vllm_factors.append(self.lora_config.compute_hash())
|
||||
# LoRA creates static buffers based on max_num_batched_tokens.
|
||||
# The tensor sizes and strides get captured in the torch.compile
|
||||
# graph explicitly.
|
||||
vllm_factors.append(
|
||||
str(self.scheduler_config.max_num_batched_tokens))
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.speculative_config:
|
||||
vllm_factors.append(self.speculative_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.structured_outputs_config:
|
||||
vllm_factors.append(self.structured_outputs_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.observability_config:
|
||||
vllm_factors.append(self.observability_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.quant_config:
|
||||
pass # should be captured by model_config.quantization
|
||||
if self.compilation_config:
|
||||
vllm_factors.append(self.compilation_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.kv_transfer_config:
|
||||
vllm_factors.append(self.kv_transfer_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.additional_config:
|
||||
if isinstance(additional_config := self.additional_config, dict):
|
||||
additional_config_hash = hashlib.md5(
|
||||
json.dumps(additional_config, sort_keys=True).encode(),
|
||||
usedforsecurity=False,
|
||||
).hexdigest()
|
||||
else:
|
||||
additional_config_hash = additional_config.compute_hash()
|
||||
vllm_factors.append(additional_config_hash)
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
factors.append(vllm_factors)
|
||||
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def pad_for_cudagraph(self, batch_size: int) -> int:
|
||||
# if batch_size > self.compilation_config.max_capture_size,
|
||||
# it should raise an IndexError.
|
||||
# the caller should make sure the batch_size is within the range,
|
||||
# i.e., batch_size <= self.compilation_config.max_capture_size
|
||||
return self.compilation_config.bs_to_padded_graph_size[batch_size]
|
||||
|
||||
@staticmethod
|
||||
def _get_quantization_config(
|
||||
model_config: ModelConfig,
|
||||
load_config: LoadConfig) -> Optional[QuantizationConfig]:
|
||||
"""Get the quantization config."""
|
||||
from vllm.platforms import current_platform
|
||||
if model_config.quantization is not None:
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
get_quant_config)
|
||||
quant_config = get_quant_config(model_config, load_config)
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
|
||||
if capability_tuple is not None:
|
||||
capability = capability_tuple.to_int()
|
||||
if capability < quant_config.get_min_capability():
|
||||
raise ValueError(
|
||||
f"The quantization method {model_config.quantization} "
|
||||
"is not supported for the current GPU. Minimum "
|
||||
f"capability: {quant_config.get_min_capability()}. "
|
||||
f"Current capability: {capability}.")
|
||||
supported_dtypes = quant_config.get_supported_act_dtypes()
|
||||
if model_config.dtype not in supported_dtypes:
|
||||
raise ValueError(
|
||||
f"{model_config.dtype} is not supported for quantization "
|
||||
f"method {model_config.quantization}. Supported dtypes: "
|
||||
f"{supported_dtypes}")
|
||||
quant_config.maybe_update_config(model_config.model)
|
||||
return quant_config
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_quantization_config(
|
||||
model_config: ModelConfig,
|
||||
load_config: LoadConfig) -> Optional[QuantizationConfig]:
|
||||
import copy
|
||||
|
||||
# For some reason, the _ version of this modifies the model_config
|
||||
# object, so using deepcopy to avoid this problem.
|
||||
return VllmConfig._get_quantization_config(copy.deepcopy(model_config),
|
||||
load_config)
|
||||
|
||||
def with_hf_config(
|
||||
self,
|
||||
hf_config: PretrainedConfig,
|
||||
architectures: Optional[list[str]] = None,
|
||||
) -> "VllmConfig":
|
||||
if architectures is not None:
|
||||
hf_config = copy.deepcopy(hf_config)
|
||||
hf_config.architectures = architectures
|
||||
|
||||
model_config = copy.deepcopy(self.model_config)
|
||||
model_config.hf_config = hf_config
|
||||
|
||||
return replace(self, model_config=model_config)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Verify configs are valid & consistent with each other.
|
||||
"""
|
||||
|
||||
self.try_verify_and_update_config()
|
||||
|
||||
if self.model_config is not None:
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
self.model_config.verify_dual_chunk_attention_config(
|
||||
self.load_config)
|
||||
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
|
||||
if self.lora_config is not None:
|
||||
self.lora_config.verify_with_cache_config(self.cache_config)
|
||||
self.lora_config.verify_with_model_config(self.model_config)
|
||||
|
||||
if self.quant_config is None and self.model_config is not None:
|
||||
self.quant_config = VllmConfig._get_quantization_config(
|
||||
self.model_config, self.load_config)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
if self.model_config is not None and \
|
||||
self.scheduler_config.chunked_prefill_enabled and \
|
||||
self.model_config.dtype == torch.float32 and \
|
||||
current_platform.get_device_capability() == (7, 5):
|
||||
logger.warning_once(
|
||||
"Turing devices tensor cores do not support float32 matmul. "
|
||||
"To workaround this limitation, vLLM will set 'ieee' input "
|
||||
"precision for chunked prefill triton kernels.")
|
||||
|
||||
# If the user does not explicitly set a compilation level, then
|
||||
# we use the default level. The default level depends on other
|
||||
# settings (see the below code).
|
||||
if self.compilation_config.level is None:
|
||||
if envs.VLLM_USE_V1:
|
||||
if (self.model_config is not None
|
||||
and not self.model_config.enforce_eager):
|
||||
self.compilation_config.level = CompilationLevel.PIECEWISE
|
||||
else:
|
||||
self.compilation_config.level = \
|
||||
CompilationLevel.NO_COMPILATION
|
||||
|
||||
else:
|
||||
# NB: Passing both --enforce-eager and a compilation level
|
||||
# in V0 means the compilation level wins out.
|
||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
||||
# async tp is built on top of sequence parallelism
|
||||
# and requires it to be enabled.
|
||||
if self.compilation_config.pass_config.enable_async_tp:
|
||||
self.compilation_config.pass_config.enable_sequence_parallelism = \
|
||||
True
|
||||
if self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||
self.compilation_config.custom_ops.append("+rms_norm")
|
||||
|
||||
if current_platform.support_static_graph_mode():
|
||||
# if cudagraph_mode is not explicitly set by users, set default
|
||||
# value
|
||||
if self.compilation_config.cudagraph_mode is None:
|
||||
if envs.VLLM_USE_V1 and self.compilation_config.level \
|
||||
== CompilationLevel.PIECEWISE:
|
||||
# default to full and piecewise for most models
|
||||
self.compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
|
||||
# pooling models and encoder-decoder models
|
||||
# do not support full cudagraphs
|
||||
if self.model_config is not None and \
|
||||
(self.model_config.pooler_config is not None
|
||||
or self.model_config.is_encoder_decoder):
|
||||
self.compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
# disable cudagraph when enforce eager execution
|
||||
if self.model_config is not None and \
|
||||
self.model_config.enforce_eager:
|
||||
logger.info("Cudagraph is disabled under eager mode")
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
elif envs.VLLM_USE_V1:
|
||||
self.compilation_config.cudagraph_num_of_warmups = 1
|
||||
|
||||
self._set_cudagraph_sizes()
|
||||
else:
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
if self.cache_config.kv_sharing_fast_prefill:
|
||||
|
||||
if self.speculative_config is not None and \
|
||||
self.speculative_config.use_eagle():
|
||||
raise NotImplementedError(
|
||||
"Fast prefill optimization for KV sharing is not "
|
||||
"compatible with EAGLE as EAGLE requires correct logits "
|
||||
"for all tokens while fast prefill gives incorrect logits "
|
||||
"for prompt tokens.")
|
||||
|
||||
logger.warning_once(
|
||||
"--kv-sharing-fast-prefill requires changes on model side for "
|
||||
"correctness and to realize prefill savings. ")
|
||||
|
||||
disable_chunked_prefill_reasons: list[str] = []
|
||||
|
||||
if self.model_config:
|
||||
if self.model_config.pooler_config:
|
||||
pooling_type = self.model_config.pooler_config.pooling_type
|
||||
if pooling_type is None or pooling_type.lower() != "last":
|
||||
disable_chunked_prefill_reasons.append(
|
||||
"Only \"last\" pooling supports chunked "
|
||||
"prefill and prefix caching; disabling both.")
|
||||
if not getattr(self.model_config.hf_config, "is_causal", True):
|
||||
disable_chunked_prefill_reasons.append(
|
||||
"Only models using causal attention supports chunked "
|
||||
"prefill and prefix caching; disabling both.")
|
||||
elif self.model_config.is_encoder_decoder:
|
||||
self.scheduler_config.max_num_encoder_input_tokens = \
|
||||
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
|
||||
logger.debug(
|
||||
"Encoder-decoder model detected: setting "
|
||||
"`max_num_encoder_input_tokens` to encoder length (%s)",
|
||||
self.scheduler_config.max_num_encoder_input_tokens)
|
||||
self.scheduler_config.disable_chunked_mm_input = True
|
||||
disable_chunked_prefill_reasons.append(
|
||||
"Encoder-decoder models do not support chunked prefill nor"
|
||||
" prefix caching; disabling both.")
|
||||
if (self.model_config.architecture
|
||||
== "WhisperForConditionalGeneration"
|
||||
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD")
|
||||
!= "spawn"):
|
||||
logger.warning(
|
||||
"Whisper is known to have issues with "
|
||||
"forked workers. If startup is hanging, "
|
||||
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
|
||||
"to 'spawn'.")
|
||||
|
||||
if disable_chunked_prefill_reasons:
|
||||
for reason in disable_chunked_prefill_reasons:
|
||||
logger.info(reason)
|
||||
self.scheduler_config.chunked_prefill_enabled = False
|
||||
self.scheduler_config.long_prefill_token_threshold = 0
|
||||
|
||||
if self.cache_config is not None:
|
||||
self.cache_config.enable_prefix_caching = False
|
||||
|
||||
if (self.kv_events_config is not None
|
||||
and self.kv_events_config.enable_kv_cache_events
|
||||
and not self.cache_config.enable_prefix_caching):
|
||||
logger.warning(
|
||||
"KV cache events are on, but prefix caching is not enabled."
|
||||
"Use --enable-prefix-caching to enable.")
|
||||
if (self.kv_events_config is not None
|
||||
and self.kv_events_config.publisher != "null"
|
||||
and not self.kv_events_config.enable_kv_cache_events):
|
||||
logger.warning("KV cache events are disabled,"
|
||||
"but the scheduler is configured to publish them."
|
||||
"Modify KVEventsConfig.enable_kv_cache_events"
|
||||
"to True to enable.")
|
||||
current_platform.check_and_update_config(self)
|
||||
|
||||
# final check of cudagraph mode after platform-specific update
|
||||
if envs.VLLM_USE_V1 and current_platform.is_cuda_alike():
|
||||
if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL \
|
||||
and self.model_config is not None and \
|
||||
not self.model_config.disable_cascade_attn:
|
||||
logger.info("CUDAGraphMode.FULL is not supported with "
|
||||
"cascade attention currently. Disabling cascade"
|
||||
"attention.")
|
||||
self.model_config.disable_cascade_attn = True
|
||||
|
||||
if self.compilation_config.cudagraph_mode\
|
||||
.requires_piecewise_compilation():
|
||||
assert self.compilation_config.level == \
|
||||
CompilationLevel.PIECEWISE, \
|
||||
"Compilation level should be CompilationLevel.PIECEWISE "\
|
||||
"when cudagraph_mode piecewise cudagraphs is used, "\
|
||||
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
|
||||
|
||||
if self.parallel_config.enable_dbo:
|
||||
a2a_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
assert a2a_backend in \
|
||||
["deepep_low_latency", "deepep_high_throughput"], \
|
||||
"Microbatching currently only supports the deepep_low_latency and "\
|
||||
f"deepep_high_throughput all2all backend. {a2a_backend} is not "\
|
||||
"supported. To fix set the VLLM_ALL2ALL_BACKEND environment "\
|
||||
"variable to deepep_low_latency or deepep_high_throughput and "\
|
||||
"install the DeepEP kernels."
|
||||
|
||||
if not self.instance_id:
|
||||
self.instance_id = random_uuid()[:5]
|
||||
|
||||
# Do this after all the updates to compilation_config.level
|
||||
if envs.VLLM_USE_V1 and \
|
||||
self.compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
self.compilation_config.set_splitting_ops_for_v1()
|
||||
|
||||
if (envs.VLLM_USE_V1
|
||||
and not self.scheduler_config.disable_hybrid_kv_cache_manager):
|
||||
# logger should only print warning message for hybrid models. As we
|
||||
# can't know whether the model is hybrid or not now, so we don't log
|
||||
# warning message here and will log it later.
|
||||
if not current_platform.support_hybrid_kv_cache():
|
||||
# Hybrid KV cache manager is not supported on non-GPU platforms.
|
||||
self.scheduler_config.disable_hybrid_kv_cache_manager = True
|
||||
if self.kv_transfer_config is not None:
|
||||
# Hybrid KV cache manager is not compatible with KV transfer.
|
||||
self.scheduler_config.disable_hybrid_kv_cache_manager = True
|
||||
if self.kv_events_config is not None:
|
||||
# Hybrid KV cache manager is not compatible with KV events.
|
||||
self.scheduler_config.disable_hybrid_kv_cache_manager = True
|
||||
if self.model_config is not None and \
|
||||
self.model_config.attention_chunk_size is not None:
|
||||
if self.speculative_config is not None and \
|
||||
self.speculative_config.use_eagle():
|
||||
# Hybrid KV cache manager is not yet supported with chunked
|
||||
# local attention + eagle.
|
||||
self.scheduler_config.disable_hybrid_kv_cache_manager = True
|
||||
elif \
|
||||
not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE:
|
||||
logger.warning(
|
||||
"There is a latency regression when using chunked local"
|
||||
" attention with the hybrid KV cache manager. Disabling"
|
||||
" it, by default. To enable it, set the environment "
|
||||
"VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1."
|
||||
)
|
||||
# Hybrid KV cache manager is not yet supported with chunked
|
||||
# local attention.
|
||||
self.scheduler_config.disable_hybrid_kv_cache_manager = True
|
||||
|
||||
def update_sizes_for_sequence_parallelism(self,
|
||||
possible_sizes: list) -> list:
|
||||
# remove the sizes that not multiple of tp_size when
|
||||
# enable sequence parallelism
|
||||
removed_sizes = [
|
||||
size for size in possible_sizes
|
||||
if size % self.parallel_config.tensor_parallel_size != 0
|
||||
]
|
||||
if removed_sizes:
|
||||
logger.warning(
|
||||
"Batch sizes %s are removed because they are not "
|
||||
"multiple of tp_size %d when "
|
||||
"sequence parallelism is enabled", removed_sizes,
|
||||
self.parallel_config.tensor_parallel_size)
|
||||
|
||||
return [
|
||||
size for size in possible_sizes
|
||||
if size % self.parallel_config.tensor_parallel_size == 0
|
||||
]
|
||||
|
||||
def _set_cudagraph_sizes(self):
|
||||
"""
|
||||
vLLM defines the default candidate list of batch sizes for CUDA graph
|
||||
capture as:
|
||||
|
||||
```python
|
||||
max_graph_size = min(max_num_seqs * 2, 512)
|
||||
# 1, 2, 4, then multiples of 8 up to max_graph_size
|
||||
cuda_graph_sizes = [1, 2, 4, 8, 16, 24, 32, 40, ..., max_graph_size]
|
||||
|
||||
In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
|
||||
will be the final sizes to capture cudagraph (in descending order).
|
||||
|
||||
These sizes are used to capture and reuse CUDA graphs for
|
||||
performance-critical paths (e.g., decoding). Capturing enables
|
||||
significantly faster kernel dispatch by avoiding Python overhead. The
|
||||
list is then filtered based on `max_num_batched_tokens` (e.g., 8192 on
|
||||
most GPUs), which controls the total allowed number of tokens in a
|
||||
batch. Since each sequence may have a variable number of tokens, the
|
||||
maximum usable batch size will depend on actual sequence lengths.
|
||||
|
||||
Example:
|
||||
With `max_num_batched_tokens = 8192`, and typical sequences
|
||||
averaging ~32 tokens, most practical batch sizes fall below 256.
|
||||
However, the system will still allow capture sizes up to 512 if
|
||||
shape and memory permit.
|
||||
|
||||
Note:
|
||||
If users explicitly specify cudagraph capture sizes in the
|
||||
compilation config, those will override this default logic.
|
||||
At runtime:
|
||||
|
||||
- If batch size <= one of the `cudagraph_capture_sizes`, the closest
|
||||
padded CUDA graph will be used.
|
||||
- If batch size > largest `cudagraph_capture_sizes`, cudagraph will
|
||||
not be used.
|
||||
"""
|
||||
|
||||
# calculate the default `batch_size_capture_list`
|
||||
batch_size_capture_list = []
|
||||
if self.model_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes
|
||||
if len(cuda_graph_sizes) == 1:
|
||||
batch_size_capture_list = [1, 2, 4] + [
|
||||
i for i in range(8, cuda_graph_sizes[0] + 1, 8)
|
||||
]
|
||||
elif len(cuda_graph_sizes) > 1:
|
||||
batch_size_capture_list = sorted(cuda_graph_sizes)
|
||||
else:
|
||||
raise TypeError(f"Invalid value for {cuda_graph_sizes=}.")
|
||||
if self.parallel_config.tensor_parallel_size > 1 and \
|
||||
self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||
batch_size_capture_list = \
|
||||
self.update_sizes_for_sequence_parallelism(batch_size_capture_list)
|
||||
max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
batch_size_capture_list = [
|
||||
size for size in batch_size_capture_list
|
||||
if size <= max_num_tokens
|
||||
]
|
||||
|
||||
self.compilation_config.init_with_cudagraph_sizes(
|
||||
batch_size_capture_list)
|
||||
|
||||
def recalculate_max_model_len(self, max_model_len: int):
|
||||
# Can only be called in try_verify_and_update_config
|
||||
model_config = self.model_config
|
||||
max_model_len = model_config.get_and_verify_max_len(max_model_len)
|
||||
self.model_config.max_model_len = max_model_len
|
||||
self.scheduler_config.max_model_len = max_model_len
|
||||
|
||||
def try_verify_and_update_config(self):
|
||||
if self.model_config is None:
|
||||
return
|
||||
|
||||
# Avoid running try_verify_and_update_config multiple times
|
||||
if getattr(self.model_config, "config_updated", False):
|
||||
return
|
||||
self.model_config.config_updated = True
|
||||
|
||||
architecture = self.model_config.architecture
|
||||
if architecture is None:
|
||||
return
|
||||
|
||||
from vllm.model_executor.models.config import (
|
||||
MODELS_CONFIG_MAP, HybridAttentionMambaModelConfig)
|
||||
cls = MODELS_CONFIG_MAP.get(architecture, None)
|
||||
if cls is not None:
|
||||
cls.verify_and_update_config(self)
|
||||
|
||||
if self.model_config.is_hybrid:
|
||||
HybridAttentionMambaModelConfig.verify_and_update_config(self)
|
||||
|
||||
if self.model_config.convert_type == "classify":
|
||||
# Maybe convert ForCausalLM into ForSequenceClassification model.
|
||||
from vllm.model_executor.models.adapters import (
|
||||
SequenceClassificationConfig)
|
||||
SequenceClassificationConfig.verify_and_update_config(self)
|
||||
|
||||
if hasattr(self.model_config, "model_weights") and is_runai_obj_uri(
|
||||
self.model_config.model_weights):
|
||||
if self.load_config.load_format == "auto":
|
||||
logger.info("Detected Run:ai model config. "
|
||||
"Overriding `load_format` to 'runai_streamer'")
|
||||
self.load_config.load_format = "runai_streamer"
|
||||
elif self.load_config.load_format != "runai_streamer":
|
||||
raise ValueError(f"To load a model from S3, 'load_format' "
|
||||
f"must be 'runai_streamer', "
|
||||
f"but got '{self.load_config.load_format}'. "
|
||||
f"Model: {self.model_config.model}")
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"model={self.model_config.model!r}, "
|
||||
f"speculative_config={self.speculative_config!r}, "
|
||||
f"tokenizer={self.model_config.tokenizer!r}, "
|
||||
f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}, "
|
||||
f"tokenizer_mode={self.model_config.tokenizer_mode}, "
|
||||
f"revision={self.model_config.revision}, "
|
||||
f"tokenizer_revision={self.model_config.tokenizer_revision}, "
|
||||
f"trust_remote_code={self.model_config.trust_remote_code}, "
|
||||
f"dtype={self.model_config.dtype}, "
|
||||
f"max_seq_len={self.model_config.max_model_len}, "
|
||||
f"download_dir={self.load_config.download_dir!r}, "
|
||||
f"load_format={self.load_config.load_format}, "
|
||||
f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, " # noqa
|
||||
f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa
|
||||
f"data_parallel_size={self.parallel_config.data_parallel_size}, " # noqa
|
||||
f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa
|
||||
f"quantization={self.model_config.quantization}, "
|
||||
f"enforce_eager={self.model_config.enforce_eager}, "
|
||||
f"kv_cache_dtype={self.cache_config.cache_dtype}, "
|
||||
f"device_config={self.device_config.device}, "
|
||||
f"structured_outputs_config={self.structured_outputs_config!r}, "
|
||||
f"observability_config={self.observability_config!r}, "
|
||||
f"seed={self.model_config.seed}, "
|
||||
f"served_model_name={self.model_config.served_model_name}, "
|
||||
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
|
||||
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
|
||||
f"pooler_config={self.model_config.pooler_config!r}, "
|
||||
f"compilation_config={self.compilation_config!r}")
|
||||
|
||||
|
||||
_current_vllm_config: Optional[VllmConfig] = None
|
||||
_current_prefix: Optional[str] = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_current_vllm_config(vllm_config: VllmConfig,
|
||||
check_compile=False,
|
||||
prefix: Optional[str] = None):
|
||||
"""
|
||||
Temporarily set the current vLLM config.
|
||||
Used during model initialization.
|
||||
We save the current vLLM config in a global variable,
|
||||
so that all modules can access it, e.g. custom ops
|
||||
can access the vLLM config to determine how to dispatch.
|
||||
"""
|
||||
global _current_vllm_config, _current_prefix
|
||||
old_vllm_config = _current_vllm_config
|
||||
old_prefix = _current_prefix
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
num_models_seen = compilation_counter.num_models_seen
|
||||
try:
|
||||
_current_vllm_config = vllm_config
|
||||
_current_prefix = prefix
|
||||
yield
|
||||
except Exception:
|
||||
raise
|
||||
else:
|
||||
if check_compile:
|
||||
vllm_config.compilation_config.custom_op_log_check()
|
||||
|
||||
if check_compile and \
|
||||
vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
|
||||
and compilation_counter.num_models_seen == num_models_seen:
|
||||
# If the model supports compilation,
|
||||
# compilation_counter.num_models_seen should be increased
|
||||
# by at least 1.
|
||||
# If it is not increased, it means the model does not support
|
||||
# compilation (does not have @support_torch_compile decorator).
|
||||
logger.warning(
|
||||
"`torch.compile` is turned on, but the model %s"
|
||||
" does not support it. Please open an issue on GitHub"
|
||||
" if you want it to be supported.",
|
||||
vllm_config.model_config.model)
|
||||
finally:
|
||||
_current_vllm_config = old_vllm_config
|
||||
_current_prefix = old_prefix
|
||||
# Clear the compilation config cache when context changes
|
||||
get_cached_compilation_config.cache_clear()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_cached_compilation_config():
|
||||
"""Cache config to avoid repeated calls to get_current_vllm_config()"""
|
||||
return get_current_vllm_config().compilation_config
|
||||
|
||||
|
||||
def get_current_vllm_config() -> VllmConfig:
|
||||
if _current_vllm_config is None:
|
||||
# in ci, usually when we test custom ops/modules directly,
|
||||
# we don't set the vllm config. In that case, we set a default
|
||||
# config.
|
||||
logger.warning("Current vLLM config is not set.")
|
||||
from vllm.config import VllmConfig
|
||||
return VllmConfig()
|
||||
return _current_vllm_config
|
||||
|
||||
|
||||
def get_current_model_prefix() -> str:
|
||||
"""
|
||||
Get the prefix of the model that's currently being initialized.
|
||||
"""
|
||||
assert _current_prefix is not None, \
|
||||
"Current model prefix is not set. "
|
||||
return _current_prefix
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_layers_from_vllm_config(
|
||||
vllm_config: VllmConfig,
|
||||
layer_type: type[T],
|
||||
layer_names: Optional[list[str]] = None) -> dict[str, T]:
|
||||
"""
|
||||
Get layers from the vLLM config.
|
||||
|
||||
Args:
|
||||
vllm_config: The vLLM config.
|
||||
layer_type: The type of the layer to get.
|
||||
layer_names: The names of the layers to get. If None, return all layers.
|
||||
"""
|
||||
|
||||
if layer_names is None:
|
||||
layer_names = list(
|
||||
vllm_config.compilation_config.static_forward_context.keys())
|
||||
|
||||
forward_context = vllm_config.compilation_config.static_forward_context
|
||||
|
||||
return {
|
||||
layer_name: forward_context[layer_name]
|
||||
for layer_name in layer_names
|
||||
if isinstance(forward_context[layer_name], layer_type)
|
||||
}
|
||||
|
||||
|
||||
def update_config(config: DataclassInstanceT,
|
||||
overrides: dict[str, Any]) -> DataclassInstanceT:
|
||||
processed_overrides = {}
|
||||
for field_name, value in overrides.items():
|
||||
assert hasattr(
|
||||
config, field_name), f"{type(config)} has no field `{field_name}`"
|
||||
current_value = getattr(config, field_name)
|
||||
if is_dataclass(current_value) and not is_dataclass(value):
|
||||
assert isinstance(value, dict), (
|
||||
f"Overrides to {type(config)}.{field_name} must be a dict"
|
||||
f" or {type(current_value)}, but got {type(value)}")
|
||||
value = update_config(
|
||||
current_value, # type: ignore[type-var]
|
||||
value)
|
||||
processed_overrides[field_name] = value
|
||||
return replace(config, **processed_overrides)
|
||||
BIN
vllm/config/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm/config/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/config/__pycache__/cache.cpython-312.pyc
Normal file
BIN
vllm/config/__pycache__/cache.cpython-312.pyc
Normal file
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user