update
This commit is contained in:
0
vllm_old/v1/attention/backends/__init__.py
Normal file
0
vllm_old/v1/attention/backends/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm_old/v1/attention/backends/__pycache__/utils.cpython-312.pyc
Normal file
BIN
vllm_old/v1/attention/backends/__pycache__/utils.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
496
vllm_old/v1/attention/backends/cpu_attn.py
Normal file
496
vllm_old/v1/attention/backends/cpu_attn.py
Normal file
@@ -0,0 +1,496 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_CPU_ARCH_PREFER_MIXED_BATCH = (CpuArchEnum.X86,)
|
||||
|
||||
|
||||
class CPUAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CPU_ATTN"
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""CPU attention supports decoder and encoder-only attention."""
|
||||
from vllm.attention import AttentionType
|
||||
|
||||
return attn_type in (
|
||||
AttentionType.DECODER,
|
||||
AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["CPUAttentionBackendImpl"]:
|
||||
return CPUAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["CPUAttentionMetadataBuilder"]:
|
||||
return CPUAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
return 2, num_blocks, num_kv_heads, block_size, head_size
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class CPUAttentionMetadata:
|
||||
isa: str
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
scheduler_metadata: torch.Tensor | None
|
||||
causal: bool = True
|
||||
|
||||
# can be removed after deprecate sdpa
|
||||
use_sdpa_prefill: bool = False
|
||||
num_decode_tokens: int = 0
|
||||
sdpa_attn_masks: list[torch.Tensor | None] | None = None
|
||||
sdpa_start_loc: torch.Tensor | None = None
|
||||
|
||||
|
||||
class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.use_sdpa_prefill = False
|
||||
reorder_batch_threshold = None
|
||||
if current_platform.get_cpu_architecture() not in _CPU_ARCH_PREFER_MIXED_BATCH:
|
||||
# in this case, decode seqs are reordered to the front of prefill seqs
|
||||
# to split decode and prefill. Then use SDPA for prefill and
|
||||
# cpu_attention_with_kv_cache for decode
|
||||
reorder_batch_threshold = 1
|
||||
self.use_sdpa_prefill = True
|
||||
|
||||
self._init_reorder_batch_threshold(reorder_batch_threshold, False)
|
||||
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.num_kv_heads = vllm_config.model_config.get_num_kv_heads(parallel_config)
|
||||
self.num_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
parallel_config
|
||||
)
|
||||
self.head_dim = kv_cache_spec.head_size
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.window_size = getattr(kv_cache_spec, "sliding_window", -1)
|
||||
if self.window_size is None:
|
||||
self.window_size = -1
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.isa = _get_attn_isa(self.dtype, self.block_size)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> CPUAttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
causal = common_attn_metadata.causal
|
||||
|
||||
sdpa_start_loc = query_start_loc
|
||||
num_decode_tokens = 0
|
||||
if self.use_sdpa_prefill and causal:
|
||||
# Decoder, need reorder and truncate
|
||||
assert self.reorder_batch_threshold
|
||||
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold,
|
||||
require_uniform=True,
|
||||
)
|
||||
)
|
||||
num_reqs = num_decodes
|
||||
sdpa_start_loc = sdpa_start_loc[num_decodes:] - num_decode_tokens
|
||||
seq_lens = seq_lens[:num_decodes]
|
||||
query_start_loc = query_start_loc[: num_decodes + 1]
|
||||
block_table_tensor = block_table_tensor[:num_decodes]
|
||||
|
||||
sheduler_metadata = None
|
||||
if causal:
|
||||
# for decode batch, use the custom kernel
|
||||
sheduler_metadata = ops.cpu_attn_get_scheduler_metadata(
|
||||
num_reqs=num_reqs,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_dim=self.head_dim,
|
||||
seq_lens=seq_lens,
|
||||
dtype=self.dtype,
|
||||
query_start_loc=query_start_loc,
|
||||
causal=causal,
|
||||
sliding_window_size=self.window_size,
|
||||
isa=self.isa,
|
||||
enable_kv_split=True,
|
||||
)
|
||||
|
||||
attn_metadata = CPUAttentionMetadata(
|
||||
isa=self.isa,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
scheduler_metadata=sheduler_metadata,
|
||||
causal=causal,
|
||||
use_sdpa_prefill=self.use_sdpa_prefill,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
sdpa_start_loc=sdpa_start_loc,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class CPUAttentionBackendImpl(AttentionImpl):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
if logits_soft_cap is not None and attn_type in (
|
||||
AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY,
|
||||
):
|
||||
logger.warning_once(
|
||||
"CPU_ATTN does not support logits softcap for"
|
||||
" ENCODER and ENCODER_ONLY, outputs may be slightly off"
|
||||
)
|
||||
if logits_soft_cap is None:
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||
self.sliding_window = (sliding_window - 1, sliding_window - 1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if is_quantized_kv_cache(kv_cache_dtype):
|
||||
raise NotImplementedError("FP8 KV cache is unsupported in CPU_ATTN")
|
||||
self.attn_type = attn_type
|
||||
|
||||
self.sinks = sinks
|
||||
if self.sinks is not None:
|
||||
assert self.sinks.shape[0] == num_heads, (
|
||||
"Sinks must have the same number of heads as the number of "
|
||||
"heads in the layer"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: CPUAttentionMetadata | None,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass for CPU attention backend.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, num_kv_heads, block_size, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for CPUAttentionBackendImpl"
|
||||
)
|
||||
|
||||
# For warming-up
|
||||
if attn_metadata is None:
|
||||
return output
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
# Handle encoder attention differently - no KV cache needed
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
# For encoder attention,
|
||||
return self._run_sdpa_forward(
|
||||
query[:num_actual_tokens],
|
||||
key[:num_actual_tokens],
|
||||
value[:num_actual_tokens],
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
self.attn_type,
|
||||
)
|
||||
|
||||
# For decoder and cross-attention, use KV cache, size are
|
||||
# [num_blocks, num_kv_heads, block_size, head_size]
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
ops.cpu_attn_reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
attn_metadata.isa,
|
||||
)
|
||||
|
||||
if attn_metadata.use_sdpa_prefill:
|
||||
assert self.sinks is None, "Attention sink is unsupported in SDPA prefill"
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
self._run_sdpa_forward(
|
||||
query[num_decode_tokens:num_actual_tokens],
|
||||
key[num_decode_tokens:num_actual_tokens],
|
||||
value[num_decode_tokens:num_actual_tokens],
|
||||
output[num_decode_tokens:num_actual_tokens],
|
||||
attn_metadata,
|
||||
self.attn_type,
|
||||
)
|
||||
num_actual_tokens = num_decode_tokens
|
||||
|
||||
if num_actual_tokens > 0:
|
||||
ops.cpu_attention_with_kv_cache(
|
||||
query=query[:num_actual_tokens],
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
output=output[:num_actual_tokens], # type: ignore
|
||||
query_start_loc=attn_metadata.query_start_loc,
|
||||
seq_lens=attn_metadata.seq_lens,
|
||||
scale=self.scale,
|
||||
causal=attn_metadata.causal,
|
||||
alibi_slopes=self.alibi_slopes, # type: ignore
|
||||
sliding_window=self.sliding_window,
|
||||
block_table=attn_metadata.block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
scheduler_metadata=attn_metadata.scheduler_metadata,
|
||||
s_aux=self.sinks,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def _run_sdpa_forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
attn_metadata: CPUAttentionMetadata,
|
||||
attn_type: str,
|
||||
) -> torch.Tensor:
|
||||
attn_masks = attn_metadata.sdpa_attn_masks
|
||||
if attn_masks is None:
|
||||
if self.alibi_slopes is not None:
|
||||
attn_masks = _make_alibi_bias(
|
||||
self.alibi_slopes,
|
||||
query.dtype,
|
||||
attn_metadata.sdpa_start_loc,
|
||||
)
|
||||
elif self.sliding_window[0] != -1 or self.sliding_window[1] != -1:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
attn_masks = _make_sliding_window_bias(
|
||||
attn_metadata.sdpa_start_loc,
|
||||
self.sliding_window[0],
|
||||
self.sliding_window[1],
|
||||
query.dtype,
|
||||
)
|
||||
else:
|
||||
attn_masks = [None] * (attn_metadata.sdpa_start_loc.size(0) - 1) # type: ignore
|
||||
attn_metadata.sdpa_attn_masks = attn_masks
|
||||
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
key = key.movedim(0, key.dim() - 2)
|
||||
value = value.movedim(0, value.dim() - 2)
|
||||
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=-3)
|
||||
value = value.repeat_interleave(self.num_queries_per_kv, dim=-3)
|
||||
|
||||
causal_attn = attn_type == AttentionType.DECODER
|
||||
|
||||
sdpa_start_loc = attn_metadata.sdpa_start_loc.numpy() # type: ignore
|
||||
for i in range(len(attn_masks)):
|
||||
mask = attn_masks[i]
|
||||
start_q = sdpa_start_loc[i]
|
||||
end_q = sdpa_start_loc[i + 1]
|
||||
sub_out = (
|
||||
torch.nn.functional.scaled_dot_product_attention(
|
||||
query[None, :, start_q:end_q, :],
|
||||
key[None, :, start_q:end_q, :],
|
||||
value[None, :, start_q:end_q, :],
|
||||
attn_mask=mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=causal_attn and mask is None,
|
||||
scale=self.scale,
|
||||
)
|
||||
.squeeze(0)
|
||||
.movedim(query.dim() - 2, 0)
|
||||
)
|
||||
output[start_q:end_q, :, :] = sub_out
|
||||
return output
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
sdpa_start_loc: torch.Tensor,
|
||||
) -> list[torch.Tensor]:
|
||||
attn_biases: list[torch.Tensor] = []
|
||||
seq_num = sdpa_start_loc.size(0) - 1
|
||||
sdpa_start_loc = sdpa_start_loc.numpy() # type: ignore
|
||||
for i in range(seq_num):
|
||||
seq_len = sdpa_start_loc[i + 1] - sdpa_start_loc[i]
|
||||
bias = torch.arange(seq_len, dtype=dtype) # type: ignore
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = bias[None, :].repeat((num_heads, 1, 1))
|
||||
bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
|
||||
inf_mask = (
|
||||
torch.empty((1, seq_len, seq_len), dtype=bias.dtype) # type: ignore
|
||||
.fill_(-torch.inf)
|
||||
.triu_(diagonal=1)
|
||||
)
|
||||
attn_biases.append((bias + inf_mask).to(dtype))
|
||||
|
||||
return attn_biases
|
||||
|
||||
|
||||
def _make_sliding_window_bias(
|
||||
sdpa_start_loc: torch.Tensor,
|
||||
left_window_size: int,
|
||||
right_window_size: int,
|
||||
dtype: torch.dtype,
|
||||
) -> list[torch.Tensor]:
|
||||
attn_biases: list[torch.Tensor] = []
|
||||
seq_num = sdpa_start_loc.size(0) - 1
|
||||
sdpa_start_loc = sdpa_start_loc.numpy() # type: ignore
|
||||
for i in range(seq_num):
|
||||
seq_len = sdpa_start_loc[i + 1] - sdpa_start_loc[i]
|
||||
mask = torch.full( # type: ignore
|
||||
(1, seq_len, seq_len), # type: ignore
|
||||
fill_value=1,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if right_window_size != -1:
|
||||
mask = torch.tril(mask, diagonal=right_window_size)
|
||||
if left_window_size != -1:
|
||||
mask = torch.triu(mask, diagonal=-left_window_size)
|
||||
mask = torch.log(mask)
|
||||
attn_biases.append(mask)
|
||||
|
||||
return attn_biases
|
||||
|
||||
|
||||
def _get_attn_isa(dtype: torch.dtype, block_size: int) -> str:
|
||||
supports_amx = torch._C._cpu._is_amx_tile_supported()
|
||||
if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0:
|
||||
return "amx"
|
||||
elif block_size % 32 == 0:
|
||||
return "vec"
|
||||
else:
|
||||
return "vec16"
|
||||
1215
vllm_old/v1/attention/backends/flash_attn.py
Normal file
1215
vllm_old/v1/attention/backends/flash_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
1572
vllm_old/v1/attention/backends/flashinfer.py
Normal file
1572
vllm_old/v1/attention/backends/flashinfer.py
Normal file
File diff suppressed because it is too large
Load Diff
926
vllm_old/v1/attention/backends/flex_attention.py
Normal file
926
vllm_old/v1/attention/backends/flex_attention.py
Normal file
@@ -0,0 +1,926 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with FlexAttention."""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
import torch._dynamo.decorators
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.attention.flex_attention import (
|
||||
BlockMask,
|
||||
_mask_mod_signature,
|
||||
_score_mod_signature,
|
||||
and_masks,
|
||||
create_block_mask,
|
||||
flex_attention,
|
||||
)
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
create_block_mask_compiled = torch.compile(
|
||||
create_block_mask, fullgraph=True, mode="reduce-overhead"
|
||||
)
|
||||
flex_attention_compiled = torch.compile(flex_attention, fullgraph=True)
|
||||
|
||||
|
||||
def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor:
|
||||
device = offsets.device
|
||||
counts = offsets[1:] - offsets[:-1]
|
||||
return torch.repeat_interleave(
|
||||
torch.arange(len(counts), device=device, dtype=torch.int32), counts
|
||||
)
|
||||
|
||||
|
||||
def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int):
|
||||
difference = (multiple - (x.shape[dim] % multiple)) % multiple
|
||||
if difference == 0:
|
||||
return x
|
||||
|
||||
dim = dim if dim >= 0 else x.ndim + dim
|
||||
pad_list = []
|
||||
|
||||
for i in range(x.ndim - 1, dim - 1, -1):
|
||||
if i == dim:
|
||||
pad_list.extend([0, difference])
|
||||
else:
|
||||
pad_list.extend([0, 0])
|
||||
|
||||
return F.pad(x, pad_list, mode="constant", value=0)
|
||||
|
||||
|
||||
class FlexAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLEX_ATTENTION"
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""FlexAttention supports both decoder and encoder-only attention."""
|
||||
from vllm.attention import AttentionType
|
||||
|
||||
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlexAttentionImpl"]:
|
||||
return FlexAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlexAttentionMetadataBuilder"]:
|
||||
return FlexAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return []
|
||||
|
||||
|
||||
# @torch.compile(fullgraph=True, mode="reduce-overhead")
|
||||
def physical_to_logical_mapping(
|
||||
block_table: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
block_size: int,
|
||||
total_blocks: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Creates an inverse mapping from physical block locations to logical indices.
|
||||
|
||||
The original block_table maps from logical blocks to physical locations:
|
||||
|
||||
Logical to Physical (Original block_table):
|
||||
┌───────────────────────────────────────────┐
|
||||
│ Request 0: │
|
||||
│ │
|
||||
│ Logical Blocks: 0 1 2 3 4 5 6 7 │
|
||||
│ │ │ │ │ │ │ │ │ │
|
||||
│ v v v v v v v v │
|
||||
│ Physical Blocks: 3 5 1 7 4 2 0 6 │
|
||||
└───────────────────────────────────────────┘
|
||||
|
||||
This function creates the inverse mapping:
|
||||
|
||||
Physical to Logical (Inverse mapping):
|
||||
┌───────────────────────────────────────────┐
|
||||
│ Request 0: │
|
||||
│ │
|
||||
│ Physical Blocks: 0 1 2 3 4 5 6 7 │
|
||||
│ │ │ │ │ │ │ │ │ │
|
||||
│ v v v v v v v v │
|
||||
│ Logical Blocks: 6 2 5 0 4 1 7 3 │
|
||||
└───────────────────────────────────────────┘
|
||||
|
||||
If multiple logical blocks map to the same physical block,
|
||||
this function returns the first (minimum) logical block index.
|
||||
|
||||
If a physical block is not mapped to by any logical block,
|
||||
its value in the result will be -1.
|
||||
|
||||
IMPORTANT: Garbage Value Protection
|
||||
────────────────────────────────────
|
||||
The block_table tensor may contain garbage values in unused positions
|
||||
(beyond the actual sequence length). For example, if a sequence only
|
||||
needs 3 blocks but the table has space for 8:
|
||||
|
||||
block_table[0] = [10, 25, 7, 999, 1234, 888, ...]
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
garbage values
|
||||
|
||||
These garbage values can cause issues because:
|
||||
1. They may map to valid physical blocks by coincidence
|
||||
2. The scatter_ operation will assign them logical indices
|
||||
3. Later attention computations may incorrectly access these blocks
|
||||
|
||||
To prevent this, we use seq_lens and block_size to mask out unused
|
||||
entries, ensuring only valid block references are processed.
|
||||
|
||||
Args:
|
||||
block_table: Tensor of shape [max_reqs, max_num_blocks]
|
||||
mapping logical blocks to physical locations. May contain
|
||||
garbage values in unused positions.
|
||||
seq_lens: Tensor of sequence lengths for each request. Used to
|
||||
determine how many blocks are actually needed per sequence.
|
||||
block_size: Size of each block in tokens. Used with seq_lens to
|
||||
compute the number of valid blocks per sequence.
|
||||
total_blocks: Total number of physical blocks available
|
||||
|
||||
Returns:
|
||||
A tensor of shape [max_reqs, total_blocks] where each entry
|
||||
physical_to_logical[req_id, physical_block] contains the logical
|
||||
block index for that physical block, or -1 if unused.
|
||||
"""
|
||||
max_reqs, max_num_blocks = block_table.shape
|
||||
device = block_table.device
|
||||
|
||||
physical_to_logical = torch.full(
|
||||
(max_reqs, total_blocks), -1, dtype=torch.long, device=device
|
||||
)
|
||||
|
||||
# Only process valid blocks to avoid garbage values
|
||||
num_blocks_per_seq = cdiv(seq_lens, block_size)
|
||||
mask = (
|
||||
torch.arange(max_num_blocks, device=device)[None, :]
|
||||
< num_blocks_per_seq[:, None]
|
||||
)
|
||||
|
||||
valid_block_table = torch.where(mask, block_table, 0)
|
||||
valid_logical_indices = torch.where(
|
||||
mask, torch.arange(max_num_blocks, device=device)[None, :], 0
|
||||
)
|
||||
|
||||
physical_to_logical.scatter_(
|
||||
-1, valid_block_table.to(torch.int64), valid_logical_indices
|
||||
)
|
||||
# NB - Seems like block 0 is always empty so we reset it manually
|
||||
physical_to_logical[:, 0] = -1
|
||||
return physical_to_logical
|
||||
|
||||
|
||||
def unique_static_unsorted(
|
||||
x: torch.Tensor,
|
||||
*,
|
||||
M: int, # maximum positive value (0 is “skip me”)
|
||||
dim: int = -1, # axis along which to deduplicate
|
||||
ignored_val: int = 0, # value to ignore
|
||||
pad_val: int = -1, # sentinel for unused slots
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
- Keeps the first occurrence of each non-zero value while preserving order,
|
||||
then left-packs those uniques and fills the rest with `pad_val`.
|
||||
- Returns (packed, keep_mask) with the *same shape* as `x`.
|
||||
- Requires that all values be in the range [0, M]
|
||||
- Skips ignored_val
|
||||
|
||||
Works on CPU or GPU, no Python loops, O(B·N) time / O(B·M) memory.
|
||||
|
||||
Example:
|
||||
x =[3, 1, 0, 1, 2], M=3, ignored_val=0 => [3, 1, 2, -1, -1]
|
||||
"""
|
||||
if not (-1 <= pad_val <= M):
|
||||
raise ValueError("`pad_val` must lie in [-1, M]")
|
||||
|
||||
# ── move `dim` to the end so we can treat tensor as [B, N] ──────────
|
||||
dim = dim % x.ndim
|
||||
x_perm = x.movedim(dim, -1) # shape [..., N]
|
||||
B, N = x_perm.numel() // x_perm.shape[-1], x_perm.shape[-1]
|
||||
x_flat = x_perm.reshape(B, N) # [B, N]
|
||||
|
||||
device = x.device
|
||||
idx = torch.arange(N, device=device).expand(B, N) # per-row indices
|
||||
|
||||
# ── build first-occurrence table for every v ∈ [0, M] ───────────────
|
||||
first_idx = torch.full((B, M + 1), N, device=device) # “∞”
|
||||
# scatter_reduce_: first_idx[b, v] = min(first_idx[b, v], i) for each i
|
||||
first_idx.scatter_reduce_(1, x_flat, idx, reduce="amin")
|
||||
|
||||
# ── keep mask: first occurrence *and* value ≠ 0 ─────────────────────
|
||||
keep = (x_flat != ignored_val) & (idx == first_idx.gather(1, x_flat)) # [B, N]
|
||||
|
||||
# ── left-pack uniques into a fresh tensor ───────────────────────────
|
||||
dest_pos = torch.cumsum(keep.to(torch.long), dim=1) - 1 # where to go
|
||||
packed_flat = torch.full_like(x_flat, pad_val)
|
||||
|
||||
rows, src_cols = torch.nonzero(keep, as_tuple=True)
|
||||
packed_flat[rows, dest_pos[rows, src_cols]] = x_flat[rows, src_cols]
|
||||
|
||||
# ── restore original layout ─────────────────────────────────────────
|
||||
packed = packed_flat.reshape(x_perm.shape).movedim(-1, dim)
|
||||
return packed
|
||||
|
||||
|
||||
def causal_mask_mod(
|
||||
b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor
|
||||
):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlexAttentionMetadata:
|
||||
causal: bool
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: torch.Tensor | None
|
||||
prefix_kv_lens: torch.Tensor | None
|
||||
suffix_kv_lens: torch.Tensor | None
|
||||
|
||||
# Block info
|
||||
total_cache_tokens: int
|
||||
block_size: int
|
||||
max_possible_sequence_length: int
|
||||
num_reqs: int
|
||||
physical_to_logical: torch.Tensor
|
||||
decode_offset: torch.Tensor
|
||||
num_blocks_per_seq: torch.Tensor
|
||||
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
# Flex Metadata
|
||||
num_blocks = 0
|
||||
block_mask: BlockMask | None = None
|
||||
score_mod: _score_mod_signature | None = None
|
||||
logical_mask_mod: _mask_mod_signature = causal_mask_mod
|
||||
doc_ids: torch.Tensor | None = None
|
||||
direct_build: bool = True
|
||||
q_block_size: int = 16
|
||||
kv_block_size: int = 16
|
||||
transformed_score_mod: _score_mod_signature | None = None
|
||||
sliding_window: int | None = None
|
||||
|
||||
def _convert_physical_to_logical(
|
||||
self,
|
||||
request_lookup: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
physical_kv_idx: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Convert physical indices to logical indices for both query and kv.
|
||||
|
||||
NB is_within_lower_bound: do sequences start on block_boundaries?
|
||||
|
||||
Returns:
|
||||
tuple of (is_valid, logical_q_idx, logical_kv_idx)
|
||||
"""
|
||||
# Map query indices to corresponding request indices
|
||||
q_req = request_lookup[q_idx]
|
||||
|
||||
# Convert physical KV indices to logical indices
|
||||
physical_kv_block = physical_kv_idx // self.block_size
|
||||
physical_kv_offset = physical_kv_idx % self.block_size
|
||||
logical_block_idx = self.physical_to_logical[q_req, physical_kv_block]
|
||||
logical_kv_idx = logical_block_idx * self.block_size + physical_kv_offset
|
||||
|
||||
# Determine valid kv indices
|
||||
live_block = logical_block_idx >= 0
|
||||
within_upper_bound = logical_kv_idx < self.seq_lens[q_req]
|
||||
within_lower_bound = logical_kv_idx >= 0
|
||||
is_valid = live_block & within_upper_bound & within_lower_bound
|
||||
|
||||
# Convert physical query indices to logical indices
|
||||
local_q_idx = q_idx - self.query_start_loc[q_req]
|
||||
logical_q_idx = local_q_idx + self.decode_offset[q_req]
|
||||
|
||||
return is_valid, logical_q_idx, logical_kv_idx
|
||||
|
||||
def get_causal_mask_mod(self) -> _mask_mod_signature:
|
||||
"""Creates the mask_mod function for FlexAttention.
|
||||
|
||||
This function creates the combined mask mod function that handles:
|
||||
1. The paged attention block mapping
|
||||
2. The mapping from packed query sequences to logical query entries
|
||||
|
||||
It also by defaults adds the decoding offset to the query indices.
|
||||
With this info we create the "logical" indices that are passed to
|
||||
mask_mod functions. This allows mask mod functions to be agnostic to
|
||||
layout of the query and key/value tensors.
|
||||
"""
|
||||
assert self.doc_ids is not None
|
||||
|
||||
def final_mask_mod(
|
||||
b: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
physical_kv_idx: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
(is_valid, logical_q_idx, logical_kv_idx) = (
|
||||
self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx)
|
||||
)
|
||||
# Apply mask modification only for valid indices
|
||||
return torch.where(
|
||||
is_valid,
|
||||
self.logical_mask_mod(b, h, logical_q_idx, logical_kv_idx),
|
||||
False,
|
||||
)
|
||||
|
||||
return final_mask_mod
|
||||
|
||||
def get_bidirectional_mask_mod(self) -> _mask_mod_signature:
|
||||
"""Creates the encoder mask_mod function for FlexAttention.
|
||||
|
||||
Since the encoder bidirectional attention doesn't run with
|
||||
KV cache, this function creates a mask based on the
|
||||
packed query sequences.
|
||||
"""
|
||||
# Create a lookup mapping from query indices -> request number
|
||||
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
|
||||
|
||||
def final_mask_mod(
|
||||
b: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
kv_idx: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return request_lookup[q_idx] == request_lookup[kv_idx]
|
||||
|
||||
return final_mask_mod
|
||||
|
||||
def get_sliding_window_mask_mod(self) -> _mask_mod_signature:
|
||||
"""Creates the sliding window mask_mod function for FlexAttention.
|
||||
|
||||
Note that the sliding window mask here is bidirectional, we need
|
||||
to mask it with the bidirectional/causal mask for encoder/decoder.
|
||||
"""
|
||||
|
||||
if self.sliding_window is None:
|
||||
raise ValueError("sliding_window must be set for sliding window attention")
|
||||
|
||||
def sliding_window_mask_mod(
|
||||
b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor
|
||||
):
|
||||
return torch.abs(q_idx - kv_idx) < self.sliding_window
|
||||
|
||||
def final_mask_mod(
|
||||
b: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
physical_kv_idx: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
(is_valid, logical_q_idx, logical_kv_idx) = (
|
||||
self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx)
|
||||
)
|
||||
return torch.where(
|
||||
is_valid,
|
||||
sliding_window_mask_mod(b, h, logical_q_idx, logical_kv_idx),
|
||||
False,
|
||||
)
|
||||
|
||||
return final_mask_mod if self.causal else sliding_window_mask_mod
|
||||
|
||||
def get_mask_mod(self):
|
||||
# Stage-1: initialize the base mask_mod
|
||||
# (causal mask for decoder or bidirectional mask for encoder)
|
||||
if self.causal:
|
||||
mask_mod = self.get_causal_mask_mod()
|
||||
else:
|
||||
mask_mod = self.get_bidirectional_mask_mod()
|
||||
# stage-2: add external mask_mod for special attention during
|
||||
# forwarding runtime to create the combined mask_mod.
|
||||
if self.sliding_window is not None:
|
||||
# Add sliding window mask for sliding window attention
|
||||
sliding_window_mask_mod = self.get_sliding_window_mask_mod()
|
||||
mask_mod = and_masks(mask_mod, sliding_window_mask_mod)
|
||||
return mask_mod
|
||||
|
||||
def get_transformed_score_mod(self) -> _score_mod_signature | None:
|
||||
"""Creates the transformed score_mod function for FlexAttention.
|
||||
|
||||
This function wraps the user's score_mod to handle physical-to-logical
|
||||
index conversion, similar to how get_mask_mod works for mask functions.
|
||||
"""
|
||||
if self.score_mod is None:
|
||||
return None
|
||||
|
||||
# Create a lookup mapping from query indices -> request number
|
||||
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
|
||||
user_score_mod = self.score_mod
|
||||
|
||||
def transformed_score_mod(
|
||||
score: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
physical_kv_idx: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
(is_valid, logical_q_idx, logical_kv_idx) = (
|
||||
self._convert_physical_to_logical(
|
||||
request_lookup, q_idx, physical_kv_idx
|
||||
)
|
||||
)
|
||||
|
||||
return torch.where(
|
||||
is_valid,
|
||||
user_score_mod(
|
||||
score, b, h, logical_q_idx, logical_kv_idx, physical_q=q_idx
|
||||
),
|
||||
-float("inf"),
|
||||
)
|
||||
|
||||
return transformed_score_mod
|
||||
|
||||
def _build_block_mask_direct(self) -> BlockMask:
|
||||
"""Direct block mask construction for standard causal attention.
|
||||
|
||||
This method constructs the block mask directly using
|
||||
BlockMask.from_kv_blocks which is much more efficient than the
|
||||
generic create_block_mask approach.
|
||||
|
||||
The direct path works as follows:
|
||||
1. For each query token, fetch blocks from block_table using max_seq_len
|
||||
(this fetches more blocks than needed for shorter sequences)
|
||||
2. Group query tokens into chunks of q_block_size
|
||||
3. For each group, deduplicate the blocks using unique_static_unsorted
|
||||
4. Create BlockMask using the deduplicated block indices
|
||||
|
||||
Over-estimation occurs when a group of q_block_size tokens contains
|
||||
multiple sequence IDs (doc_ids). In this case, we fetch ALL blocks for
|
||||
each sequence represented in the group, even though individual query
|
||||
tokens may only need a subset of those blocks based on causal masking
|
||||
and their position.
|
||||
|
||||
"""
|
||||
page_to_block_ratio = self.kv_block_size // self.block_size
|
||||
if page_to_block_ratio != 1:
|
||||
raise ValueError(
|
||||
f"FlexAttention currently requires the cache block size "
|
||||
f"({self.block_size}) to be equal to the kv_block_size "
|
||||
f"({self.kv_block_size}). Please check your model's "
|
||||
f"configuration."
|
||||
)
|
||||
|
||||
used_pages = self.block_table[
|
||||
self.doc_ids, : cdiv(self.max_seq_len, self.block_size)
|
||||
]
|
||||
used_pages_padded = pad_to_multiple(
|
||||
used_pages, multiple=self.q_block_size, dim=0
|
||||
)
|
||||
used_pages_padded = used_pages_padded.reshape(
|
||||
used_pages_padded.shape[0] // self.q_block_size, -1
|
||||
)
|
||||
used_pages_padded = used_pages_padded // page_to_block_ratio
|
||||
kv_indices = unique_static_unsorted(
|
||||
(used_pages_padded.long()), M=self.num_blocks
|
||||
).to(torch.int32)
|
||||
|
||||
kv_num_blocks = (kv_indices >= 0).sum(dim=-1).to(torch.int32)
|
||||
block_mask_kwargs = {
|
||||
"seq_lengths": (self.num_actual_tokens, self.total_cache_tokens),
|
||||
"kv_num_blocks": kv_num_blocks[None, None],
|
||||
"kv_indices": kv_indices[None, None],
|
||||
"full_kv_num_blocks": None,
|
||||
"full_kv_indices": None,
|
||||
"BLOCK_SIZE": (self.q_block_size, self.kv_block_size),
|
||||
"mask_mod": self.mask_mod,
|
||||
}
|
||||
|
||||
# compute_q_blocks parameter is available in PyTorch 2.9+
|
||||
if is_torch_equal_or_newer("2.9.0.dev0"):
|
||||
block_mask_kwargs["compute_q_blocks"] = False
|
||||
return BlockMask.from_kv_blocks(**block_mask_kwargs)
|
||||
|
||||
def build_block_mask(self) -> BlockMask:
|
||||
mask_mod = self.get_mask_mod()
|
||||
kv_len = self.total_cache_tokens if self.causal else self.num_actual_tokens
|
||||
return create_block_mask_compiled(
|
||||
mask_mod,
|
||||
None,
|
||||
None,
|
||||
self.num_actual_tokens,
|
||||
kv_len,
|
||||
device=self.block_table.device,
|
||||
BLOCK_SIZE=(self.q_block_size, self.kv_block_size),
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.use_cascade is False, "Not implemented yet."
|
||||
assert self.common_prefix_len == 0, "Not implemented yet."
|
||||
assert self.cu_prefix_query_lens is None, "Not implemented yet."
|
||||
assert self.prefix_kv_lens is None, "Not implemented yet."
|
||||
assert self.suffix_kv_lens is None, "Not implemented yet."
|
||||
# Create a lookup mapping from query indices -> request number
|
||||
self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc)
|
||||
self.num_blocks = self.total_cache_tokens // self.block_size
|
||||
|
||||
self.mask_mod = self.get_mask_mod()
|
||||
self.transformed_score_mod = self.get_transformed_score_mod()
|
||||
|
||||
if self.direct_build and self.causal:
|
||||
self.block_mask = self._build_block_mask_direct()
|
||||
else:
|
||||
self.block_mask = self.build_block_mask()
|
||||
|
||||
|
||||
class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.model_config = vllm_config.model_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
|
||||
self.num_heads_q = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config
|
||||
)
|
||||
self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||
self.headdim = self.model_config.get_head_size()
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
supports_small_blocks = is_torch_equal_or_newer("2.9.0.dev0")
|
||||
self.direct_build: bool = supports_small_blocks
|
||||
self.q_block_size: int = 16 if supports_small_blocks else 128
|
||||
self.kv_block_size: int = self.block_size if supports_small_blocks else 128
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> FlexAttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
num_blocks_per_seq = cdiv(seq_lens, self.block_size)
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
if use_cascade:
|
||||
raise NotImplementedError("Not yet my friend")
|
||||
|
||||
block_size = self.kv_cache_spec.block_size
|
||||
max_possible_seq_len = self.model_config.max_model_len
|
||||
num_gpu_blocks = self.cache_config.num_gpu_blocks
|
||||
|
||||
assert num_gpu_blocks is not None, (
|
||||
"FlexAttention requires num_gpu_blocks to be set"
|
||||
)
|
||||
total_cache_tokens = num_gpu_blocks * block_size
|
||||
|
||||
inverse_block_table = physical_to_logical_mapping(
|
||||
block_table_tensor, seq_lens, block_size, num_gpu_blocks
|
||||
)
|
||||
|
||||
offset_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
out = FlexAttentionMetadata(
|
||||
causal=common_attn_metadata.causal,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
block_size=block_size,
|
||||
max_possible_sequence_length=max_possible_seq_len,
|
||||
num_reqs=num_reqs,
|
||||
physical_to_logical=inverse_block_table,
|
||||
total_cache_tokens=total_cache_tokens,
|
||||
decode_offset=offset_tensor,
|
||||
num_blocks_per_seq=num_blocks_per_seq,
|
||||
# FIXME(Isotr0py): direct build has issue to build bidirectional
|
||||
# attention block mask for encoder-only models, disable it temporarily.
|
||||
# see: https://github.com/vllm-project/vllm/pull/27329#issuecomment-3431484053
|
||||
direct_build=(self.direct_build and common_attn_metadata.causal),
|
||||
q_block_size=self.q_block_size,
|
||||
kv_block_size=self.kv_block_size,
|
||||
)
|
||||
return out
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class FlexAttentionImpl(AttentionImpl):
|
||||
sliding_window: int | None
|
||||
alibi_slopes: torch.Tensor | None
|
||||
logits_soft_cap: float | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.attn_type = attn_type
|
||||
|
||||
if attn_type not in (AttentionType.ENCODER_ONLY, AttentionType.DECODER):
|
||||
raise NotImplementedError(
|
||||
f"FlexAttention does not support {attn_type} attention"
|
||||
)
|
||||
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError(
|
||||
"FlexAttention does not support alibi slopes yet."
|
||||
)
|
||||
else:
|
||||
self.alibi_slopes = None
|
||||
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
if self.logits_soft_cap is not None:
|
||||
raise NotImplementedError(
|
||||
"FlexAttention does not support logits soft cap yet."
|
||||
)
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("FlexAttention does not support kv sharing yet.")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlexAttention does not support quantized kv-cache. Yet"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def view_as_4d(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""View a 3d tensor as 4D."""
|
||||
if tensor.ndim == 4:
|
||||
return tensor
|
||||
assert tensor.ndim == 3
|
||||
return tensor[None, :, :, :]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlexAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FLexAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported for FlexAttentionImpl"
|
||||
)
|
||||
|
||||
enable_gqa = self.num_kv_heads != self.num_heads
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
# query = self.view_as_4d(query).permute(0, 2, 1, 3)
|
||||
# return torch.empty_like(query)
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
if attn_metadata.sliding_window != self.sliding_window:
|
||||
attn_metadata.sliding_window = self.sliding_window
|
||||
if attn_metadata.direct_build:
|
||||
# TODO: Support skipping the computation of sliding window
|
||||
# in direct block mask building code path.
|
||||
logger.warning_once(
|
||||
"Using direct block mask building with sliding window, "
|
||||
"which is suboptimal now. Performance may be degraded."
|
||||
)
|
||||
# update mask mod in attention metadata
|
||||
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
|
||||
attn_metadata.block_mask = attn_metadata._build_block_mask_direct()
|
||||
else:
|
||||
attn_metadata.block_mask = attn_metadata.build_block_mask()
|
||||
|
||||
if not attn_metadata.causal:
|
||||
assert self.attn_type == AttentionType.ENCODER_ONLY
|
||||
|
||||
query, key_tensor, value_tensor = map(
|
||||
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
|
||||
(query, key, value),
|
||||
)
|
||||
|
||||
query = query[:, :, :num_actual_tokens, :]
|
||||
if (key_tensor.size(-2) > num_actual_tokens) or (
|
||||
value_tensor.size(-2) > num_actual_tokens
|
||||
):
|
||||
# In the encoder-only model with torch.compile,
|
||||
# qkv might be padded, which might cause exception.
|
||||
# see: https://github.com/vllm-project/vllm/pull/24872#discussion_r2353252290
|
||||
key_tensor = key_tensor[:, :, :num_actual_tokens, :]
|
||||
value_tensor = value_tensor[:, :, :num_actual_tokens, :]
|
||||
|
||||
else:
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# View out the block_size dim
|
||||
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
|
||||
value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)
|
||||
query, key_tensor, value_tensor = map(
|
||||
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
|
||||
(query, key_cache, value_cache),
|
||||
)
|
||||
|
||||
query = query[:, :, :num_actual_tokens, :]
|
||||
|
||||
# Doesn't work for now -> constraint violation
|
||||
# torch._dynamo.try_mark_dynamic(query, 2)
|
||||
|
||||
assert attn_metadata.block_mask is not None
|
||||
block_m, block_n = attn_metadata.block_mask.BLOCK_SIZE
|
||||
|
||||
kernel_options = get_kernel_options(
|
||||
query, block_m, block_n, attn_metadata.direct_build
|
||||
)
|
||||
out = flex_attention_compiled(
|
||||
query,
|
||||
key_tensor,
|
||||
value_tensor,
|
||||
attn_metadata.transformed_score_mod,
|
||||
attn_metadata.block_mask,
|
||||
self.scale,
|
||||
enable_gqa=enable_gqa,
|
||||
kernel_options=kernel_options,
|
||||
)
|
||||
|
||||
# Flex doesn't have an out variant today, rely on epilogue fusion
|
||||
out = out.permute(0, 2, 1, 3).squeeze(0)
|
||||
output[:num_actual_tokens, :, :].copy_(out)
|
||||
return output
|
||||
|
||||
|
||||
def get_kernel_options(
|
||||
query, block_m, block_n, use_direct_build: bool
|
||||
) -> dict[str, int | bool]:
|
||||
kernel_options: dict[str, int | bool] = {
|
||||
"FORCE_USE_FLEX_ATTENTION": True,
|
||||
}
|
||||
|
||||
def ensure_divisible(candidate: int, block_size: int) -> int:
|
||||
"""Pick a kernel block size that divides the logical block."""
|
||||
if block_size <= 0:
|
||||
return candidate
|
||||
candidate = min(candidate, block_size)
|
||||
if candidate <= 0:
|
||||
return block_size
|
||||
if block_size % candidate == 0:
|
||||
return candidate
|
||||
|
||||
candidate = math.gcd(candidate, block_size)
|
||||
if candidate <= 1:
|
||||
return block_size
|
||||
return candidate
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
kernel_options["BLOCK_M"] = 16
|
||||
kernel_options["BLOCK_N"] = 16
|
||||
kernel_options["IS_DIVISIBLE"] = False
|
||||
return kernel_options
|
||||
if use_direct_build:
|
||||
kernel_options["BLOCK_M"] = block_m
|
||||
kernel_options["BLOCK_N"] = block_n
|
||||
return kernel_options
|
||||
else:
|
||||
preferred_block = 32 if query.dtype == torch.float32 else 64
|
||||
block_lower_bound = 16
|
||||
|
||||
block_m_candidate = ensure_divisible(preferred_block, block_m)
|
||||
block_n_candidate = ensure_divisible(preferred_block, block_n)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device_props = torch.cuda.get_device_properties()
|
||||
max_shared_memory = device_props.shared_memory_per_block_optin
|
||||
if max_shared_memory < 144 * 1024:
|
||||
block_m_candidate = ensure_divisible(
|
||||
max(1, block_m_candidate // 2), block_m
|
||||
)
|
||||
block_n_candidate = ensure_divisible(
|
||||
max(1, block_n_candidate // 2), block_n
|
||||
)
|
||||
|
||||
block_m_candidate = max(block_m_candidate, block_lower_bound)
|
||||
block_n_candidate = max(block_n_candidate, block_lower_bound)
|
||||
|
||||
kernel_options["BLOCK_M"] = block_m_candidate
|
||||
kernel_options["BLOCK_N"] = block_n_candidate
|
||||
|
||||
return kernel_options
|
||||
387
vllm_old/v1/attention/backends/gdn_attn.py
Normal file
387
vllm_old/v1/attention/backends/gdn_attn.py
Normal file
@@ -0,0 +1,387 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Backend for GatedDeltaNet attention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
compute_causal_conv1d_metadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
class GDNAttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
|
||||
return GDNAttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class GDNAttentionMetadata:
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_spec_decodes: int
|
||||
num_spec_decode_tokens: int
|
||||
num_actual_tokens: int
|
||||
|
||||
has_initial_state: torch.Tensor | None = None
|
||||
|
||||
spec_query_start_loc: torch.Tensor | None = None # shape: [num_spec_decodes + 1,]
|
||||
non_spec_query_start_loc: torch.Tensor | None = (
|
||||
None # shape: [batch - num_spec_decodes + 1,]
|
||||
)
|
||||
|
||||
spec_state_indices_tensor: torch.Tensor | None = None # shape: [batch, num_spec]
|
||||
non_spec_state_indices_tensor: torch.Tensor | None = (
|
||||
None # shape: [batch - num_spec_decodes,]
|
||||
)
|
||||
spec_sequence_masks: torch.Tensor | None = None # shape: [batch,]
|
||||
spec_token_indx: torch.Tensor | None = None
|
||||
non_spec_token_indx: torch.Tensor | None = None
|
||||
|
||||
num_accepted_tokens: torch.Tensor | None = None # shape: [batch,]
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: dict | None = None
|
||||
batch_ptr: torch.Tensor | None = None
|
||||
token_chunk_offset_ptr: torch.Tensor | None = None
|
||||
|
||||
|
||||
class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]):
|
||||
_cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
if self.speculative_config:
|
||||
self.num_spec = self.speculative_config.num_speculative_tokens
|
||||
else:
|
||||
self.num_spec = 0
|
||||
self.use_spec_decode = self.num_spec > 0
|
||||
self._init_reorder_batch_threshold(1, self.use_spec_decode)
|
||||
|
||||
self.use_full_cuda_graph = (
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
)
|
||||
self.decode_cudagraph_max_bs = min(
|
||||
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1),
|
||||
self.compilation_config.max_cudagraph_capture_size,
|
||||
)
|
||||
|
||||
self.spec_state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs, self.num_spec + 1),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.spec_sequence_masks = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
self.spec_token_indx = torch.empty(
|
||||
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_token_indx = torch.empty(
|
||||
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.spec_query_start_loc = torch.empty(
|
||||
(self.decode_cudagraph_max_bs + 1,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_query_start_loc = torch.empty(
|
||||
(self.decode_cudagraph_max_bs + 1,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.num_accepted_tokens = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build( # type: ignore[override]
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
num_decode_draft_tokens_cpu: torch.Tensor | None = None,
|
||||
fast_build: bool = False,
|
||||
) -> GDNAttentionMetadata:
|
||||
m = common_attn_metadata
|
||||
|
||||
query_start_loc = m.query_start_loc
|
||||
context_lens = m.num_computed_tokens_cpu
|
||||
context_lens_tensor = context_lens.to(query_start_loc.device)
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
if (
|
||||
not self.use_spec_decode
|
||||
or num_decode_draft_tokens_cpu is None
|
||||
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0]
|
||||
.sum()
|
||||
.item()
|
||||
== 0
|
||||
):
|
||||
spec_sequence_masks = None
|
||||
num_spec_decodes = 0
|
||||
else:
|
||||
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
|
||||
num_spec_decodes = spec_sequence_masks.sum().item()
|
||||
if num_spec_decodes == 0:
|
||||
spec_sequence_masks = None
|
||||
else:
|
||||
spec_sequence_masks = spec_sequence_masks.to(
|
||||
query_start_loc.device, non_blocking=True
|
||||
)
|
||||
|
||||
if spec_sequence_masks is None:
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(m, decode_threshold=1)
|
||||
)
|
||||
num_spec_decode_tokens = 0
|
||||
spec_token_indx = None
|
||||
non_spec_token_indx = None
|
||||
spec_state_indices_tensor = None
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
|
||||
spec_query_start_loc = None
|
||||
non_spec_query_start_loc = query_start_loc
|
||||
num_accepted_tokens = None
|
||||
else:
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
|
||||
non_spec_query_lens = query_lens[~spec_sequence_masks]
|
||||
num_decodes = (non_spec_query_lens == 1).sum().item()
|
||||
num_prefills = non_spec_query_lens.size(0) - num_decodes
|
||||
num_decode_tokens = num_decodes
|
||||
num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens
|
||||
num_spec_decode_tokens = (
|
||||
query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
|
||||
)
|
||||
|
||||
if num_prefills == 0 and num_decodes == 0:
|
||||
spec_token_size = min(
|
||||
num_spec_decodes * (self.num_spec + 1),
|
||||
query_start_loc[-1].item(),
|
||||
)
|
||||
spec_token_indx = torch.arange(
|
||||
spec_token_size,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
non_spec_token_indx = torch.empty(
|
||||
0, dtype=torch.int32, device=query_start_loc.device
|
||||
)
|
||||
spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
|
||||
non_spec_state_indices_tensor = None
|
||||
spec_query_start_loc = query_start_loc
|
||||
non_spec_query_start_loc = None
|
||||
else:
|
||||
spec_token_masks = torch.repeat_interleave(
|
||||
spec_sequence_masks, query_lens
|
||||
)
|
||||
index = torch.argsort(spec_token_masks)
|
||||
num_non_spec_tokens = num_prefill_tokens + num_decode_tokens
|
||||
non_spec_token_indx = index[:num_non_spec_tokens]
|
||||
spec_token_indx = index[num_non_spec_tokens:]
|
||||
|
||||
spec_state_indices_tensor = m.block_table_tensor[
|
||||
spec_sequence_masks, : self.num_spec + 1
|
||||
]
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[
|
||||
~spec_sequence_masks, 0
|
||||
]
|
||||
|
||||
spec_query_start_loc = torch.zeros(
|
||||
num_spec_decodes + 1,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
torch.cumsum(
|
||||
query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:]
|
||||
)
|
||||
non_spec_query_start_loc = torch.zeros(
|
||||
query_lens.size(0) - num_spec_decodes + 1,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
torch.cumsum(
|
||||
query_lens[~spec_sequence_masks],
|
||||
dim=0,
|
||||
out=non_spec_query_start_loc[1:],
|
||||
)
|
||||
|
||||
assert num_accepted_tokens is not None
|
||||
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
|
||||
|
||||
if num_prefills > 0:
|
||||
has_initial_state = context_lens_tensor > 0
|
||||
if spec_sequence_masks is not None:
|
||||
has_initial_state = has_initial_state[~spec_sequence_masks]
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||
compute_causal_conv1d_metadata(non_spec_query_start_loc)
|
||||
)
|
||||
else:
|
||||
has_initial_state = None
|
||||
num_actual_tokens = (
|
||||
num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens
|
||||
)
|
||||
|
||||
# prepare tensors for cudagraph
|
||||
#
|
||||
# With speculative decoding, the xgrammar backend may rollback tokens
|
||||
# and causing some sequences has less draft tokens than self.num_spec.
|
||||
#
|
||||
# In above cases, the max possible batch size for n tokens, can be
|
||||
# min(n, cudagraph_max_bs).
|
||||
if (
|
||||
self.use_full_cuda_graph
|
||||
and num_prefills == 0
|
||||
and num_decodes == 0
|
||||
and num_spec_decodes <= self.decode_cudagraph_max_bs
|
||||
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
|
||||
):
|
||||
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
|
||||
batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)
|
||||
|
||||
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
|
||||
spec_state_indices_tensor, non_blocking=True
|
||||
)
|
||||
spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size]
|
||||
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
|
||||
|
||||
self.spec_sequence_masks[:num_spec_decodes].copy_(
|
||||
spec_sequence_masks, non_blocking=True
|
||||
)
|
||||
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
|
||||
spec_sequence_masks[num_spec_decodes:].fill_(False)
|
||||
|
||||
assert non_spec_token_indx is not None and spec_token_indx is not None
|
||||
self.non_spec_token_indx[: non_spec_token_indx.size(0)].copy_(
|
||||
non_spec_token_indx, non_blocking=True
|
||||
)
|
||||
non_spec_token_indx = self.non_spec_token_indx[
|
||||
: non_spec_token_indx.size(0)
|
||||
]
|
||||
|
||||
self.spec_token_indx[: spec_token_indx.size(0)].copy_(
|
||||
spec_token_indx, non_blocking=True
|
||||
)
|
||||
spec_token_indx = self.spec_token_indx[: spec_token_indx.size(0)]
|
||||
|
||||
self.spec_query_start_loc[: num_spec_decodes + 1].copy_(
|
||||
spec_query_start_loc, non_blocking=True
|
||||
)
|
||||
spec_num_query_tokens = spec_query_start_loc[-1] # type: ignore[index]
|
||||
spec_query_start_loc = self.spec_query_start_loc[: batch_size + 1]
|
||||
spec_query_start_loc[num_spec_decodes + 1 :].fill_(spec_num_query_tokens)
|
||||
|
||||
self.num_accepted_tokens[:num_spec_decodes].copy_(
|
||||
num_accepted_tokens, non_blocking=True
|
||||
)
|
||||
num_accepted_tokens = self.num_accepted_tokens[:batch_size]
|
||||
num_accepted_tokens[num_spec_decodes:].fill_(1)
|
||||
|
||||
if (
|
||||
self.use_full_cuda_graph
|
||||
and num_prefills == 0
|
||||
and num_spec_decodes == 0
|
||||
and num_decodes <= self.decode_cudagraph_max_bs
|
||||
):
|
||||
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
|
||||
batch_size = num_actual_tokens
|
||||
|
||||
self.non_spec_state_indices_tensor[:num_decodes].copy_(
|
||||
non_spec_state_indices_tensor, non_blocking=True
|
||||
)
|
||||
non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[
|
||||
:batch_size
|
||||
]
|
||||
non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)
|
||||
|
||||
self.non_spec_query_start_loc[: num_decodes + 1].copy_(
|
||||
non_spec_query_start_loc, non_blocking=True
|
||||
)
|
||||
non_spec_num_query_tokens = non_spec_query_start_loc[-1] # type: ignore[index]
|
||||
non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1]
|
||||
non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens)
|
||||
|
||||
attn_metadata = GDNAttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_spec_decodes=num_spec_decodes,
|
||||
num_spec_decode_tokens=num_spec_decode_tokens,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
has_initial_state=has_initial_state,
|
||||
spec_query_start_loc=spec_query_start_loc,
|
||||
non_spec_query_start_loc=non_spec_query_start_loc,
|
||||
spec_state_indices_tensor=spec_state_indices_tensor,
|
||||
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
|
||||
spec_sequence_masks=spec_sequence_masks,
|
||||
spec_token_indx=spec_token_indx,
|
||||
non_spec_token_indx=non_spec_token_indx,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
):
|
||||
"""
|
||||
This method builds the metadata for full cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with Mamba.
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
|
||||
assert (
|
||||
m.num_reqs <= self.decode_cudagraph_max_bs
|
||||
and m.num_actual_tokens <= self.decode_cudagraph_max_bs
|
||||
), (
|
||||
f"GDN only supports decode-only full CUDAGraph capture. "
|
||||
f"Make sure batch size ({m.num_reqs}) <= "
|
||||
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), "
|
||||
f"and number of tokens ({m.num_actual_tokens}) <= "
|
||||
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs})."
|
||||
)
|
||||
|
||||
num_accepted_tokens = torch.diff(m.query_start_loc)
|
||||
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
|
||||
m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
|
||||
|
||||
return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu)
|
||||
74
vllm_old/v1/attention/backends/linear_attn.py
Normal file
74
vllm_old/v1/attention/backends/linear_attn.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
class LinearAttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]:
|
||||
return LinearAttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class LinearAttentionMetadata:
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
query_start_loc: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
state_indices_tensor: torch.Tensor # shape: [batch,]
|
||||
|
||||
|
||||
class LinearAttentionMetadataBuilder(AttentionMetadataBuilder[LinearAttentionMetadata]):
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> LinearAttentionMetadata:
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
)
|
||||
|
||||
attn_metadata = LinearAttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens=seq_lens,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
)
|
||||
return attn_metadata
|
||||
165
vllm_old/v1/attention/backends/mamba1_attn.py
Normal file
165
vllm_old/v1/attention/backends/mamba1_attn.py
Normal file
@@ -0,0 +1,165 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
class Mamba1AttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]:
|
||||
return Mamba1AttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mamba1AttentionMetadata:
|
||||
query_start_loc_p: torch.Tensor
|
||||
state_indices_tensor: torch.Tensor
|
||||
has_initial_states_p: torch.Tensor | None
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_padded_decodes: int
|
||||
|
||||
block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
|
||||
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
|
||||
block_idx_last_computed_token: torch.Tensor # shape: [batch,]
|
||||
num_computed_tokens_p: torch.Tensor # shape: [batch,]
|
||||
|
||||
|
||||
class Mamba1AttentionMetadataBuilder(
|
||||
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> Mamba1AttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
)
|
||||
|
||||
has_initial_states_p = None
|
||||
query_start_loc_p = None
|
||||
padded_decodes = num_decodes
|
||||
num_computed_tokens, num_computed_tokens_p = None, None
|
||||
block_idx_first_scheduled_token = None
|
||||
block_idx_first_scheduled_token_p = None
|
||||
|
||||
# TODO(@Josephasafg) Mamba1 and Mamba2 have a lot of code in common here.
|
||||
# We should consolidate this code
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
# Return a tensor of shape (#requests, #max blocks)
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor
|
||||
mamba_block_size = self.kv_cache_spec.block_size
|
||||
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||
self.device
|
||||
)
|
||||
(
|
||||
block_idx_last_computed_token,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
) = self._compute_prefix_caching_block_indices(
|
||||
common_attn_metadata, mamba_block_size
|
||||
)
|
||||
else:
|
||||
# Always return just a single block per each request:
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
block_idx_last_scheduled_token = None
|
||||
block_idx_last_computed_token = None
|
||||
|
||||
if num_prefills > 0:
|
||||
query_start_loc_p = (
|
||||
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
has_initial_states_cpu = (
|
||||
common_attn_metadata.num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
> 0
|
||||
)
|
||||
has_initial_states_p = has_initial_states_cpu.to(
|
||||
common_attn_metadata.query_start_loc.device
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
assert num_computed_tokens is not None
|
||||
num_computed_tokens_p = num_computed_tokens[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
assert block_idx_first_scheduled_token is not None
|
||||
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
|
||||
elif (
|
||||
num_decodes > 0
|
||||
and num_decodes <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes)
|
||||
self.state_indices_tensor[:num_decodes].copy_(
|
||||
state_indices_tensor, non_blocking=True
|
||||
)
|
||||
state_indices_tensor = self.state_indices_tensor[:padded_decodes]
|
||||
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
self.block_idx_last_scheduled_token[:num_decodes].copy_(
|
||||
block_idx_last_scheduled_token, non_blocking=True
|
||||
)
|
||||
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
|
||||
:padded_decodes
|
||||
]
|
||||
block_idx_last_scheduled_token[num_decodes:] = 0
|
||||
|
||||
self.block_idx_last_computed_token[:num_decodes].copy_(
|
||||
block_idx_last_computed_token, non_blocking=True
|
||||
)
|
||||
block_idx_last_computed_token = self.block_idx_last_computed_token[
|
||||
:padded_decodes
|
||||
]
|
||||
block_idx_last_computed_token[num_decodes:] = 0
|
||||
|
||||
return Mamba1AttentionMetadata(
|
||||
query_start_loc_p=query_start_loc_p,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_padded_decodes=padded_decodes,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
|
||||
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
|
||||
block_idx_last_computed_token=block_idx_last_computed_token,
|
||||
num_computed_tokens_p=num_computed_tokens_p,
|
||||
)
|
||||
354
vllm_old/v1/attention/backends/mamba2_attn.py
Normal file
354
vllm_old/v1/attention/backends/mamba2_attn.py
Normal file
@@ -0,0 +1,354 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
PAD_SLOT_ID,
|
||||
CommonAttentionMetadata,
|
||||
compute_causal_conv1d_metadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
def compute_varlen_chunk_metadata(
|
||||
query_start_loc: torch.Tensor,
|
||||
chunk_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Build chunk-aligned, variable-length metadata used by Mamba2 SSD kernels.
|
||||
|
||||
Given per-sequence cumulative token starts `query_start_loc` of shape [B+1]
|
||||
and a physical `chunk_size`, returns three tensors on the same device:
|
||||
- cu_chunk_seqlens: (nchunks+1,) int32 exclusive prefix-sum of
|
||||
logical-chunk lengths (each logical chunk never crosses a sequence or
|
||||
physical-chunk boundary).
|
||||
- last_chunk_indices: (B,) int32 index of the last logical chunk
|
||||
for each sequence (=-1 for empty sequences).
|
||||
- seq_idx_chunks: (nchunks,) int32 sequence index for each logical
|
||||
chunk in order.
|
||||
|
||||
This is intentionally lightweight and CPU-side; it mirrors the metadata
|
||||
produced by the V1 Mamba2 meta-data builder and is exported so tests
|
||||
(and other callers) can avoid duplicating the logic.
|
||||
"""
|
||||
assert query_start_loc.ndim == 1, "query_start_loc must be 1-D [B+1]"
|
||||
assert int(query_start_loc[0].item()) == 0, "query_start_loc[0] must be 0"
|
||||
device = query_start_loc.device
|
||||
|
||||
qsl64 = query_start_loc.to(torch.int64)
|
||||
starts = qsl64[:-1].tolist()
|
||||
ends = qsl64[1:].tolist()
|
||||
total = int(qsl64[-1].item())
|
||||
|
||||
chunk_lens: list[int] = []
|
||||
seq_idx_chunks: list[int] = []
|
||||
last_chunk_indices: list[int] = [-1] * len(starts)
|
||||
|
||||
for b, (s, e) in enumerate(zip(starts, ends)):
|
||||
if e <= s:
|
||||
# empty sequence
|
||||
continue
|
||||
pos = s
|
||||
while pos < e:
|
||||
# split at both sequence boundaries and physical chunk boundaries
|
||||
room = chunk_size - (pos % chunk_size)
|
||||
take = min(room, e - pos)
|
||||
chunk_lens.append(int(take))
|
||||
seq_idx_chunks.append(b)
|
||||
last_chunk_indices[b] = len(chunk_lens) - 1
|
||||
pos += take
|
||||
|
||||
# Exclusive prefix sum over logical-chunk lengths
|
||||
if chunk_lens:
|
||||
cu_chunk_seqlens = torch.tensor(
|
||||
[0] + list(itertools.accumulate(chunk_lens)),
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
# Final boundary must equal total tokens
|
||||
assert int(cu_chunk_seqlens[-1].item()) == total
|
||||
else:
|
||||
cu_chunk_seqlens = torch.tensor([0], device=device, dtype=torch.int32)
|
||||
|
||||
last_chunk_indices_t = (
|
||||
torch.tensor(last_chunk_indices, device=device, dtype=torch.int32)
|
||||
if len(starts) > 0
|
||||
else torch.empty((0,), device=device, dtype=torch.int32)
|
||||
)
|
||||
seq_idx_chunks_t = torch.tensor(seq_idx_chunks, device=device, dtype=torch.int32)
|
||||
return cu_chunk_seqlens, last_chunk_indices_t, seq_idx_chunks_t
|
||||
|
||||
|
||||
class Mamba2AttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]:
|
||||
return Mamba2AttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mamba2AttentionMetadata:
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
query_start_loc_p: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
prep_initial_states: bool
|
||||
chunk_size: int
|
||||
|
||||
# The following tensors only contain prefill requests and will be None if
|
||||
# the batch has no prefill request.
|
||||
has_initial_states_p: torch.Tensor | None
|
||||
seq_idx_p: torch.Tensor | None
|
||||
|
||||
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
|
||||
# each chunk, its offests into the varlen sequence dimension. It is defined
|
||||
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
|
||||
# cu_chunk_seqlen_p[i+1].
|
||||
cu_chunk_seqlen_p: torch.Tensor | None
|
||||
|
||||
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
|
||||
# index of the last chunk for every sequence in the (prefill) batch.
|
||||
last_chunk_indices_p: torch.Tensor | None
|
||||
|
||||
state_indices_tensor: torch.Tensor # shape: [batch,]
|
||||
block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
|
||||
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
|
||||
block_idx_last_computed_token: torch.Tensor # shape: [batch,]
|
||||
num_computed_tokens_p: torch.Tensor # shape: [batch,]
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: dict | None = None
|
||||
batch_ptr: torch.Tensor | None = None
|
||||
token_chunk_offset_ptr: torch.Tensor | None = None
|
||||
|
||||
|
||||
class Mamba2AttentionMetadataBuilder(
|
||||
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
|
||||
assert self.chunk_size is not None, (
|
||||
"chunk_size needs to be set in the model config for Mamba2 models"
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> Mamba2AttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
query_start_loc_p = None
|
||||
seq_idx_p = None
|
||||
cu_chunk_seqlen_p = None
|
||||
last_chunk_indices_p = None
|
||||
|
||||
# Need flags to indicate if there are initial states
|
||||
has_initial_states_p = None
|
||||
prep_initial_states = False
|
||||
|
||||
# for causal_conv1d
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
num_computed_tokens, num_computed_tokens_p = None, None
|
||||
block_idx_first_scheduled_token = None
|
||||
block_idx_first_scheduled_token_p = None
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
# Return a tensor of shape (#requests, #max blocks)
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor
|
||||
# Additional cache-related varaiables:
|
||||
mamba_block_size = self.kv_cache_spec.block_size
|
||||
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||
self.device
|
||||
)
|
||||
(
|
||||
block_idx_last_computed_token,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
) = self._compute_prefix_caching_block_indices(
|
||||
common_attn_metadata, mamba_block_size
|
||||
)
|
||||
else:
|
||||
# Always return just a single block per each request:
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
# Additional cache-related varaiables:
|
||||
block_idx_last_scheduled_token = None
|
||||
block_idx_last_computed_token = None
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
)
|
||||
|
||||
# Compute seq_idx for prefill only
|
||||
if num_prefills > 0:
|
||||
# [batch,]
|
||||
has_initial_states_cpu = (
|
||||
common_attn_metadata.num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
> 0
|
||||
)
|
||||
prep_initial_states = torch.any(has_initial_states_cpu).item()
|
||||
has_initial_states_p = has_initial_states_cpu.to(
|
||||
common_attn_metadata.query_start_loc.device
|
||||
)
|
||||
|
||||
query_start_loc_p = (
|
||||
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
assert num_computed_tokens is not None
|
||||
num_computed_tokens_p = num_computed_tokens[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
assert block_idx_first_scheduled_token is not None
|
||||
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
query_start_loc_p_cpu = (
|
||||
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
|
||||
# The code below carefully constructs the chunks such that:
|
||||
# 1. Chunks contain tokens from a *single* sequence only.
|
||||
# 2. For every sequence, we are guaranteed that we can
|
||||
# retrieve the mamba state *every* chunk_size tokens.
|
||||
# Constraint (1) dramatically simplifies the mamba2 kernels.
|
||||
# Constraint (2) dramatically simplifies the implementation
|
||||
# of prefix caching for mamba2 (wip). We need to take care
|
||||
# of the interaction with chunked prefill in order to
|
||||
# satisfy constraint (2).
|
||||
# TODO (tdoublep): This code could probably be optimized.
|
||||
cu_chunk_seqlen = []
|
||||
seq_idx = []
|
||||
last_chunk_indices = []
|
||||
seqlen_pos = 0
|
||||
for req_idx in range(num_prefills):
|
||||
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
|
||||
this_new_tokens = (
|
||||
query_start_loc_p_cpu[req_idx + 1].item()
|
||||
- query_start_loc_p_cpu[req_idx].item()
|
||||
)
|
||||
|
||||
# if computed tokens are not chunk-aligned, use the first
|
||||
# chunk to finish it off
|
||||
if this_num_computed % self.chunk_size != 0:
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
# how many tokens to finish the chunk?
|
||||
chunk_len = (
|
||||
cdiv(this_num_computed, self.chunk_size) * self.chunk_size
|
||||
- this_num_computed
|
||||
)
|
||||
# we can only use at most this_new_tokens
|
||||
chunk_len = min(chunk_len, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
n_chunks = cdiv(this_new_tokens, self.chunk_size)
|
||||
for chunk in range(n_chunks):
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
chunk_len = min(self.chunk_size, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
assert this_new_tokens == 0
|
||||
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
|
||||
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
|
||||
seq_idx_p = torch.as_tensor(
|
||||
seq_idx, device=query_start_loc_p.device, dtype=torch.int32
|
||||
)
|
||||
cu_chunk_seqlen_p = torch.as_tensor(
|
||||
cu_chunk_seqlen, device=query_start_loc_p.device, dtype=torch.int32
|
||||
)
|
||||
last_chunk_indices_p = torch.as_tensor(
|
||||
last_chunk_indices, device=query_start_loc_p.device, dtype=torch.int32
|
||||
)
|
||||
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||
compute_causal_conv1d_metadata(query_start_loc_p)
|
||||
)
|
||||
|
||||
elif (
|
||||
num_decodes <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
# Pad state tensor for CUDA graph
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
|
||||
self.state_indices_tensor[:num_decodes].copy_(
|
||||
state_indices_tensor, non_blocking=True
|
||||
)
|
||||
state_indices_tensor = self.state_indices_tensor[:num_input_tokens]
|
||||
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
self.block_idx_last_scheduled_token[:num_decodes].copy_(
|
||||
block_idx_last_scheduled_token, non_blocking=True
|
||||
)
|
||||
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
|
||||
:num_input_tokens
|
||||
]
|
||||
block_idx_last_scheduled_token[num_decodes:] = 0
|
||||
|
||||
self.block_idx_last_computed_token[:num_decodes].copy_(
|
||||
block_idx_last_computed_token, non_blocking=True
|
||||
)
|
||||
block_idx_last_computed_token = self.block_idx_last_computed_token[
|
||||
:num_input_tokens
|
||||
]
|
||||
block_idx_last_computed_token[num_decodes:] = 0
|
||||
|
||||
attn_metadata = Mamba2AttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
query_start_loc_p=query_start_loc_p,
|
||||
seq_lens=seq_lens,
|
||||
prep_initial_states=prep_initial_states,
|
||||
chunk_size=self.chunk_size,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
seq_idx_p=seq_idx_p,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
cu_chunk_seqlen_p=cu_chunk_seqlen_p,
|
||||
last_chunk_indices_p=last_chunk_indices_p,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
|
||||
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
|
||||
block_idx_last_computed_token=block_idx_last_computed_token,
|
||||
num_computed_tokens_p=num_computed_tokens_p,
|
||||
)
|
||||
return attn_metadata
|
||||
115
vllm_old/v1/attention/backends/mamba_attn.py
Normal file
115
vllm_old/v1/attention/backends/mamba_attn.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import abc
|
||||
from typing import ClassVar, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
M = TypeVar("M")
|
||||
|
||||
|
||||
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
reorder_batch_threshold: int = 1
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.decode_cudagraph_max_bs = min(
|
||||
self.vllm_config.scheduler_config.max_num_seqs,
|
||||
self.compilation_config.max_cudagraph_capture_size,
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
self.state_indices_tensor = torch.empty(
|
||||
(
|
||||
self.decode_cudagraph_max_bs,
|
||||
cdiv(
|
||||
self.vllm_config.model_config.max_model_len,
|
||||
self.kv_cache_spec.block_size,
|
||||
),
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.block_idx_last_scheduled_token = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.block_idx_last_computed_token = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> M:
|
||||
"""
|
||||
This method builds the metadata for full cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with Mamba.
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
|
||||
assert m.num_reqs == m.num_actual_tokens, (
|
||||
"Mamba only supports decode-only full CUDAGraph capture. "
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
)
|
||||
|
||||
m.max_query_len = 1 # decode-only
|
||||
|
||||
return self.build(0, m)
|
||||
|
||||
def _compute_prefix_caching_block_indices(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
mamba_block_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||
self.device
|
||||
)
|
||||
# Block index of the last computed token
|
||||
block_idx_last_computed_token = cdiv(num_computed_tokens, mamba_block_size) - 1
|
||||
# which is <= block index for the first scheduled token
|
||||
block_idx_first_scheduled_token = (
|
||||
cdiv(num_computed_tokens + 1, mamba_block_size) - 1
|
||||
)
|
||||
# which is <= block index of the last scheduled token
|
||||
block_idx_last_scheduled_token = (
|
||||
cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1
|
||||
)
|
||||
# -1 in case it's non-computed and causes later issues with indexing
|
||||
block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0)
|
||||
|
||||
return (
|
||||
block_idx_last_computed_token,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
)
|
||||
0
vllm_old/v1/attention/backends/mla/__init__.py
Normal file
0
vllm_old/v1/attention/backends/mla/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2200
vllm_old/v1/attention/backends/mla/common.py
Normal file
2200
vllm_old/v1/attention/backends/mla/common.py
Normal file
File diff suppressed because it is too large
Load Diff
275
vllm_old/v1/attention/backends/mla/cutlass_mla.py
Normal file
275
vllm_old/v1/attention/backends/mla/cutlass_mla.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
# enable full CUDA Graph support for decode-only capture
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
)
|
||||
|
||||
|
||||
class CutlassMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [128]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CUTLASS_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["CutlassMLAImpl"]:
|
||||
return CutlassMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
|
||||
return CutlassMLAMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major == 10
|
||||
|
||||
|
||||
class SM100Workspace:
|
||||
def __init__(self, initial_workspace_size):
|
||||
self._workspace_buf = torch.empty(
|
||||
initial_workspace_size, device="cuda", dtype=torch.uint8
|
||||
)
|
||||
|
||||
self._block_size = 128 # Forced to 128
|
||||
|
||||
# Pre-compute sm_count to avoid recomputing it. Use device 0 as a proxy
|
||||
# (assumes all devices are similar)
|
||||
properties = torch.cuda.get_device_properties(torch.device("cuda:0"))
|
||||
self._sm_count = properties.multi_processor_count
|
||||
|
||||
def get_buf(self):
|
||||
return self._workspace_buf
|
||||
|
||||
def ensure_size(self, attn_metadata: MLACommonMetadata, num_kv_splits: int):
|
||||
batch_size = attn_metadata.num_reqs
|
||||
max_seq_len = attn_metadata.max_query_len
|
||||
|
||||
workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
|
||||
max_seq_len * self._block_size,
|
||||
batch_size,
|
||||
self._sm_count,
|
||||
num_kv_splits=num_kv_splits,
|
||||
)
|
||||
|
||||
if self._workspace_buf.shape[0] < workspace_size:
|
||||
self._workspace_buf.resize_(workspace_size)
|
||||
|
||||
|
||||
g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB
|
||||
|
||||
MAX_HEADS = 128
|
||||
|
||||
|
||||
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
q_pad_num_heads=MAX_HEADS,
|
||||
**mla_args,
|
||||
)
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"CutlassMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"CutlassMLAImpl"
|
||||
)
|
||||
|
||||
# TODO: Currently, num_kv_splits is limited to 16 to avoid hanging
|
||||
# issues. In case the code hangs, use:
|
||||
# FORCE_NUM_KV_SPLITS=1
|
||||
force_num_kv_splits = os.environ.get("FORCE_NUM_KV_SPLITS", None)
|
||||
if force_num_kv_splits:
|
||||
logger.debug_once("Forcing num_kv_splits to %d", int(force_num_kv_splits))
|
||||
self._num_kv_splits = int(force_num_kv_splits)
|
||||
else:
|
||||
self._num_kv_splits = -1 # => Auto-detect
|
||||
|
||||
# Share workspace buffer across all executions
|
||||
self._workspace = g_sm100_workspace
|
||||
|
||||
def _sm100_cutlass_mla_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
page_table: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
sm_scale: float,
|
||||
num_kv_splits: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
|
||||
assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}"
|
||||
assert kv_c_and_k_pe_cache.ndim == 3, (
|
||||
"kv_c_and_k_pe_cache must be a 3D tensor, but got {}".format(
|
||||
kv_c_and_k_pe_cache.ndim
|
||||
)
|
||||
)
|
||||
|
||||
B_q, H, D_q_nope = q_nope.shape
|
||||
B_q_2, H_2, D_q_pe = q_pe.shape
|
||||
assert (B_q == B_q_2) and (H == H_2)
|
||||
|
||||
_, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape
|
||||
|
||||
D_latent = 512
|
||||
D_rope = 64
|
||||
assert D_q_nope == D_latent
|
||||
assert D_q_pe == D_rope
|
||||
assert D_ckv == D_latent + D_rope
|
||||
|
||||
MAX_HEADS = 128
|
||||
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
|
||||
|
||||
assert len(page_table.shape) == 2
|
||||
B_block_table, block_num = page_table.shape
|
||||
assert B_block_table == B_q
|
||||
assert block_num > 0, f"block num must be greater than 0, got {block_num}"
|
||||
assert block_num % (128 / PAGE_SIZE) == 0
|
||||
|
||||
assert q_nope.dtype in (torch.float16, torch.bfloat16, torch.float8_e4m3fn), (
|
||||
f"q_nope.dtype needs to be fp16 or bf16 or e4m3 but got {q_nope.dtype}."
|
||||
)
|
||||
assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
|
||||
assert seq_lens.dtype == torch.int32, (
|
||||
f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}."
|
||||
)
|
||||
assert page_table.dtype == torch.int32, (
|
||||
f"page_table.dtype needs to be int32 but got {page_table.dtype}."
|
||||
)
|
||||
|
||||
dtype = (
|
||||
torch.bfloat16
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype)
|
||||
else q_nope.dtype
|
||||
)
|
||||
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype)
|
||||
lse = (
|
||||
torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device)
|
||||
if self.need_to_return_lse_for_decode
|
||||
else torch.Tensor()
|
||||
)
|
||||
|
||||
ops.sm100_cutlass_mla_decode(
|
||||
out,
|
||||
lse,
|
||||
q_nope,
|
||||
q_pe,
|
||||
kv_c_and_k_pe_cache,
|
||||
seq_lens,
|
||||
page_table,
|
||||
workspace,
|
||||
sm_scale,
|
||||
num_kv_splits,
|
||||
)
|
||||
|
||||
if H < MAX_HEADS:
|
||||
# Extract the subsets of the outputs
|
||||
lse = lse[:, :H] if self.need_to_return_lse_for_decode else lse
|
||||
out = out[:, :H]
|
||||
|
||||
return out, lse
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q_nope, q_pe = q
|
||||
else:
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
# Adjust workspace size (if necessary)
|
||||
self._workspace.ensure_size(attn_metadata, self._num_kv_splits)
|
||||
|
||||
# Run MLA
|
||||
o, lse = self._sm100_cutlass_mla_decode(
|
||||
q_nope,
|
||||
q_pe,
|
||||
kv_c_and_k_pe_cache,
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.block_table,
|
||||
self._workspace.get_buf(),
|
||||
self.scale,
|
||||
self._num_kv_splits,
|
||||
)
|
||||
|
||||
return o, (lse if self.need_to_return_lse_for_decode else None)
|
||||
337
vllm_old/v1/attention/backends/mla/flashattn_mla.py
Normal file
337
vllm_old/v1/attention/backends/mla/flashattn_mla.py
Normal file
@@ -0,0 +1,337 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.attention.utils.fa_utils import (
|
||||
flash_attn_supports_mla,
|
||||
get_flash_attn_version,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashAttnMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]:
|
||||
return FlashAttnMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashAttnMLAImpl"]:
|
||||
return FlashAttnMLAImpl
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major == 9
|
||||
|
||||
@classmethod
|
||||
def supports_combination(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
device_capability: DeviceCapability,
|
||||
) -> str | None:
|
||||
if not flash_attn_supports_mla():
|
||||
return "FlashAttention MLA not supported on this device"
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
query_start_loc: torch.Tensor
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
scheduler_metadata: torch.Tensor | None = None
|
||||
max_num_splits: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN
|
||||
reorder_batch_threshold: int = 512 # process small prefills with decode pathway
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
vllm_config,
|
||||
device,
|
||||
FlashAttnMLAMetadata,
|
||||
supports_dcp_with_varlen=True,
|
||||
)
|
||||
self.max_num_splits = 0 # No upper bound on the number of splits.
|
||||
self.fa_aot_schedule = get_flash_attn_version() == 3
|
||||
|
||||
self.use_full_cuda_graph = (
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
)
|
||||
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
|
||||
|
||||
if self.use_full_cuda_graph and self.fa_aot_schedule:
|
||||
self.scheduler_metadata = torch.zeros(
|
||||
vllm_config.scheduler_config.max_num_seqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
# When using cuda graph, we need to set the upper bound of the
|
||||
# number of splits so that large enough intermediate buffers are
|
||||
# pre-allocated during capture.
|
||||
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
self.max_num_splits = 1
|
||||
|
||||
def _schedule_decode(
|
||||
self,
|
||||
num_reqs,
|
||||
cu_query_lens,
|
||||
max_query_len,
|
||||
seqlens,
|
||||
max_seq_len,
|
||||
causal,
|
||||
max_num_splits,
|
||||
):
|
||||
if self.fa_aot_schedule:
|
||||
return get_scheduler_metadata(
|
||||
batch_size=num_reqs,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_seq_len,
|
||||
num_heads_q=self.num_heads * self.dcp_world_size,
|
||||
num_heads_kv=1,
|
||||
headdim=self.mla_dims.qk_rope_head_dim,
|
||||
cache_seqlens=seqlens,
|
||||
qkv_dtype=self.kv_cache_spec.dtype,
|
||||
headdim_v=self.mla_dims.kv_lora_rank,
|
||||
page_size=self.page_size,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
causal=causal,
|
||||
num_splits=max_num_splits,
|
||||
)
|
||||
return None
|
||||
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
) -> FlashAttnMLADecodeMetadata:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
max_query_len = query_lens_cpu.max().item()
|
||||
max_seq_len = seq_lens_cpu.max().item()
|
||||
|
||||
# For Flash Attention MLA + full cudagraph
|
||||
max_num_splits = 0
|
||||
if self.use_full_cuda_graph and num_decode_tokens <= self.max_cudagraph_size:
|
||||
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
||||
# usage, because the intermediate buffers of size [num_splits,
|
||||
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
max_num_splits = 1
|
||||
|
||||
scheduler_metadata = self._schedule_decode(
|
||||
num_reqs=seq_lens_cpu.numel(),
|
||||
cu_query_lens=query_start_loc_device,
|
||||
max_query_len=max_query_len,
|
||||
seqlens=seq_lens_device,
|
||||
max_seq_len=max_seq_len,
|
||||
causal=True,
|
||||
max_num_splits=max_num_splits,
|
||||
)
|
||||
|
||||
if self.use_full_cuda_graph and scheduler_metadata is not None:
|
||||
n = scheduler_metadata.shape[0]
|
||||
# Ensure the persistent buffer is large enough
|
||||
assert n <= self.scheduler_metadata.shape[0], (
|
||||
f"Scheduler metadata size {n} exceeds buffer size "
|
||||
+ f"{self.scheduler_metadata.shape[0]}"
|
||||
)
|
||||
self.scheduler_metadata[:n] = scheduler_metadata
|
||||
# NOTE(woosuk): We should zero out the rest of the scheduler
|
||||
# metadata to guarantee the correctness. Otherwise, some thread
|
||||
# blocks may use the invalid scheduler metadata and overwrite the
|
||||
# output buffer.
|
||||
self.scheduler_metadata[n:] = 0
|
||||
scheduler_metadata = self.scheduler_metadata[:n]
|
||||
|
||||
metadata = FlashAttnMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
query_start_loc=query_start_loc_device,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
max_num_splits=max_num_splits,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
|
||||
assert flash_attn_supports_mla(), "FlashAttnMLA is not supported on this device"
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashAttnMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashAttnMLAImpl"
|
||||
)
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlashAttnMLA V1 with FP8 KV cache not yet supported"
|
||||
)
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttnMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q_nope, q_pe = q
|
||||
else:
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 FlashAttention MLA not yet supported")
|
||||
|
||||
kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
|
||||
k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank :]
|
||||
|
||||
# NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the
|
||||
# kernel uses this to calculate grid dimensions. Ensure it's at least 1
|
||||
# to prevent invalid grid configuration during graph capture.
|
||||
max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)
|
||||
|
||||
attn_out = flash_attn_varlen_func(
|
||||
q=q_pe,
|
||||
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
|
||||
q_v=q_nope,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
cu_seqlens_q=attn_metadata.decode.query_start_loc,
|
||||
max_seqlen_k=attn_metadata.decode.max_seq_len,
|
||||
seqused_k=attn_metadata.decode.seq_lens,
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
return_softmax_lse=self.need_to_return_lse_for_decode,
|
||||
fa_version=3, # only version 3 is supported
|
||||
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
|
||||
num_splits=attn_metadata.decode.max_num_splits,
|
||||
cp_world_size=self.dcp_world_size,
|
||||
cp_rank=self.dcp_rank,
|
||||
cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens,
|
||||
)
|
||||
|
||||
if self.need_to_return_lse_for_decode:
|
||||
o, lse = attn_out
|
||||
# FA returns LSE in shape [ H, B ] but DCP wants [ B, H ]
|
||||
return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ]
|
||||
else:
|
||||
o = attn_out
|
||||
return o, None
|
||||
171
vllm_old/v1/attention/backends/mla/flashinfer_mla.py
Normal file
171
vllm_old/v1/attention/backends/mla/flashinfer_mla.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport, KVCacheLayoutType
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
||||
|
||||
|
||||
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
||||
|
||||
|
||||
class FlashInferMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [32, 64]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHINFER_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashInferMLAImpl"]:
|
||||
return FlashInferMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
|
||||
return FlashInferMLAMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major == 10
|
||||
|
||||
@classmethod
|
||||
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
|
||||
return "HND"
|
||||
|
||||
|
||||
g_fi_workspace = torch.zeros(
|
||||
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
|
||||
class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashInferMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashInferMLAImpl"
|
||||
)
|
||||
|
||||
self._workspace_buffer = g_fi_workspace
|
||||
self.bmm1_scale: float | None = None
|
||||
self.bmm2_scale: float | None = None
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if isinstance(q, tuple):
|
||||
q_nope, q_pe = q
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
|
||||
# trtllm API requires extra dimension q_len_per_request for MTP
|
||||
if attn_metadata.num_decode_tokens % attn_metadata.num_decodes != 0:
|
||||
logger.warning_once(
|
||||
"""FlashInferMLAImpl got a query of uneven length.
|
||||
This usually indicates an issue in batch reordering
|
||||
or incorrect setup in dummy_run."""
|
||||
)
|
||||
q = q.unsqueeze(1)
|
||||
else:
|
||||
q = q.view(attn_metadata.num_decodes, -1, q.shape[-2], q.shape[-1])
|
||||
|
||||
if self.bmm1_scale is None:
|
||||
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
|
||||
if self.bmm2_scale is None:
|
||||
self.bmm2_scale = layer._v_scale_float
|
||||
|
||||
o = trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
seq_lens=attn_metadata.decode.seq_lens,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
)
|
||||
|
||||
# Flatten the output for consistent shape
|
||||
o = o.view(-1, o.shape[-2], o.shape[-1])
|
||||
|
||||
# TODO: Return LSE pending support from Flashinfer API:
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/1566
|
||||
return o, None
|
||||
314
vllm_old/v1/attention/backends/mla/flashmla.py
Normal file
314
vllm_old/v1/attention/backends/mla/flashmla.py
Normal file
@@ -0,0 +1,314 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf
|
||||
from vllm.attention.ops.flashmla import (
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_dense_supported,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
reshape_attn_output_for_spec_decode,
|
||||
reshape_query_for_spec_decode,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
|
||||
return FlashMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashMLAImpl"]:
|
||||
return FlashMLAImpl
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major in [9, 10]
|
||||
|
||||
@classmethod
|
||||
def supports_combination(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
device_capability: DeviceCapability,
|
||||
) -> str | None:
|
||||
if use_sparse:
|
||||
from vllm.attention.ops.flashmla import is_flashmla_sparse_supported
|
||||
|
||||
return is_flashmla_sparse_supported()[1]
|
||||
else:
|
||||
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
|
||||
return is_flashmla_dense_supported()[1]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
tile_scheduler_metadata: torch.Tensor
|
||||
num_splits: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
||||
reorder_batch_threshold: int = 128 # process small prefills with decode pathway
|
||||
# ^ TODO(matt): tune this
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata
|
||||
)
|
||||
|
||||
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
|
||||
self.cg_buf_tile_scheduler_metadata = None
|
||||
self.cg_buf_num_splits = None
|
||||
self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8")
|
||||
|
||||
device_properties = torch.cuda.get_device_properties(self.device)
|
||||
num_sms = device_properties.multi_processor_count
|
||||
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.cg_buf_tile_scheduler_metadata = torch.zeros(
|
||||
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
|
||||
# TileSchedulerMetaDataSize = 8
|
||||
(num_sms, 8),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.cg_buf_num_splits = torch.empty(
|
||||
(vllm_config.scheduler_config.max_num_seqs + 1),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
) -> FlashMLADecodeMetadata:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
# we use the max but all should be the same due to uniform length requirement
|
||||
max_query_len = query_lens_cpu.max().item()
|
||||
num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
seq_lens_device,
|
||||
num_q_tokens_per_head_k,
|
||||
1, # MQA for the decode path
|
||||
is_fp8_kvcache=self.is_fp8_kvcache,
|
||||
)
|
||||
|
||||
# TODO: we can disambiguate between decode and mixed-prefill decode here
|
||||
# so we can only use the persistent buffer if a cudagraph is actually
|
||||
# being used.
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
assert self.cg_buf_tile_scheduler_metadata is not None
|
||||
assert self.cg_buf_num_splits is not None
|
||||
|
||||
sm_parts = tile_scheduler_metadata.size(0)
|
||||
# Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
|
||||
assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
|
||||
tile_scheduler_metadata_view = self.cg_buf_tile_scheduler_metadata[
|
||||
:sm_parts
|
||||
]
|
||||
tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
|
||||
tile_scheduler_metadata = tile_scheduler_metadata_view
|
||||
|
||||
# Num splits is per-batch, varying size (batch_size,)
|
||||
n = num_splits.size(0)
|
||||
# make sure static buffer is large enough
|
||||
assert n <= self.cg_buf_num_splits.size(0)
|
||||
num_splits_view = self.cg_buf_num_splits[:n]
|
||||
num_splits_view.copy_(num_splits)
|
||||
# Num splits needs to monotonically increasing
|
||||
# (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
|
||||
# it needs to monotonically increasing by 1)
|
||||
self.cg_buf_num_splits[n:].fill_(num_splits[-1])
|
||||
num_splits = num_splits_view
|
||||
|
||||
return FlashMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
tile_scheduler_metadata=tile_scheduler_metadata,
|
||||
num_splits=num_splits,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
)
|
||||
|
||||
|
||||
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
|
||||
is_supported, reason = is_flashmla_dense_supported()
|
||||
assert is_supported, reason
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashMLAImpl"
|
||||
)
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
# TODO: (zyongye) decode function for mla here
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
# mypy assertion: q is now always a tensor
|
||||
assert isinstance(q, torch.Tensor)
|
||||
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
q = reshape_query_for_spec_decode(q, num_decodes)
|
||||
|
||||
tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata
|
||||
num_splits = attn_metadata.decode.num_splits
|
||||
if vllm_is_batch_invariant():
|
||||
device = q.device
|
||||
dtype = torch.int32
|
||||
|
||||
B = q.shape[0]
|
||||
# block_table shape: [batch_size, max_num_blocks_per_seq]
|
||||
# The number of blocks per sequence is in the second dimension
|
||||
topk = attn_metadata.decode.block_table.shape[-1]
|
||||
B_TOPK = 64
|
||||
assert topk % B_TOPK == 0, f"topk ({topk}) must be divisible by {B_TOPK}"
|
||||
end_block_idx = topk // B_TOPK
|
||||
|
||||
# Single partition => num_sm_parts = 1
|
||||
# TileSchedulerMetaDataSize = 8, layout:
|
||||
# [begin_idx, begin_block_idx, end_idx, end_block_idx,
|
||||
# begin_n_split_idx, _, _, _]
|
||||
tile_scheduler_metadata = torch.zeros((1, 8), dtype=dtype, device=device)
|
||||
tile_scheduler_metadata[0, 0] = 0 # begin_idx
|
||||
tile_scheduler_metadata[0, 1] = 0 # sched_begin_block_idx
|
||||
tile_scheduler_metadata[0, 2] = B - 1 # end_idx
|
||||
tile_scheduler_metadata[0, 3] = end_block_idx
|
||||
tile_scheduler_metadata[0, 4] = 0 # begin_n_split_idx
|
||||
# fields [5..7] stay 0
|
||||
|
||||
# Non-split path ignores num_splits, but the API requires it:
|
||||
# zeros of length B+1
|
||||
num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
|
||||
|
||||
o, lse = flash_mla_with_kvcache(
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
head_dim_v=self.kv_lora_rank,
|
||||
tile_scheduler_metadata=tile_scheduler_metadata,
|
||||
num_splits=num_splits,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
descale_q=layer._q_scale.reshape(1),
|
||||
descale_k=layer._k_scale.reshape(1),
|
||||
)
|
||||
|
||||
o = reshape_attn_output_for_spec_decode(o)
|
||||
|
||||
return o, lse
|
||||
560
vllm_old/v1/attention/backends/mla/flashmla_sparse.py
Normal file
560
vllm_old/v1/attention/backends/mla/flashmla_sparse.py
Normal file
@@ -0,0 +1,560 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionLayer,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.backends.utils import get_mla_dims
|
||||
from vllm.attention.ops.flashmla import (
|
||||
flash_mla_sparse_prefill,
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.deepseek_v2 import Indexer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
"""
|
||||
NOTE: FlashMLA Sparse uses an fp8 cache with the following format
|
||||
|
||||
In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
|
||||
structured as:
|
||||
- **First 512 bytes:** The "quantized NoPE" part, containing 512
|
||||
`float8_e4m3` values.
|
||||
- **Next 16 bytes:** Scale factors, containing 4 `float32` values.
|
||||
The first `float32` is the scale for the first 128 `float8_e4m3` values,
|
||||
the second for the next 128, and so on.
|
||||
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
|
||||
part is not quantized for accuracy.
|
||||
"""
|
||||
|
||||
|
||||
class FlashMLASparseBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA_SPARSE"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]:
|
||||
return FlashMLASparseMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashMLASparseImpl"]:
|
||||
return FlashMLASparseImpl
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576]
|
||||
|
||||
@classmethod
|
||||
def is_mla(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_sparse(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major in [9, 10]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if cache_dtype_str == "fp8_ds_mla":
|
||||
# custom storage fromat is 656 bytes
|
||||
# see FlashMLA readme.md for details
|
||||
return (num_blocks, block_size, 656)
|
||||
else:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLASparseMetadata:
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
query_start_loc: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
block_table: torch.Tensor
|
||||
req_id_per_token: torch.Tensor
|
||||
block_size: int = 64
|
||||
topk_tokens: int = 2048
|
||||
|
||||
@dataclass
|
||||
class FP8KernelMetadata:
|
||||
scheduler_metadata: torch.Tensor | None
|
||||
num_splits: torch.Tensor
|
||||
dummy_block_table: torch.Tensor
|
||||
cache_lens: torch.Tensor
|
||||
|
||||
fp8_extra_metadata: FP8KernelMetadata | None = None
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _convert_req_index_to_global_index_kernel(
|
||||
req_id_ptr, # int32 [num_tokens]
|
||||
block_table_ptr, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
# shapes (compile-time where possible)
|
||||
max_num_blocks_per_req: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr, # tile width along columns
|
||||
# strides (in elements)
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
ti_stride0,
|
||||
ti_stride1,
|
||||
out_stride0,
|
||||
out_stride1,
|
||||
):
|
||||
# program_id(0) -> token_id (row)
|
||||
# program_id(1) -> tile index along columns
|
||||
token_id = tl.program_id(0)
|
||||
tile_id = tl.program_id(1)
|
||||
|
||||
# Each program covers BLOCK_N consecutive columns
|
||||
indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
# Load request id for this token (no mask: grid is exact)
|
||||
req = tl.load(req_id_ptr + token_id)
|
||||
|
||||
# Load token indices for this tile
|
||||
ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1
|
||||
tok = tl.load(ti_ptr) # int32
|
||||
|
||||
# Only token == -1 should propagate as -1
|
||||
is_invalid_tok = tok < 0
|
||||
|
||||
# Compute block id and in-block offset
|
||||
block_id = tok // BLOCK_SIZE
|
||||
inblock_off = tok % BLOCK_SIZE
|
||||
|
||||
# Guard block_table access
|
||||
valid_block = block_id < max_num_blocks_per_req
|
||||
bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
|
||||
base = tl.load(bt_ptr, mask=valid_block, other=0)
|
||||
|
||||
# If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset
|
||||
out_val = tl.where(
|
||||
is_invalid_tok | (~valid_block), -1, base * BLOCK_SIZE + inblock_off
|
||||
)
|
||||
|
||||
# Store results
|
||||
out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1
|
||||
tl.store(out_ptr_ij, out_val)
|
||||
|
||||
|
||||
def triton_convert_req_index_to_global_index(
|
||||
req_id: torch.Tensor, # int32 [num_tokens]
|
||||
block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
BLOCK_SIZE: int = 64,
|
||||
NUM_TOPK_TOKENS: int = 2048,
|
||||
BLOCK_N: int = 128, # tile width along columns
|
||||
):
|
||||
"""
|
||||
out[token_id, indice_id] =
|
||||
block_table[req_id[token_id],
|
||||
token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE
|
||||
+ token_indices[token_id, indice_id] % BLOCK_SIZE
|
||||
|
||||
Only when token_indices[token_id, indice_id] == -1 do we output -1.
|
||||
For safety, we also output -1 if the derived block_id would be
|
||||
out-of-bounds.
|
||||
"""
|
||||
assert req_id.dtype == torch.int32
|
||||
assert block_table.dtype == torch.int32
|
||||
assert token_indices.dtype == torch.int32
|
||||
assert token_indices.shape[1] == NUM_TOPK_TOKENS
|
||||
assert NUM_TOPK_TOKENS % BLOCK_N == 0, (
|
||||
f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible byBLOCK_N ({BLOCK_N})"
|
||||
)
|
||||
|
||||
num_tokens = req_id.shape[0]
|
||||
num_requests, max_num_blocks_per_req = block_table.shape
|
||||
tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N
|
||||
|
||||
# Ensure contiguous tensors on the same device
|
||||
req_id_c = req_id.contiguous()
|
||||
block_table_c = block_table.contiguous()
|
||||
token_indices_c = token_indices.contiguous()
|
||||
out = torch.empty_like(token_indices_c)
|
||||
|
||||
# Strides in elements
|
||||
bt_stride0, bt_stride1 = block_table_c.stride()
|
||||
ti_stride0, ti_stride1 = token_indices_c.stride()
|
||||
out_stride0, out_stride1 = out.stride()
|
||||
|
||||
# Exact 2D grid: tokens × column tiles
|
||||
grid = (num_tokens, tiles_per_row)
|
||||
|
||||
_convert_req_index_to_global_index_kernel[grid](
|
||||
req_id_c,
|
||||
block_table_c,
|
||||
token_indices_c,
|
||||
out,
|
||||
# shapes / constexprs
|
||||
max_num_blocks_per_req,
|
||||
BLOCK_SIZE,
|
||||
BLOCK_N,
|
||||
# strides
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
ti_stride0,
|
||||
ti_stride1,
|
||||
out_stride0,
|
||||
out_stride1,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
cache_config = vllm_config.cache_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.device = device
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
sm_count = props.multi_processor_count
|
||||
|
||||
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
|
||||
self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
|
||||
self.topk_tokens_tensor = torch.tensor(
|
||||
[self.topk_tokens], device=device, dtype=torch.int32
|
||||
)
|
||||
self.max_model_len_tensor = torch.tensor(
|
||||
[self.model_config.max_model_len], device=device, dtype=torch.int32
|
||||
)
|
||||
# this is ignored by `flash_mla_with_kvcache` if indices not None
|
||||
self.dummy_block_table = torch.empty(
|
||||
(1, 1), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
# Equation taken from FlashMLA/csrc/pybind.cpp
|
||||
h_q, h_k = self.num_heads, 1
|
||||
s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest
|
||||
max_num_sm_parts = int(
|
||||
max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1)
|
||||
)
|
||||
if current_platform.is_device_capability(100):
|
||||
max_num_sm_parts *= 2
|
||||
self.tile_scheduler_metadata_buffer = torch.empty(
|
||||
# TileSchedulerMetaDataSize = 8
|
||||
# see: FlashMLA/csrc/params.h
|
||||
(max_num_sm_parts, 8),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.num_splits_buffer = torch.empty(
|
||||
# We pack all the tokens into one batch for sparse attention.
|
||||
# Otherwise, we can exceed the sm of `get_mla_metadata`.
|
||||
(2,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.req_id_per_token_buffer = torch.empty(
|
||||
(vllm_config.scheduler_config.max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> FlashMLASparseMetadata:
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
req_id_per_token = np.repeat(
|
||||
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
|
||||
)
|
||||
# Zero-fill for cudagraphs
|
||||
self.req_id_per_token_buffer.fill_(0)
|
||||
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
|
||||
torch.from_numpy(req_id_per_token), non_blocking=True
|
||||
)
|
||||
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
|
||||
|
||||
fp8_extra_metadata = None
|
||||
if self.use_fp8_kv_cache:
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens=self.topk_tokens_tensor,
|
||||
num_q_tokens_per_head_k=num_tokens * self.num_heads,
|
||||
topk=self.topk_tokens,
|
||||
num_heads_q=self.num_heads,
|
||||
num_heads_k=1,
|
||||
is_fp8_kvcache=True,
|
||||
)
|
||||
|
||||
num_sm_parts = tile_scheduler_metadata.size(0)
|
||||
# Copy to persistent buffer for full-CG support
|
||||
tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[
|
||||
:num_sm_parts
|
||||
]
|
||||
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
|
||||
self.num_splits_buffer.copy_(num_splits)
|
||||
|
||||
fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
|
||||
scheduler_metadata=tile_scheduler_metadata_buffer,
|
||||
num_splits=self.num_splits_buffer,
|
||||
# cache_lens and block_table are basically unused in sparse case
|
||||
# but the decode kernel will treat -1 and indices >= cache_lens
|
||||
# as invalid so we make sure cache_lens is large enough to not
|
||||
# accidentally mark indices invalid, we will use -1 exclusively
|
||||
# to mark invalid indices
|
||||
cache_lens=self.max_model_len_tensor,
|
||||
dummy_block_table=self.dummy_block_table,
|
||||
)
|
||||
|
||||
metadata = FlashMLASparseMetadata(
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
block_table=common_attn_metadata.block_table_tensor,
|
||||
req_id_per_token=req_id_per_token,
|
||||
block_size=self.kv_cache_spec.block_size,
|
||||
topk_tokens=self.topk_tokens,
|
||||
fp8_extra_metadata=fp8_extra_metadata,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
topk_indice_buffer: torch.Tensor | None = None,
|
||||
indexer: Optional["Indexer"] = None,
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
self.softmax_scale = scale
|
||||
assert indexer is not None
|
||||
self.topk_indices_buffer = indexer.topk_indices_buffer
|
||||
self.padding = 128 if current_platform.is_device_capability(100) else 64
|
||||
|
||||
def _forward_bf16_kv(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = q.shape[0]
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
|
||||
-1, 1, kv_c_and_k_pe_cache.shape[-1]
|
||||
)
|
||||
|
||||
# NOTE(Chen): kernel requires num_local_head to be a multiple of
|
||||
# 64 on hopper and 128 on blackwell
|
||||
if self.num_heads % self.padding != 0:
|
||||
assert self.padding % self.num_heads == 0
|
||||
logger.warning_once(
|
||||
f"padding num_heads to {self.padding} \
|
||||
due to sparse attn kernel requirement"
|
||||
)
|
||||
q_padded = q.new_empty((q.shape[0], self.padding, q.shape[2]))
|
||||
q_padded[:, : self.num_heads, :] = q
|
||||
q = q_padded
|
||||
|
||||
topk_indices = topk_indices.view(num_tokens, 1, -1)
|
||||
output = flash_mla_sparse_prefill(
|
||||
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
|
||||
)
|
||||
output = output[:, : self.num_heads, :]
|
||||
return output
|
||||
|
||||
def _forward_fp8_kv(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert attn_metadata.fp8_extra_metadata is not None
|
||||
extra_metadata = attn_metadata.fp8_extra_metadata
|
||||
|
||||
_attn_out, _ = flash_mla_with_kvcache(
|
||||
q=q.unsqueeze(0), # unsqueeze to add batch_dim
|
||||
k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
|
||||
block_table=extra_metadata.dummy_block_table,
|
||||
head_dim_v=512,
|
||||
cache_seqlens=extra_metadata.cache_lens,
|
||||
tile_scheduler_metadata=extra_metadata.scheduler_metadata,
|
||||
num_splits=extra_metadata.num_splits,
|
||||
is_fp8_kvcache=True,
|
||||
indices=topk_indices.unsqueeze(0), # unsqueeze to add batch_dim
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
|
||||
return _attn_out
|
||||
|
||||
def forward_prepare(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
) -> None:
|
||||
self.positions = positions
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
q: torch.Tensor,
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
kv_cache_scale: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
|
||||
# MQA 576/512 approach for both prefill and decode
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported for MLACommonImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# The zero fill is required when used with DP + EP
|
||||
# to ensure all ranks within a DP group compute the
|
||||
# same expert outputs.
|
||||
output = torch.empty(output.shape[0], self.v_head_dim * self.num_heads, device=q.device,
|
||||
dtype=q.dtype)
|
||||
return output
|
||||
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
q = q[:num_actual_toks, ...]
|
||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
q_pe, k_pe = self.rotary_emb(self.positions[:num_actual_toks], q_pe, k_pe)
|
||||
|
||||
q_nope = self._k_up_proj(q_nope)
|
||||
q_nope = q_nope.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
||||
|
||||
# TODO: handle index / kv_cache correctly
|
||||
topk_indices_global = triton_convert_req_index_to_global_index(
|
||||
attn_metadata.req_id_per_token,
|
||||
attn_metadata.block_table,
|
||||
topk_indices,
|
||||
BLOCK_SIZE=attn_metadata.block_size,
|
||||
NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
|
||||
)
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
ops.concat_and_cache_mla(
|
||||
k_c_normed,
|
||||
k_pe,
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=layer._k_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype != "fp8_ds_mla":
|
||||
attn_out = self._forward_bf16_kv(
|
||||
q, kv_cache, topk_indices_global, attn_metadata
|
||||
)
|
||||
else:
|
||||
attn_out = self._forward_fp8_kv(
|
||||
q, kv_cache, topk_indices_global, attn_metadata
|
||||
)
|
||||
output = torch.empty(output.shape[0],
|
||||
self.num_heads, self.v_head_dim,
|
||||
device=q.device,
|
||||
dtype=q.dtype)
|
||||
|
||||
output[:num_actual_toks] = self._v_up_proj(attn_out)
|
||||
return output.view(output.shape[0], self.v_head_dim * self.num_heads)
|
||||
362
vllm_old/v1/attention/backends/mla/indexer.py
Normal file
362
vllm_old/v1/attention/backends/mla/indexer.py
Normal file
@@ -0,0 +1,362 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekV32IndexerBackend(AttentionBackend):
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 128]
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["DeepseekV32IndexerMetadataBuilder"]:
|
||||
return DeepseekV32IndexerMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
assert num_kv_heads == 1
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> tuple[int, ...]:
|
||||
return (0, 1, 2)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerPrefillChunkMetadata:
|
||||
block_table: torch.Tensor
|
||||
cu_seqlen_ks: torch.Tensor
|
||||
cu_seqlen_ke: torch.Tensor
|
||||
cu_seq_lens: torch.Tensor
|
||||
total_seq_lens: int
|
||||
token_start: int
|
||||
token_end: int
|
||||
num_reqs: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerPrefillMetadata:
|
||||
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepSeekV32IndexerDecodeMetadata:
|
||||
block_table: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
decode_lens: torch.Tensor
|
||||
requires_padding: bool
|
||||
# schedule_metadata: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerMetadata:
|
||||
# FIXME (zyongye)
|
||||
# hacky way to access the data now, need to be in chunked meta
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
query_start_loc: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
# The dimension of the attention heads
|
||||
head_dim: int
|
||||
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
|
||||
decode: DeepSeekV32IndexerDecodeMetadata | None = None
|
||||
prefill: DeepseekV32IndexerPrefillMetadata | None = None
|
||||
|
||||
|
||||
# TODO (zyongye) optimize this, this is now vibe coded
|
||||
def kv_spans_from_batches(
|
||||
start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor, device: torch.device
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
start_seq_loc: 1D long tensor [B+1], cumulative counts of
|
||||
selected tokens per batch.
|
||||
Example: [0, 2, 4, 7] ->
|
||||
batch sizes (selected) [2, 2, 3], N=7 tokens total.
|
||||
seq_len_per_batch: 1D long tensor [B],
|
||||
full sequence length (KV length) of each batch.
|
||||
Example: [5, 9, 4].
|
||||
|
||||
Returns:
|
||||
start_tensor: 1D long tensor [N], start offset in the
|
||||
concatenated KV cache for each token's batch.
|
||||
end_location: 1D long tensor [N],
|
||||
**exclusive** end = start + token's local position.
|
||||
(So the attended KV slice is kv[start:end].)
|
||||
|
||||
Assumes each batch contributes its full `seq_len_per_batch[i]`
|
||||
keys to the KV cache, andthe selected tokens within a batch
|
||||
are the **last** `counts[i]` positions of that sequence.
|
||||
"""
|
||||
q = start_seq_loc.to(dtype=torch.long)
|
||||
L = seq_len_per_batch.to(dtype=torch.long)
|
||||
assert q.dim() == 1 and L.dim() == 1
|
||||
assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1"
|
||||
|
||||
# Selected tokens per batch and totals
|
||||
counts = q[1:] - q[:-1] # [B]
|
||||
N = int(q[-1].item()) # total selected tokens
|
||||
B = L.numel()
|
||||
|
||||
if N == 0:
|
||||
return (
|
||||
torch.empty(0, dtype=torch.long, device=device),
|
||||
torch.empty(0, dtype=torch.long, device=device),
|
||||
)
|
||||
|
||||
# KV start offsets per batch in the concatenated KV cache
|
||||
kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B]
|
||||
|
||||
# For each selected token, which batch does it belong to?
|
||||
batch_id = torch.repeat_interleave(torch.arange(B), counts) # [N]
|
||||
|
||||
# Map batch KV start to each token
|
||||
start_tensor = kv_starts_per_batch[batch_id] # [N]
|
||||
|
||||
# End-align local positions inside each batch:
|
||||
# local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b
|
||||
L_expand = torch.repeat_interleave(L, counts) # [N]
|
||||
m_expand = torch.repeat_interleave(counts, counts) # [N]
|
||||
# position within the selected block: 1..counts[b]
|
||||
pos_within = (
|
||||
torch.arange(N, dtype=torch.long) - torch.repeat_interleave(q[:-1], counts) + 1
|
||||
)
|
||||
|
||||
local_pos = L_expand - m_expand + pos_within # [N], 1-based
|
||||
end_location = start_tensor + local_pos # exclusive end
|
||||
|
||||
return start_tensor.int().to(device), end_location.int().to(device)
|
||||
|
||||
|
||||
def get_max_prefill_buffer_size(vllm_config: VllmConfig):
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
# NOTE(Chen): 2 is a magic number for controlling the prefill buffer size.
|
||||
# May be tuned later.
|
||||
return max_model_len * 2
|
||||
|
||||
|
||||
def split_prefill_chunks(
|
||||
seq_lens_cpu: torch.Tensor, max_prefill_buffer_size: int, reqs_start: int
|
||||
) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Split the prefill chunks into a list of tuples of (reqs_start, reqs_end)
|
||||
such that the total sequence length of each chunk is less than the
|
||||
maximum prefill buffer size.
|
||||
|
||||
Args:
|
||||
seq_lens_cpu: The sequence lengths of the prefill requests.
|
||||
max_prefill_buffer_size: The maximum prefill buffer size.
|
||||
reqs_start: The start index of the prefill requests.
|
||||
|
||||
Returns:
|
||||
A list of tuples of (reqs_start, reqs_end).
|
||||
"""
|
||||
chunk_seq_ids = []
|
||||
total_seq_lens = 0
|
||||
for i in range(reqs_start, len(seq_lens_cpu)):
|
||||
cur_seq_len = seq_lens_cpu[i].item()
|
||||
assert cur_seq_len <= max_prefill_buffer_size
|
||||
total_seq_lens += cur_seq_len
|
||||
if total_seq_lens > max_prefill_buffer_size:
|
||||
chunk_seq_ids.append((reqs_start, i))
|
||||
reqs_start = i
|
||||
total_seq_lens = cur_seq_len
|
||||
if total_seq_lens > 0:
|
||||
chunk_seq_ids.append((reqs_start, len(seq_lens_cpu)))
|
||||
return chunk_seq_ids
|
||||
|
||||
|
||||
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
)
|
||||
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
scheduler_config = self.vllm_config.scheduler_config
|
||||
# NOTE(Chen):an estimated max size of flattened_kv. Need to double check.
|
||||
self.max_prefill_buffer_size = get_max_prefill_buffer_size(self.vllm_config)
|
||||
self.num_speculative_tokens = (
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
if self.vllm_config.speculative_config
|
||||
else 0
|
||||
)
|
||||
# Now deepgemm fp8_paged_mqa_logits does not support next_n > 2
|
||||
self.reorder_batch_threshold += min(self.num_speculative_tokens, 1)
|
||||
|
||||
props = torch.cuda.get_device_properties(self.device)
|
||||
sm_count = props.multi_processor_count
|
||||
self.num_sms = sm_count
|
||||
|
||||
self.decode_lens_buffer = torch.empty(
|
||||
(scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
# See: DeepGMM/csrc/apis/attention.hpp
|
||||
self.scheduler_metadata_buffer = torch.empty(
|
||||
(self.num_sms + 1, 2), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
def build_one_prefill_chunk(
|
||||
self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table
|
||||
):
|
||||
prefill_query_start_loc = (
|
||||
query_start_loc_cpu[reqs_start : reqs_end + 1]
|
||||
- query_start_loc_cpu[reqs_start]
|
||||
)
|
||||
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
|
||||
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device
|
||||
)
|
||||
token_start = query_start_loc_cpu[reqs_start].item()
|
||||
token_end = query_start_loc_cpu[reqs_end].item()
|
||||
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
|
||||
assert total_seq_lens <= self.max_prefill_buffer_size
|
||||
cu_seq_lens = (
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=torch.int32),
|
||||
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0),
|
||||
]
|
||||
)
|
||||
.to(torch.int32)
|
||||
.to(self.device)
|
||||
)
|
||||
return DeepseekV32IndexerPrefillChunkMetadata(
|
||||
cu_seqlen_ks=cu_seqlen_ks,
|
||||
cu_seqlen_ke=cu_seqlen_ke,
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
total_seq_lens=total_seq_lens,
|
||||
block_table=block_table[reqs_start:reqs_end],
|
||||
token_start=token_start,
|
||||
token_end=token_end,
|
||||
num_reqs=reqs_end - reqs_start,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> DeepseekV32IndexerMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
)
|
||||
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
chunk_seq_ids = split_prefill_chunks(
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
self.max_prefill_buffer_size,
|
||||
num_decodes,
|
||||
)
|
||||
chunks = [
|
||||
self.build_one_prefill_chunk(
|
||||
reqs_start,
|
||||
reqs_end,
|
||||
query_start_loc_cpu,
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
common_attn_metadata.block_table_tensor,
|
||||
)
|
||||
for reqs_start, reqs_end in chunk_seq_ids
|
||||
]
|
||||
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
|
||||
chunks=chunks,
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
torch.diff(
|
||||
common_attn_metadata.query_start_loc[: num_decodes + 1],
|
||||
out=self.decode_lens_buffer[:num_decodes],
|
||||
)
|
||||
decode_lens = self.decode_lens_buffer[:num_decodes]
|
||||
decode_lens_cpu = torch.diff(
|
||||
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
|
||||
)
|
||||
|
||||
# Use CPU to avoid GPU sync; breaking async scheduling
|
||||
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
||||
|
||||
# self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||
# seq_lens, self.kv_cache_spec.block_size, self.num_sms
|
||||
# )
|
||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...],
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
decode_lens=decode_lens,
|
||||
requires_padding=requires_padding,
|
||||
# schedule_metadata=self.scheduler_metadata_buffer,
|
||||
)
|
||||
|
||||
attn_metadata = DeepseekV32IndexerMetadata(
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
head_dim=128,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
)
|
||||
|
||||
# if get_tensor_model_parallel_rank() == 0:
|
||||
# logger.info(f"attn_metadata: {attn_metadata}")
|
||||
return attn_metadata
|
||||
294
vllm_old/v1/attention/backends/mla/rocm_aiter_mla.py
Normal file
294
vllm_old/v1/attention/backends/mla/rocm_aiter_mla.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.abstract import AttentionLayer
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
class AiterMLABackend(MLACommonBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_AITER_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["AiterMLAImpl"]:
|
||||
return AiterMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AiterMLAMetadataBuilder"]:
|
||||
return AiterMLAMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
# The indptr of the paged kv cache, shape: [batch_size + 1]
|
||||
paged_kv_indptr: torch.Tensor | None = None
|
||||
# The page indices of the paged kv cache
|
||||
paged_kv_indices: torch.Tensor | None = None
|
||||
# The number of entries in the last page of each request in
|
||||
# the paged kv cache, shape: [batch_size]
|
||||
paged_kv_last_page_len: torch.Tensor | None = None
|
||||
# The query indptr, shape : [num_decode + 1]
|
||||
qo_indptr: torch.Tensor | None = None
|
||||
|
||||
|
||||
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
# TODO(luka, lucas): audit this as part of:
|
||||
# https://github.com/vllm-project/vllm/issues/22945
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_spec, layer_names, vllm_config, device, AiterMLAMetadata
|
||||
)
|
||||
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
max_num_pages_per_req = cdiv(
|
||||
vllm_config.model_config.max_model_len, self.kv_cache_spec.block_size
|
||||
)
|
||||
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
max_num_pages = max_num_reqs * max_num_pages_per_req
|
||||
|
||||
# Preparing persistent buffers
|
||||
# TODO: we can disambiguate between decode and mixed-prefill decode here
|
||||
# so we can only use the persistent buffer if a cudagraph is actually
|
||||
# being used.
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.block_table_remapping = torch.zeros(
|
||||
[max_num_reqs, max_num_pages_per_req * self.kv_cache_spec.block_size],
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.paged_kv_indptr = torch.zeros(
|
||||
max_num_reqs + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
self.paged_kv_indices = torch.zeros(
|
||||
max_num_pages, dtype=torch.int32, device=device
|
||||
)
|
||||
self.paged_kv_last_page_len = torch.zeros(
|
||||
max_num_reqs, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
self.qo_indptr = torch.arange(
|
||||
0, max_num_reqs + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
) -> AiterMLADecodeMetadata:
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
device = self.device
|
||||
num_reqs = seq_lens_device.size(0)
|
||||
bs, _ = block_table_tensor.shape
|
||||
block_table_tensor = (
|
||||
block_table_tensor.unsqueeze(-1).expand(-1, -1, page_size) * page_size
|
||||
)
|
||||
block_table_tensor = (
|
||||
block_table_tensor
|
||||
+ torch.arange(
|
||||
0,
|
||||
page_size,
|
||||
device=block_table_tensor.device,
|
||||
dtype=block_table_tensor.dtype,
|
||||
)[None, None, :]
|
||||
)
|
||||
block_table_tensor = block_table_tensor.view(bs, -1)
|
||||
|
||||
# after remapping, we assume the block size already equals to 1
|
||||
|
||||
max_blk_size_per_req = block_table_tensor.shape[-1]
|
||||
mask = torch.arange(
|
||||
block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device
|
||||
).unsqueeze(0) < seq_lens_device.unsqueeze(1)
|
||||
paged_kv_indices = block_table_tensor[mask]
|
||||
|
||||
paged_kv_last_page_len = seq_lens_device % page_size
|
||||
paged_kv_last_page_len = torch.where(
|
||||
paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len
|
||||
)
|
||||
|
||||
paged_kv_indptr = torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=seq_lens_device.dtype, device=device),
|
||||
seq_lens_device.cumsum(dim=0, dtype=torch.int32),
|
||||
]
|
||||
)
|
||||
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
num_actual_pages = paged_kv_indices.size(0)
|
||||
self.block_table_remapping[:num_reqs, :max_blk_size_per_req].copy_(
|
||||
block_table_tensor, non_blocking=True
|
||||
)
|
||||
block_table_tensor = self.block_table_remapping[
|
||||
:num_reqs, :max_blk_size_per_req
|
||||
]
|
||||
|
||||
self.paged_kv_indices[:num_actual_pages].copy_(
|
||||
paged_kv_indices, non_blocking=True
|
||||
)
|
||||
self.paged_kv_indices[num_actual_pages:].fill_(-1)
|
||||
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
|
||||
|
||||
self.paged_kv_indptr[: 1 + num_reqs].copy_(
|
||||
paged_kv_indptr, non_blocking=True
|
||||
)
|
||||
self.paged_kv_indptr[1 + num_reqs :].fill_(paged_kv_indptr[-1])
|
||||
paged_kv_indptr = self.paged_kv_indptr[: 1 + num_reqs]
|
||||
|
||||
self.paged_kv_last_page_len[:num_reqs].copy_(
|
||||
paged_kv_last_page_len, non_blocking=True
|
||||
)
|
||||
self.paged_kv_last_page_len[num_reqs:].fill_(1)
|
||||
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
|
||||
|
||||
qo_indptr = self.qo_indptr[: 1 + num_reqs]
|
||||
|
||||
else:
|
||||
qo_indptr = torch.arange(
|
||||
0, num_reqs + 1, step=1, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
attn_metadata = AiterMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
paged_kv_indptr=paged_kv_indptr,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
qo_indptr=qo_indptr,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
assert num_heads == 16 or num_heads == 128, (
|
||||
f"Aiter MLA only supports 16 or 128 number of heads.\n"
|
||||
f"Provided {num_heads} number of heads.\n"
|
||||
"Try adjusting tensor_parallel_size value."
|
||||
)
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"Aiter MLA does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
from aiter import flash_attn_varlen_func
|
||||
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||
):
|
||||
output = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
softmax_scale=softmax_scale,
|
||||
return_lse=return_softmax_lse,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: AiterMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
assert isinstance(q, torch.Tensor)
|
||||
B = q.shape[0]
|
||||
o = torch.zeros(
|
||||
B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device
|
||||
)
|
||||
|
||||
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
|
||||
# max_seqlen_qo must be 1 except for MTP
|
||||
# TODO: Find the best value for MTP
|
||||
max_seqlen_qo = 1
|
||||
rocm_aiter_ops.mla_decode_fwd(
|
||||
q,
|
||||
kv_buffer,
|
||||
o,
|
||||
self.scale,
|
||||
attn_metadata.decode.qo_indptr,
|
||||
max_seqlen_qo,
|
||||
attn_metadata.decode.paged_kv_indptr,
|
||||
attn_metadata.decode.paged_kv_indices,
|
||||
attn_metadata.decode.paged_kv_last_page_len,
|
||||
)
|
||||
|
||||
return o, None
|
||||
206
vllm_old/v1/attention/backends/mla/triton_mla.py
Normal file
206
vllm_old/v1/attention/backends/mla/triton_mla.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
)
|
||||
import ixformer.inference.functions as ixf_ops
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TritonMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TritonMLAImpl"]:
|
||||
return TritonMLAImpl
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"TritonMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TritonMLAImpl"
|
||||
)
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"TritonMLA V1 with FP8 KV cache not yet supported"
|
||||
)
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||
):
|
||||
return super()._flash_attn_varlen_diff_headdims(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
softmax_scale=softmax_scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
k_c_normed: torch.Tensor | None,
|
||||
k_pe: torch.Tensor | None,
|
||||
kv_c_and_k_pe_cache_scale: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||
|
||||
decode_meta = attn_metadata.decode
|
||||
q_nope = self._k_up_proj(q_nope)
|
||||
q_nope = q_nope.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
q = get_dcp_group().all_gather(q, dim=1)
|
||||
o = torch.empty(B,
|
||||
q.shape[1],
|
||||
self.kv_lora_rank,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
if envs.VLLM_USE_INT8_MLA:
|
||||
q_int8, q_scale = ops.quant_kv(q)
|
||||
attn_out, softmax_lse = ixf_ops.vllm_paged_attention_mla_int8(
|
||||
o,
|
||||
q_int8,
|
||||
q_scale,
|
||||
kv_c_and_k_pe_cache,
|
||||
kv_c_and_k_pe_cache_scale,
|
||||
self.scale,
|
||||
attn_metadata.decode.block_table,
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.max_decode_seq_len,
|
||||
return_softmax_lse=True
|
||||
)
|
||||
else:
|
||||
attn_out, softmax_lse = ixf_ops.vllm_paged_attention_mla(
|
||||
output=o,
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
return_softmax_lse=True)
|
||||
return attn_out, softmax_lse
|
||||
|
||||
o = torch.empty(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
|
||||
if envs.VLLM_USE_INT8_MLA:
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
q_int8, q_scale = ops.quant_kv(q)
|
||||
ixf_ops.vllm_paged_attention_mla_int8(
|
||||
o,
|
||||
q_int8,
|
||||
q_scale,
|
||||
kv_c_and_k_pe_cache,
|
||||
kv_c_and_k_pe_cache_scale,
|
||||
self.scale,
|
||||
attn_metadata.decode.block_table,
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.max_decode_seq_len,
|
||||
attn_metadata.decode.use_cuda_graph
|
||||
)
|
||||
else:
|
||||
# fused q concat & cache write
|
||||
ixf_ops.vllm_paged_attention_mla_fused(
|
||||
output=o,
|
||||
q_nope=q_nope,
|
||||
q_pe=q_pe.contiguous(),
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
k_c_normed=k_c_normed,
|
||||
k_pe=k_pe,
|
||||
use_cuda_graph=decode_meta.use_cuda_graph
|
||||
)
|
||||
return self._v_up_proj(o), None
|
||||
436
vllm_old/v1/attention/backends/pallas.py
Normal file
436
vllm_old/v1/attention/backends/pallas.py
Normal file
@@ -0,0 +1,436 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import cdiv, next_power_of_2
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# TPU requires the head size to be a multiple of 128.
|
||||
TPU_HEAD_SIZE_ALIGNMENT = 128
|
||||
|
||||
# Note: TPU can fp8 as storage dtype but doesn't support converting from uint8
|
||||
# from to fp32 directly. That's why it has a dtype mapping different from GPU
|
||||
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.half,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float": torch.float,
|
||||
"fp8": torch.float8_e4m3fn,
|
||||
"fp8_e4m3": torch.float8_e4m3fn,
|
||||
"fp8_e5m2": torch.float8_e5m2,
|
||||
"int8": torch.int8,
|
||||
"uint8": torch.uint8,
|
||||
}
|
||||
|
||||
try:
|
||||
import tpu_inference # noqa: F401
|
||||
except ImportError:
|
||||
# Lazy import torch_xla
|
||||
import torch_xla.core.xla_builder as xb
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
from torch.library import impl
|
||||
from torch_xla._internal.jax_workarounds import requires_jax
|
||||
from torch_xla.experimental.custom_kernel import XLA_LIB
|
||||
|
||||
@requires_jax
|
||||
def kv_cache_update_op_impl(
|
||||
kv: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_update_slices: torch.Tensor,
|
||||
page_size: int,
|
||||
num_slices_per_block: int,
|
||||
):
|
||||
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
|
||||
|
||||
new_kv_cache = xb.call_jax(
|
||||
kv_cache_update,
|
||||
(kv, slot_mapping, kv_cache, num_kv_update_slices),
|
||||
{"page_size": page_size, "num_slices_per_block": num_slices_per_block},
|
||||
)
|
||||
return new_kv_cache
|
||||
|
||||
XLA_LIB.define(
|
||||
"kv_cache_update_op(Tensor kv, Tensor slot_mapping,"
|
||||
"Tensor kv_cache, Tensor num_kv_update_slices, int page_size,"
|
||||
"int num_slices_per_block)"
|
||||
"-> Tensor",
|
||||
)
|
||||
|
||||
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
|
||||
def kv_cache_update_op_xla(
|
||||
kv: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_update_slices: torch.Tensor,
|
||||
page_size: int,
|
||||
num_slices_per_block: int,
|
||||
) -> torch.Tensor:
|
||||
new_kv_cache = kv_cache_update_op_impl(
|
||||
kv,
|
||||
slot_mapping,
|
||||
kv_cache,
|
||||
num_kv_update_slices,
|
||||
page_size,
|
||||
num_slices_per_block,
|
||||
)
|
||||
return new_kv_cache
|
||||
|
||||
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
|
||||
def kv_cache_update_op_non_xla(
|
||||
kv: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_update_slices: torch.Tensor,
|
||||
page_size: int,
|
||||
num_slices_per_block: int,
|
||||
) -> torch.Tensor:
|
||||
return kv_cache
|
||||
|
||||
|
||||
class PallasAttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "PALLAS"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
|
||||
return PallasAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
padded_head_size = (
|
||||
cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
)
|
||||
return (num_blocks, block_size, num_kv_heads * 2, padded_head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
raise RuntimeError("swap_blocks is not used for the TPU backend.")
|
||||
|
||||
# In recent TPU generations, up to v6e, the SMEM size is 1MB. The
|
||||
# block_tables within the PallasMetadata constitute almost the entire SMEM
|
||||
# requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here
|
||||
# we simply make sure that the size is smaller than half of SMEM capacity.
|
||||
@staticmethod
|
||||
def get_min_page_size(vllm_config: VllmConfig) -> int:
|
||||
max_num_page_per_req = (
|
||||
1024 * 1024 // 2 // vllm_config.scheduler_config.max_num_seqs // 4
|
||||
)
|
||||
min_page_size = cdiv(
|
||||
vllm_config.model_config.max_model_len, max_num_page_per_req
|
||||
)
|
||||
min_page_size = 1 << (min_page_size - 1).bit_length()
|
||||
return min_page_size
|
||||
|
||||
@staticmethod
|
||||
def get_max_num_seqs(model_len: int, page_size: int) -> int:
|
||||
num_page_per_req = cdiv(model_len, page_size)
|
||||
return 1024 * 1024 // 2 // num_page_per_req // 4
|
||||
|
||||
# TPU has limited SREGs (scalar registers), if page_size is too small, we
|
||||
# can spill SREGs easily which leads to bad performance. The strategy we
|
||||
# apply here is trying to split max-model-len to 16 pages which make the
|
||||
# spill less likely. Meanwhile we make sure the page size is in [16, 256].
|
||||
@staticmethod
|
||||
def get_page_size(vllm_config: VllmConfig) -> int:
|
||||
# TODO: This is a temporary fix for vmem OOM.
|
||||
# For long model length, we use 16 page-size to avoid too much
|
||||
# VMEM spill. A more robust solution should be implemented to
|
||||
# handle VREG spills.
|
||||
if vllm_config.model_config.max_model_len > 8192:
|
||||
return 16
|
||||
page_size = next_power_of_2(vllm_config.model_config.max_model_len) // 16
|
||||
if page_size <= 16:
|
||||
return 16
|
||||
if page_size >= 256:
|
||||
return 256
|
||||
return page_size
|
||||
|
||||
|
||||
@dataclass
|
||||
class PallasMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Used in the PallasAttentionBackendImpl
|
||||
slot_mapping: torch.Tensor
|
||||
block_tables: torch.Tensor
|
||||
context_lens: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
num_seqs: torch.Tensor
|
||||
num_kv_update_slices: torch.Tensor
|
||||
num_slices_per_kv_cache_update_block: int
|
||||
|
||||
|
||||
class PallasAttentionBackendImpl(AttentionImpl):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: int | None = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError("Alibi slopes is not supported.")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl"
|
||||
)
|
||||
|
||||
self.kv_cache_quantized_dtype = None
|
||||
if kv_cache_dtype != "auto":
|
||||
self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get(
|
||||
kv_cache_dtype.lower().strip()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: PallasMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Pallas attention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache: shape =
|
||||
[num_blocks, block_size, num_kv_heads * 2, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for PallasAttentionBackendImpl"
|
||||
)
|
||||
|
||||
# For determine_available_memory case.
|
||||
if kv_cache.numel() == 0:
|
||||
if output is None:
|
||||
output = torch.ones_like(query)
|
||||
return output
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
query = query.view(num_tokens, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
|
||||
padded_head_size = (
|
||||
cdiv(self.head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
)
|
||||
query = torch.nn.functional.pad(
|
||||
query, (0, padded_head_size - self.head_size), value=0.0
|
||||
)
|
||||
key = torch.nn.functional.pad(
|
||||
key, (0, padded_head_size - self.head_size), value=0.0
|
||||
)
|
||||
value = torch.nn.functional.pad(
|
||||
value, (0, padded_head_size - self.head_size), value=0.0
|
||||
)
|
||||
|
||||
if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0:
|
||||
# Write input keys and values to the KV cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
write_to_kv_cache(
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
slot_mapping,
|
||||
attn_metadata.num_slices_per_kv_cache_update_block,
|
||||
attn_metadata.num_kv_update_slices,
|
||||
self.kv_cache_quantized_dtype,
|
||||
layer._k_scale_float,
|
||||
layer._v_scale_float,
|
||||
)
|
||||
|
||||
if self.kv_cache_quantized_dtype is not None and (
|
||||
layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0
|
||||
):
|
||||
raise ValueError("k_scale_float and v_scale_float must be non-zero")
|
||||
output = torch.ops.xla.ragged_paged_attention(
|
||||
query,
|
||||
kv_cache,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.query_start_loc,
|
||||
attn_metadata.num_seqs,
|
||||
# By default, the system utilizes optimized block size and
|
||||
# vmem_limit_bytes parameters from the kernel repository. However,
|
||||
# these can be manually adjusted for debugging if necessary.
|
||||
num_kv_pages_per_block=None,
|
||||
num_queries_per_block=None,
|
||||
vmem_limit_bytes=None,
|
||||
use_kernel=True,
|
||||
sm_scale=self.scale,
|
||||
sliding_window=self.sliding_window,
|
||||
soft_cap=self.logits_soft_cap,
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
)
|
||||
|
||||
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
|
||||
output = output[:, :, : self.head_size]
|
||||
|
||||
return output.reshape(num_tokens, hidden_size)
|
||||
|
||||
|
||||
def write_to_kv_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
num_slices_per_kv_cache_update_block: int,
|
||||
num_kv_update_slices: torch.Tensor,
|
||||
kv_cache_quantized_dtype: torch.dtype | None = None,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
) -> None:
|
||||
"""Write the key and values to the KV cache.
|
||||
|
||||
Args:
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape = [num_blocks, block_size, num_kv_heads * 2, head_size]
|
||||
num_slices_per_kv_cache_update_block: int
|
||||
"""
|
||||
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape
|
||||
head_size = cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
|
||||
if kv_cache_quantized_dtype is not None:
|
||||
dtype_info = torch.finfo(kv_cache_quantized_dtype)
|
||||
key = key.to(torch.float32) / k_scale
|
||||
# NOTE: clamp is added here to avoid out of range of quantized dtype
|
||||
key = torch.clamp(key, dtype_info.min, dtype_info.max)
|
||||
key = key.to(kv_cache_quantized_dtype)
|
||||
value = value.to(torch.float32) / v_scale
|
||||
value = torch.clamp(value, dtype_info.min, dtype_info.max)
|
||||
value = value.to(kv_cache_quantized_dtype)
|
||||
|
||||
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, head_size)
|
||||
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
|
||||
|
||||
kv_cache = kv_cache.flatten(0, 1)
|
||||
new_kv_cache = torch.ops.xla.kv_cache_update_op(
|
||||
kv,
|
||||
slot_mapping,
|
||||
kv_cache,
|
||||
num_kv_update_slices,
|
||||
page_size,
|
||||
num_slices_per_kv_cache_update_block,
|
||||
)
|
||||
# NOTE: the in-place copy will be optimized away by XLA compiler.
|
||||
kv_cache.copy_(new_kv_cache)
|
||||
|
||||
|
||||
# We can move this function to a common utils file if it's also useful for other
|
||||
# hardware.
|
||||
def dtype_bits(dtype: torch.dtype):
|
||||
if dtype.is_floating_point:
|
||||
try:
|
||||
return torch.finfo(dtype).bits
|
||||
except TypeError:
|
||||
pass
|
||||
elif dtype.is_complex:
|
||||
if dtype is torch.complex32:
|
||||
return 32
|
||||
elif dtype is torch.complex64:
|
||||
return 64
|
||||
elif dtype is torch.complex128:
|
||||
return 128
|
||||
else:
|
||||
try:
|
||||
return torch.iinfo(dtype).bits
|
||||
# torch.iinfo cannot support int4, int2, bits8...
|
||||
except TypeError:
|
||||
pass
|
||||
str_dtype = str(dtype)
|
||||
# support torch.int4, torch.int5, torch.uint5...
|
||||
if str_dtype.startswith("torch.int") or str_dtype.startswith("torch.uint"):
|
||||
return int(str_dtype[-1])
|
||||
raise TypeError(f"Getting the bit width of {dtype} is not supported")
|
||||
|
||||
|
||||
def get_dtype_packing(dtype):
|
||||
bits = dtype_bits(dtype)
|
||||
if 32 % bits != 0:
|
||||
raise ValueError(
|
||||
f"The bit width must be divisible by 32, but got bits={bits}, "
|
||||
"dtype={dtype}"
|
||||
)
|
||||
return 32 // bits
|
||||
|
||||
|
||||
def get_page_size_bytes(
|
||||
block_size: int, num_kv_heads: int, head_size: int, kv_cache_dtype: torch.dtype
|
||||
) -> int:
|
||||
"""Returns the size in bytes of one page of the KV cache."""
|
||||
padded_head_size = (
|
||||
cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
)
|
||||
num_combined_kv_heads = num_kv_heads * 2
|
||||
|
||||
# NOTE: for the implicit padding in XLA
|
||||
packing = get_dtype_packing(kv_cache_dtype)
|
||||
num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing
|
||||
|
||||
kv_cache_dtype_bits = dtype_bits(kv_cache_dtype)
|
||||
return (
|
||||
block_size * num_combined_kv_heads * padded_head_size * kv_cache_dtype_bits // 8
|
||||
)
|
||||
816
vllm_old/v1/attention/backends/rocm_aiter_fa.py
Normal file
816
vllm_old/v1/attention/backends/rocm_aiter_fa.py
Normal file
@@ -0,0 +1,816 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with AiterFlashAttention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import get_cu_count
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_prefills_and_extends,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
_CP_TOKENS_PER_ITER_ROCM = 32 * 1024
|
||||
|
||||
if current_platform.is_rocm():
|
||||
import aiter
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
def block_size(x, head_dim):
|
||||
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
|
||||
|
||||
def num_programs(total_tokens):
|
||||
return min(total_tokens, get_cu_count())
|
||||
|
||||
@triton.jit
|
||||
def cp_mha_gather_cache_kernel(
|
||||
key_cache_ptr, # [num_blocks, page_size, num_head, head_size]
|
||||
value_cache_ptr, # [num_blocks, page_size, num_head, head_size]
|
||||
key_ptr, # [num_tokens, num_heads, head_size]
|
||||
value_ptr, # [num_tokens, num_heads, head_size]
|
||||
block_table_ptr, # [num_batches, max_block_num]
|
||||
cu_seqlens_kv_ptr, # [num_batches + 1]
|
||||
token_to_batch_ptr, # [max_cum_tokens]
|
||||
seq_start_ptr, # [num_batches]
|
||||
k_scale_ptr,
|
||||
v_scale_ptr,
|
||||
num_heads,
|
||||
head_size,
|
||||
x,
|
||||
max_block_num,
|
||||
num_tokens,
|
||||
num_programs,
|
||||
DEQUANT: tl.constexpr,
|
||||
PAGE_SIZE: tl.constexpr,
|
||||
CACHE_FORMAT: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
bid = tl.program_id(0)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
if DEQUANT:
|
||||
k_scale = tl.load(k_scale_ptr)
|
||||
v_scale = tl.load(v_scale_ptr)
|
||||
|
||||
for token_id in tl.range(bid, num_tokens, num_programs):
|
||||
key_ptr_offset = key_ptr + token_id * head_size * num_heads
|
||||
value_ptr_offset = value_ptr + token_id * head_size * num_heads
|
||||
batch_idx = tl.load(token_to_batch_ptr + token_id)
|
||||
batch_start = tl.load(seq_start_ptr + batch_idx)
|
||||
token_start = tl.load(cu_seqlens_kv_ptr + batch_idx)
|
||||
batch_offset = token_id - token_start + batch_start
|
||||
block_offset = batch_offset // PAGE_SIZE
|
||||
block_id = tl.load(
|
||||
block_table_ptr + max_block_num * batch_idx + block_offset
|
||||
)
|
||||
slot_id = batch_offset % PAGE_SIZE
|
||||
|
||||
if CACHE_FORMAT == "NHD":
|
||||
# for kv cache layout as
|
||||
# K: [num_blocks, page_size, num_head, head_dim]
|
||||
# V: [num_blocks, page_size, num_head, head_dim]
|
||||
key_cache_ptr_offset = (
|
||||
key_cache_ptr
|
||||
+ block_id * num_heads * head_size * PAGE_SIZE
|
||||
+ slot_id * num_heads * head_size
|
||||
)
|
||||
value_cache_ptr_offset = (
|
||||
value_cache_ptr
|
||||
+ block_id * num_heads * head_size * PAGE_SIZE
|
||||
+ slot_id * num_heads * head_size
|
||||
)
|
||||
|
||||
for i in tl.range(0, head_size * num_heads, BLOCK_SIZE):
|
||||
mask = (col_offsets + i) < head_size * num_heads
|
||||
k_reg = tl.load(key_cache_ptr_offset + col_offsets + i, mask=mask)
|
||||
v_reg = tl.load(value_cache_ptr_offset + col_offsets + i, mask=mask)
|
||||
if DEQUANT:
|
||||
k_dtype = k_reg.dtype
|
||||
v_dtype = v_reg.dtype
|
||||
k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype)
|
||||
v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype)
|
||||
tl.store(key_ptr_offset + col_offsets + i, k_reg, mask=mask)
|
||||
tl.store(value_ptr_offset + col_offsets + i, v_reg, mask=mask)
|
||||
|
||||
def cp_mha_gather_cache(
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
k_scales: torch.Tensor,
|
||||
v_scales: torch.Tensor,
|
||||
cu_seqlens_kv: torch.Tensor,
|
||||
token_to_batch: torch.Tensor,
|
||||
seq_starts: torch.Tensor,
|
||||
dequant: bool,
|
||||
kv_cache_layout: str,
|
||||
total_tokens: int,
|
||||
):
|
||||
assert kv_cache_layout in ["v0", "NHD", "HND"], (
|
||||
"kv_cache_layout only support v0, NHD, HND"
|
||||
)
|
||||
head_dim = key.shape[2]
|
||||
x = 0
|
||||
# assert dequant is True, "Currently, we only support "\
|
||||
# "gather cache with dequant"
|
||||
# For k cache layout: [num_blocks, num_heads, page_size, head_dim]
|
||||
assert kv_cache_layout == "NHD", (
|
||||
"ROCM_AITER_FA_BACKEND Only support NHD kv cache layout for now"
|
||||
)
|
||||
assert head_dim == key_cache.shape[3], (
|
||||
"We assume your kv cache layout is [num_blocks, "
|
||||
"page_size, num_heads, head_dim], but got otherwise"
|
||||
)
|
||||
page_size = key_cache.shape[1]
|
||||
num_heads = key_cache.shape[2]
|
||||
|
||||
NUM_PRGMS = num_programs(total_tokens)
|
||||
BLOCK_SIZE = block_size(key_cache, head_dim)
|
||||
grid = lambda meta: (NUM_PRGMS,)
|
||||
cp_mha_gather_cache_kernel[grid](
|
||||
key_cache,
|
||||
value_cache,
|
||||
key,
|
||||
value,
|
||||
block_tables,
|
||||
cu_seqlens_kv,
|
||||
token_to_batch,
|
||||
seq_starts,
|
||||
k_scales,
|
||||
v_scales,
|
||||
num_heads,
|
||||
head_dim,
|
||||
x,
|
||||
block_tables.size(1),
|
||||
total_tokens,
|
||||
NUM_PRGMS,
|
||||
DEQUANT=dequant,
|
||||
PAGE_SIZE=page_size,
|
||||
CACHE_FORMAT=kv_cache_layout,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterFlashAttentionDecodeMetadata:
|
||||
max_query_len: int
|
||||
min_query_len: int
|
||||
max_seq_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterFlashAttentionPrefillMetadata:
|
||||
max_query_len: int
|
||||
min_query_len: int
|
||||
max_seq_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterChunkContextMetadata:
|
||||
workspace: torch.Tensor
|
||||
cu_seq_lens_chunk: torch.Tensor
|
||||
chunk_starts: torch.Tensor
|
||||
token_to_batch: torch.Tensor
|
||||
seq_tot: list[int]
|
||||
max_seq_lens: list[int]
|
||||
seq_lens: torch.Tensor
|
||||
num_chunks: int
|
||||
total_token_per_batch: list[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterFlashAttentionChunkPrefillMetadata:
|
||||
max_query_len: int
|
||||
min_query_len: int
|
||||
max_seq_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
chunk_context_metadata: AiterChunkContextMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterFlashAttentionMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
num_actual_kv_tokens: int
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
|
||||
# prefill and deocde split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_extends: int
|
||||
num_extend_tokens: int
|
||||
|
||||
decode_metadata: AiterFlashAttentionDecodeMetadata | None
|
||||
prefill_metadata: AiterFlashAttentionPrefillMetadata | None
|
||||
extend_metadata: AiterFlashAttentionChunkPrefillMetadata | None
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class AiterFlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[AiterFlashAttentionMetadata]
|
||||
):
|
||||
_cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.model_config = vllm_config.model_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
|
||||
self.num_heads_q = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config
|
||||
)
|
||||
self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||
self.headdim = self.model_config.get_head_size()
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
# Sliding window size to be used with the AOT scheduler will be
|
||||
# populated on first build() call.
|
||||
self.aot_sliding_window: tuple[int, int] | None = None
|
||||
self.total_tokens: int = 0
|
||||
|
||||
self.extend_workspace = torch.empty(
|
||||
[2, _CP_TOKENS_PER_ITER_ROCM, self.num_heads_kv, self.headdim],
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
):
|
||||
self.total_tokens = (
|
||||
self.model_config.max_model_len
|
||||
* self.vllm_config.scheduler_config.max_num_partial_prefills
|
||||
)
|
||||
res = self.build(common_prefix_len=0, common_attn_metadata=common_attn_metadata)
|
||||
self.total_tokens = 0
|
||||
return res
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> "AiterFlashAttentionMetadata":
|
||||
split_ret = split_decodes_prefills_and_extends(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold,
|
||||
)
|
||||
|
||||
(
|
||||
num_decodes,
|
||||
num_extends,
|
||||
num_prefills,
|
||||
num_decode_tokens,
|
||||
num_extend_tokens,
|
||||
num_prefill_tokens,
|
||||
) = split_ret
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu
|
||||
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
decode_metadata = AiterFlashAttentionDecodeMetadata(
|
||||
max_query_len=query_lens_cpu[:num_decodes].max().item(),
|
||||
min_query_len=query_lens_cpu[:num_decodes].min().item(),
|
||||
max_seq_len=seq_lens[:num_decodes].max().item(),
|
||||
query_start_loc=common_attn_metadata.query_start_loc[: num_decodes + 1],
|
||||
)
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
query_lens_for_prefill = query_lens_cpu[num_decodes + num_extends :]
|
||||
query_start_loc_device = common_attn_metadata.query_start_loc[
|
||||
num_decodes + num_extends :
|
||||
]
|
||||
prefill_metadata = AiterFlashAttentionPrefillMetadata(
|
||||
max_query_len=query_lens_for_prefill.max().item(),
|
||||
min_query_len=query_lens_for_prefill.min().item(),
|
||||
max_seq_len=seq_lens[num_decodes + num_extends :].max().item(),
|
||||
query_start_loc=query_start_loc_device - query_start_loc_device[0],
|
||||
)
|
||||
|
||||
extend_metadata = None
|
||||
if num_extends > 0:
|
||||
num_extends_slice = slice(num_decodes, num_decodes + num_extends)
|
||||
query_lens_for_extend = query_lens_cpu[num_extends_slice]
|
||||
seq_lens_for_extend = common_attn_metadata.seq_lens_cpu[num_extends_slice]
|
||||
computed_kv_lens = seq_lens_for_extend - query_lens_for_extend
|
||||
|
||||
# allocate the equal amount of workspace for
|
||||
# each chunk prefill request
|
||||
max_context_chunk = _CP_TOKENS_PER_ITER_ROCM // num_extends
|
||||
num_chunks = cdiv(computed_kv_lens.max().item(), max_context_chunk)
|
||||
|
||||
chunk_starts = (
|
||||
torch.arange(num_chunks, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(-1, num_extends)
|
||||
* max_context_chunk
|
||||
)
|
||||
chunk_ends = torch.min(
|
||||
computed_kv_lens.unsqueeze(0), chunk_starts + max_context_chunk
|
||||
)
|
||||
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(
|
||||
min=0
|
||||
) # [num_chunks, num_extends]
|
||||
cu_seq_lens_cpu = torch.zeros(
|
||||
[num_chunks, num_extends + 1], dtype=torch.int32, pin_memory=True
|
||||
)
|
||||
torch.cumsum(
|
||||
chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32
|
||||
)
|
||||
max_cum_tokens = cu_seq_lens_cpu[:, -1].max().item()
|
||||
|
||||
range_idx = torch.arange(max_cum_tokens, dtype=torch.int32)[None, None, :]
|
||||
idx_to_batch_tensor = range_idx == cu_seq_lens_cpu[:, 1:][:, :, None]
|
||||
idx_to_batch_tensor = idx_to_batch_tensor.sum(
|
||||
dim=1
|
||||
) # [num_chunks, max_cum_tokens]
|
||||
token_to_batch_tensor = torch.cumsum(idx_to_batch_tensor, dim=1)
|
||||
|
||||
chunk_context_metadata = AiterChunkContextMetadata(
|
||||
workspace=self.extend_workspace,
|
||||
cu_seq_lens_chunk=cu_seq_lens_cpu.to(self.device, non_blocking=True),
|
||||
chunk_starts=chunk_starts.to(self.device, non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
seq_lens=chunk_seq_lens,
|
||||
token_to_batch=token_to_batch_tensor.to(self.device, non_blocking=True),
|
||||
num_chunks=num_chunks,
|
||||
total_token_per_batch=cu_seq_lens_cpu[:, -1].tolist(),
|
||||
)
|
||||
|
||||
query_start_loc_device = common_attn_metadata.query_start_loc[
|
||||
num_decodes : num_decodes + num_extends + 1
|
||||
]
|
||||
seq_lens_device = common_attn_metadata.seq_lens[num_extends_slice]
|
||||
cu_seq_lens = torch.zeros(
|
||||
num_extends + 1, dtype=torch.int32, device=seq_lens_device.device
|
||||
)
|
||||
torch.cumsum(
|
||||
seq_lens_device, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:]
|
||||
)
|
||||
extend_metadata = AiterFlashAttentionChunkPrefillMetadata(
|
||||
max_query_len=query_lens_for_extend.max().item(),
|
||||
min_query_len=query_lens_for_extend.min().item(),
|
||||
max_seq_len=seq_lens[num_extends_slice].max().item(),
|
||||
query_start_loc=query_start_loc_device - query_start_loc_device[0],
|
||||
chunk_context_metadata=chunk_context_metadata,
|
||||
)
|
||||
|
||||
num_actual_kv_tokens = torch.sum(seq_lens).item()
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
attn_metadata = AiterFlashAttentionMetadata(
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
num_actual_kv_tokens=num_actual_kv_tokens,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
block_table=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_extends=num_extends,
|
||||
num_extend_tokens=num_extend_tokens,
|
||||
decode_metadata=decode_metadata,
|
||||
prefill_metadata=prefill_metadata,
|
||||
extend_metadata=extend_metadata,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
total_tokens=self.total_tokens,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class AiterFlashAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [64, 128, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["AiterFlashAttentionImpl"]:
|
||||
return AiterFlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AiterFlashAttentionMetadataBuilder"]:
|
||||
return AiterFlashAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
|
||||
class AiterFlashAttentionImpl(AttentionImpl):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: int | None = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = [-1, -1]
|
||||
else:
|
||||
self.sliding_window = [sliding_window - 1, 0]
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0.0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashAttentionImpl"
|
||||
)
|
||||
|
||||
def extend_forward(
|
||||
self,
|
||||
attn_metadata: AiterFlashAttentionMetadata,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
min_seqlen_q: int,
|
||||
block_table: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
):
|
||||
out, lse = aiter.flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_q,
|
||||
min_seqlen_q=min_seqlen_q,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
return_lse=True,
|
||||
)
|
||||
assert attn_metadata.extend_metadata is not None
|
||||
chunk_context_metadata = attn_metadata.extend_metadata.chunk_context_metadata
|
||||
num_chunks = chunk_context_metadata.num_chunks
|
||||
workspace = chunk_context_metadata.workspace
|
||||
cu_seqlens_kv = chunk_context_metadata.cu_seq_lens_chunk
|
||||
max_seqlens = chunk_context_metadata.max_seq_lens
|
||||
chunk_starts = chunk_context_metadata.chunk_starts
|
||||
token_to_batch = chunk_context_metadata.token_to_batch
|
||||
total_token_per_batch = chunk_context_metadata.total_token_per_batch
|
||||
key_fetched, value_fetched = workspace[0], workspace[1]
|
||||
chunked_output = None
|
||||
chunked_lse = None
|
||||
for chunk_idx in range(num_chunks):
|
||||
cp_mha_gather_cache(
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
key=key_fetched,
|
||||
value=value_fetched,
|
||||
block_tables=block_table,
|
||||
k_scales=k_scale,
|
||||
v_scales=v_scale,
|
||||
cu_seqlens_kv=cu_seqlens_kv[chunk_idx],
|
||||
token_to_batch=token_to_batch[chunk_idx],
|
||||
seq_starts=chunk_starts[chunk_idx],
|
||||
dequant=False,
|
||||
kv_cache_layout="NHD",
|
||||
total_tokens=total_token_per_batch[chunk_idx],
|
||||
)
|
||||
|
||||
suf_out, suf_lse = aiter.flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_fetched,
|
||||
v=value_fetched,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_kv[chunk_idx],
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlens[chunk_idx],
|
||||
min_seqlen_q=min_seqlen_q,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
return_lse=True,
|
||||
)
|
||||
if chunked_output is None:
|
||||
chunked_output = suf_out
|
||||
chunked_lse = suf_lse
|
||||
else:
|
||||
tmp_output = torch.empty_like(out)
|
||||
tmp_lse = torch.empty_like(lse)
|
||||
merge_attn_states(
|
||||
output=tmp_output,
|
||||
output_lse=tmp_lse,
|
||||
prefix_output=chunked_output,
|
||||
prefix_lse=chunked_lse,
|
||||
suffix_output=suf_out,
|
||||
suffix_lse=suf_lse,
|
||||
)
|
||||
chunked_output = tmp_output
|
||||
chunked_lse = tmp_lse
|
||||
|
||||
merge_attn_states(
|
||||
output=output,
|
||||
prefix_output=chunked_output,
|
||||
prefix_lse=chunked_lse,
|
||||
suffix_output=out,
|
||||
suffix_lse=lse,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AiterFlashAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with AiterFlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
NOTE: FP8 quantization, flash-attn expect the size of
|
||||
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported for FlashAttentionImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is
|
||||
# executed in eager-mode PyTorch. Thus, we need to be careful
|
||||
# about any CPU overhead in this method. For example, `view`
|
||||
# and `slice` (or `[:n]`) operations are surprisingly slow even
|
||||
# in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping
|
||||
# is not padded. However, we don't need to do
|
||||
# key[:num_actual_tokens] and value[:num_actual_tokens] because
|
||||
# the reshape_and_cache_flash op uses the slot_mapping's shape
|
||||
# to determine the number of actual tokens.
|
||||
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||
|
||||
# decode:extend:prefill
|
||||
query = query[:num_actual_tokens]
|
||||
key = key[:num_actual_tokens]
|
||||
value = value[:num_actual_tokens]
|
||||
|
||||
output_actual_tokens = output[:num_actual_tokens]
|
||||
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
num_prefills = attn_metadata.num_prefills
|
||||
num_extends = attn_metadata.num_extends
|
||||
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_extend_tokens = attn_metadata.num_extend_tokens
|
||||
if not attn_metadata.use_cascade:
|
||||
# calculate for pure prefills
|
||||
if num_prefills > 0:
|
||||
assert attn_metadata.prefill_metadata is not None
|
||||
|
||||
prefill_query = query[num_decode_tokens + num_extend_tokens :]
|
||||
prefill_key = key[num_decode_tokens + num_extend_tokens :]
|
||||
prefill_value = value[num_decode_tokens + num_extend_tokens :]
|
||||
|
||||
aiter.flash_attn_varlen_func(
|
||||
q=prefill_query,
|
||||
k=prefill_key,
|
||||
v=prefill_value,
|
||||
cu_seqlens_q=attn_metadata.prefill_metadata.query_start_loc,
|
||||
cu_seqlens_k=attn_metadata.prefill_metadata.query_start_loc,
|
||||
max_seqlen_q=attn_metadata.prefill_metadata.max_query_len,
|
||||
max_seqlen_k=attn_metadata.prefill_metadata.max_seq_len,
|
||||
min_seqlen_q=1,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
out=output_actual_tokens[num_decode_tokens + num_extend_tokens :],
|
||||
)
|
||||
|
||||
# calculate for extends
|
||||
if num_extends > 0:
|
||||
assert attn_metadata.extend_metadata is not None
|
||||
extend_tokens_slice = slice(
|
||||
num_decode_tokens, num_decode_tokens + num_extend_tokens
|
||||
)
|
||||
extend_querys = query[extend_tokens_slice]
|
||||
extend_keys = key[extend_tokens_slice]
|
||||
extend_values = value[extend_tokens_slice]
|
||||
extend_outputs = output[extend_tokens_slice]
|
||||
self.extend_forward(
|
||||
attn_metadata=attn_metadata,
|
||||
query=extend_querys,
|
||||
key=extend_keys,
|
||||
value=extend_values,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
output=extend_outputs,
|
||||
cu_seqlens_q=attn_metadata.extend_metadata.query_start_loc,
|
||||
max_seqlen_q=attn_metadata.extend_metadata.max_query_len,
|
||||
max_seqlen_k=attn_metadata.extend_metadata.max_seq_len,
|
||||
min_seqlen_q=1,
|
||||
block_table=attn_metadata.block_table[
|
||||
num_decodes : num_decodes + num_extends
|
||||
],
|
||||
slot_mapping=attn_metadata.slot_mapping[
|
||||
num_decodes : num_decodes + num_extends
|
||||
],
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
)
|
||||
|
||||
# calculate for decodes
|
||||
if num_decodes > 0:
|
||||
assert attn_metadata.decode_metadata is not None
|
||||
_, num_heads, head_size = query.shape
|
||||
nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8
|
||||
num_seqs = attn_metadata.seq_lens.shape[0]
|
||||
max_num_partitions = (
|
||||
attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1
|
||||
) // _PARTITION_SIZE_ROCM
|
||||
|
||||
workspace_buffer = torch.empty(
|
||||
(num_seqs * num_heads * max_num_partitions * head_size)
|
||||
* nbytes_per_qo_elem
|
||||
+ 2 * (num_seqs * num_heads * max_num_partitions) * 4,
|
||||
dtype=torch.uint8,
|
||||
device=output.device,
|
||||
)
|
||||
|
||||
torch.ops.aiter.paged_attention_v1(
|
||||
output[:num_decode_tokens],
|
||||
workspace_buffer,
|
||||
query[:num_decode_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.scale,
|
||||
attn_metadata.block_table[:num_decodes],
|
||||
attn_metadata.query_start_loc[:num_decodes],
|
||||
attn_metadata.seq_lens[:num_decodes],
|
||||
attn_metadata.max_seq_len,
|
||||
self.alibi_slopes,
|
||||
self.kv_cache_dtype,
|
||||
"NHD",
|
||||
self.logits_soft_cap,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
None,
|
||||
_PARTITION_SIZE_ROCM,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Cascade attention is not implemented for ROCM AITER"
|
||||
)
|
||||
|
||||
return output
|
||||
196
vllm_old/v1/attention/backends/rocm_aiter_unified_attn.py
Normal file
196
vllm_old/v1/attention/backends/rocm_aiter_unified_attn.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.rocm_attn import (
|
||||
RocmAttentionBackend,
|
||||
RocmAttentionImpl,
|
||||
RocmAttentionMetadataBuilder,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_AITER_UNIFIED_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["RocmAiterUnifiedAttentionImpl"]:
|
||||
return RocmAiterUnifiedAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
|
||||
return RocmAttentionMetadataBuilder
|
||||
|
||||
|
||||
class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
return quant_key == kFp8StaticTensorSym
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: int | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
sinks,
|
||||
)
|
||||
logger.info_once(
|
||||
"Using aiter unified attention for RocmAiterUnifiedAttentionImpl"
|
||||
)
|
||||
from aiter.ops.triton.unified_attention import unified_attention
|
||||
|
||||
self.unified_attention = unified_attention
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused block_scale output quantization is not yet supported"
|
||||
" for RocmAttentionImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
|
||||
assert attn_metadata.use_cascade is False
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
assert layer._q_scale_float == 1.0, (
|
||||
"A non 1.0 q_scale is not currently supported."
|
||||
)
|
||||
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||
|
||||
self.unified_attention(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
sinks=self.sinks,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
|
||||
return output
|
||||
362
vllm_old/v1/attention/backends/rocm_attn.py
Normal file
362
vllm_old/v1/attention/backends/rocm_attn.py
Normal file
@@ -0,0 +1,362 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
)
|
||||
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RocmAttentionMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: torch.Tensor | None
|
||||
prefix_kv_lens: torch.Tensor | None
|
||||
suffix_kv_lens: torch.Tensor | None
|
||||
|
||||
# Optional aot scheduling
|
||||
scheduler_metadata: torch.Tensor | None = None
|
||||
prefix_scheduler_metadata: torch.Tensor | None = None
|
||||
|
||||
|
||||
class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
self.num_heads_q = model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config)
|
||||
self.headdim = model_config.get_head_size()
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> RocmAttentionMetadata:
|
||||
attn_metadata = self.build(0, common_attn_metadata)
|
||||
# When doing full graph capture, setting seq_lens to
|
||||
# max_model_len will cause graph capture to be extremely
|
||||
# slow, so here we set it to 1.
|
||||
attn_metadata.seq_lens.fill_(1)
|
||||
|
||||
# Here we set the query start locs to 0. This is to
|
||||
# cover up an invalid memory access in the prefix_prefil kernel
|
||||
# that we run into during graph capture (#25985)
|
||||
common_attn_metadata.query_start_loc.zero_()
|
||||
common_attn_metadata.query_start_loc_cpu.zero_()
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> RocmAttentionMetadata:
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
if use_cascade:
|
||||
cu_prefix_query_lens = torch.tensor(
|
||||
[0, num_actual_tokens], dtype=torch.int32, device=self.device
|
||||
)
|
||||
prefix_kv_lens = torch.tensor(
|
||||
[common_prefix_len], dtype=torch.int32, device=self.device
|
||||
)
|
||||
suffix_kv_lens = common_attn_metadata.seq_lens_cpu - common_prefix_len
|
||||
suffix_kv_lens = suffix_kv_lens.to(self.device)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
prefix_scheduler_metadata = None
|
||||
|
||||
attn_metadata = RocmAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class RocmAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
if not cls.supports_head_size(head_size):
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["RocmAttentionImpl"]:
|
||||
return RocmAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
|
||||
return RocmAttentionMetadataBuilder
|
||||
|
||||
|
||||
class RocmAttentionImpl(AttentionImpl):
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
return quant_key == kFp8StaticTensorSym
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: int | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
RocmAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"RocmAttentionImpl"
|
||||
)
|
||||
|
||||
self.fp8_dtype = current_platform.fp8_dtype()
|
||||
|
||||
self.sinks = sinks
|
||||
if sinks is not None:
|
||||
assert sinks.shape[0] == num_heads, (
|
||||
"Sinks must have the same number of heads as the number of "
|
||||
f"heads in the layer. Sinks shape: {sinks.shape}, "
|
||||
f"num_heads: {num_heads}."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused block_scale output quantization is not yet supported"
|
||||
" for RocmAttentionImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
|
||||
assert attn_metadata.use_cascade is False
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size
|
||||
)
|
||||
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
assert layer._q_scale_float == 1.0, (
|
||||
"A non 1.0 q_scale is not currently supported."
|
||||
)
|
||||
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
chunked_prefill_paged_decode(
|
||||
query=query[:num_actual_tokens],
|
||||
key=key[:num_actual_tokens],
|
||||
value=value[:num_actual_tokens],
|
||||
output=output[:num_actual_tokens],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_table=block_table,
|
||||
query_start_loc=cu_seqlens_q,
|
||||
seq_lens=seqused_k,
|
||||
max_seq_len=max_seqlen_k,
|
||||
max_query_len=max_seqlen_q,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window[0],
|
||||
sm_scale=self.scale,
|
||||
output_scale=output_scale,
|
||||
sinks=self.sinks,
|
||||
)
|
||||
|
||||
return output
|
||||
105
vllm_old/v1/attention/backends/short_conv_attn.py
Normal file
105
vllm_old/v1/attention/backends/short_conv_attn.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
PAD_SLOT_ID,
|
||||
CommonAttentionMetadata,
|
||||
compute_causal_conv1d_metadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
|
||||
|
||||
class ShortConvAttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]:
|
||||
return ShortConvAttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShortConvAttentionMetadata:
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
|
||||
query_start_loc: torch.Tensor
|
||||
state_indices_tensor: torch.Tensor
|
||||
has_initial_states_p: torch.Tensor | None
|
||||
|
||||
# For causal_conv1d
|
||||
nums_dict: dict | None = None
|
||||
batch_ptr: torch.Tensor | None = None
|
||||
token_chunk_offset_ptr: torch.Tensor | None = None
|
||||
|
||||
|
||||
class ShortConvAttentionMetadataBuilder(
|
||||
BaseMambaAttentionMetadataBuilder[ShortConvAttentionMetadata]
|
||||
):
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> ShortConvAttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
|
||||
# for causal_conv1d
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
)
|
||||
|
||||
has_initial_states_p = None
|
||||
if num_prefills > 0:
|
||||
has_initial_states_cpu = (
|
||||
common_attn_metadata.num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
> 0
|
||||
)
|
||||
has_initial_states_p = has_initial_states_cpu.to(query_start_loc.device)
|
||||
|
||||
query_start_loc_p = (
|
||||
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||
compute_causal_conv1d_metadata(query_start_loc_p)
|
||||
)
|
||||
|
||||
elif (
|
||||
num_decodes > 0
|
||||
and num_decodes <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
|
||||
self.state_indices_tensor[:num_decodes].copy_(
|
||||
state_indices_tensor, non_blocking=True
|
||||
)
|
||||
state_indices_tensor = self.state_indices_tensor[:num_input_tokens]
|
||||
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
||||
|
||||
attn_metadata = ShortConvAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
)
|
||||
return attn_metadata
|
||||
425
vllm_old/v1/attention/backends/tree_attn.py
Normal file
425
vllm_old/v1/attention/backends/tree_attn.py
Normal file
@@ -0,0 +1,425 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with TreeAttention."""
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TreeAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TREE_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TreeAttentionImpl"]:
|
||||
return TreeAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["TreeAttentionMetadataBuilder"]:
|
||||
return TreeAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class TreeAttentionMetadata:
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
num_prefill_tokens: int = 0
|
||||
num_decode_tokens: int = 0
|
||||
num_prefills: int = 0
|
||||
num_decodes: int = 0
|
||||
|
||||
tree_attn_bias: torch.Tensor | None = None
|
||||
|
||||
# Cached Prefill/decode metadata.
|
||||
_cached_prefill_metadata: Optional["TreeAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional["TreeAttentionMetadata"] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
# Recover cached prefill-phase attention
|
||||
# metadata structure
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
q_start_loc = self.query_start_loc[self.num_decodes :]
|
||||
q_seqlens = torch.diff(q_start_loc)
|
||||
kv_seqlens = self.seq_lens[self.num_decodes :]
|
||||
# Construct & cache prefill-phase attention metadata structure
|
||||
self._cached_prefill_metadata = TreeAttentionMetadata(
|
||||
num_actual_tokens=self.num_prefill_tokens,
|
||||
max_query_len=int(q_seqlens.max().item()),
|
||||
query_start_loc=q_start_loc - q_start_loc[0],
|
||||
max_seq_len=int(kv_seqlens.max().item()),
|
||||
seq_lens=kv_seqlens,
|
||||
block_table=self.block_table[self.num_decodes :],
|
||||
slot_mapping=self.slot_mapping[self.num_decode_tokens :],
|
||||
)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["TreeAttentionMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
# Recover cached decode-phase attention
|
||||
# metadata structure
|
||||
return self._cached_decode_metadata
|
||||
|
||||
q_start_loc = self.query_start_loc[: self.num_decodes + 1]
|
||||
q_seqlens = torch.diff(q_start_loc)
|
||||
kv_seqlens = self.seq_lens[: self.num_decodes]
|
||||
# Construct & cache decode-phase attention metadata structure
|
||||
self._cached_decode_metadata = TreeAttentionMetadata(
|
||||
num_actual_tokens=self.num_decode_tokens,
|
||||
max_query_len=int(q_seqlens.max().item()),
|
||||
query_start_loc=q_start_loc,
|
||||
max_seq_len=int(kv_seqlens.max().item()),
|
||||
seq_lens=kv_seqlens,
|
||||
block_table=self.block_table[: self.num_decodes],
|
||||
slot_mapping=self.slot_mapping[: self.num_decode_tokens],
|
||||
tree_attn_bias=self.tree_attn_bias,
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
|
||||
spec_config = vllm_config.speculative_config
|
||||
spec_token_tree = (spec := spec_config) and spec.speculative_token_tree
|
||||
tree_choices: list[tuple[int, ...]] = (
|
||||
ast.literal_eval(spec_token_tree) if spec_token_tree is not None else [(0,)]
|
||||
)
|
||||
# Construct the tree attention bias.
|
||||
depth_counts = _get_depth_counts(tree_choices)
|
||||
self.tree_attn_bias = _prepare_tree_attn_bias(
|
||||
tree_choices,
|
||||
depth_counts,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.reorder_batch_threshold = self.tree_attn_bias.shape[0]
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> TreeAttentionMetadata:
|
||||
decode_threshold = self.tree_attn_bias.shape[0]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=decode_threshold
|
||||
)
|
||||
)
|
||||
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
q_start_loc = common_attn_metadata.query_start_loc
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
kv_seqlens = common_attn_metadata.seq_lens
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
return TreeAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=q_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=kv_seqlens,
|
||||
block_table=block_table,
|
||||
slot_mapping=slot_mapping,
|
||||
tree_attn_bias=self.tree_attn_bias,
|
||||
)
|
||||
|
||||
def build_for_drafting(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
draft_index: int,
|
||||
) -> TreeAttentionMetadata:
|
||||
# Cache the original tree attention bias.
|
||||
orig_tree_attn_bias = self.tree_attn_bias
|
||||
|
||||
if draft_index == 0:
|
||||
# Use prefill for drafting at the root level.
|
||||
self.tree_attn_bias = torch.empty(0)
|
||||
else:
|
||||
# Slice the tree attention bias for drafting. Exclude
|
||||
# the root level.
|
||||
start, end = 1, 1 + common_attn_metadata.max_query_len
|
||||
self.tree_attn_bias = self.tree_attn_bias[start:end, start:end].contiguous()
|
||||
|
||||
# Build attention bias.
|
||||
attn_metadata = self.build(0, common_attn_metadata, fast_build=True)
|
||||
|
||||
# Reset the tree attention bias to the original value.
|
||||
self.tree_attn_bias = orig_tree_attn_bias
|
||||
return attn_metadata
|
||||
|
||||
|
||||
def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]:
|
||||
# Count the number of choices at each depth of the tree.
|
||||
depth_counts = []
|
||||
prev_depth = 0
|
||||
for path in sorted_tree_choices:
|
||||
depth = len(path)
|
||||
if depth != prev_depth:
|
||||
depth_counts.append(0)
|
||||
depth_counts[depth - 1] += 1
|
||||
prev_depth = depth
|
||||
return depth_counts
|
||||
|
||||
|
||||
def _prepare_tree_attn_bias(
|
||||
sorted_tree_choices: list[tuple[int, ...]],
|
||||
depth_counts: list[int],
|
||||
dtype: torch.dtype | None,
|
||||
device: torch.device | None,
|
||||
) -> torch.Tensor:
|
||||
# +1 comes from the additional root node.
|
||||
tree_len = len(sorted_tree_choices) + 1
|
||||
tree_attn_mask = torch.full(
|
||||
(tree_len, tree_len), -torch.inf, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# Set diagonal to all zeros. Each token should
|
||||
# attend to itself.
|
||||
mask_val = 0
|
||||
for i in range(tree_len):
|
||||
tree_attn_mask[i, i] = mask_val
|
||||
|
||||
# Set root to all zeros. All tokens attend to it.
|
||||
tree_attn_mask[:, 0] = mask_val
|
||||
|
||||
# Set all ancestors to zeros.
|
||||
start = 0
|
||||
for i in range(len(depth_counts)):
|
||||
for j in range(depth_counts[i]):
|
||||
cur_tree_choice = sorted_tree_choices[start + j]
|
||||
# Retrieve ancestor position.
|
||||
if len(cur_tree_choice) == 1:
|
||||
continue
|
||||
ancestor_idx = []
|
||||
for c in range(len(cur_tree_choice) - 1):
|
||||
ancestor_idx.append(
|
||||
sorted_tree_choices.index(cur_tree_choice[: c + 1]) + 1
|
||||
)
|
||||
tree_attn_mask[j + start + 1, ancestor_idx] = mask_val
|
||||
start += depth_counts[i]
|
||||
return tree_attn_mask
|
||||
|
||||
|
||||
class TreeAttentionImpl(AttentionImpl):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if logits_soft_cap is None:
|
||||
# Setting logits_soft_cap to 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TreeAttentionImpl."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: TreeAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with TreeAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported for TreeAttentionImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
|
||||
# Cache the input KVs.
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
descale_shape = (attn_metadata.query_start_loc.shape[0] - 1, key.shape[1])
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
unified_attention(
|
||||
q=query[num_decode_tokens:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[num_decode_tokens:num_actual_tokens],
|
||||
cu_seqlens_q=prefill_meta.query_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_query_len,
|
||||
seqused_k=prefill_meta.seq_lens,
|
||||
max_seqlen_k=prefill_meta.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=prefill_meta.block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
unified_attention(
|
||||
q=query[:num_decode_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_decode_tokens],
|
||||
cu_seqlens_q=decode_meta.query_start_loc,
|
||||
max_seqlen_q=decode_meta.max_query_len,
|
||||
seqused_k=decode_meta.seq_lens,
|
||||
max_seqlen_k=decode_meta.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
qq_bias=decode_meta.tree_attn_bias,
|
||||
window_size=self.sliding_window,
|
||||
block_table=decode_meta.block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
return output
|
||||
373
vllm_old/v1/attention/backends/triton_attn.py
Normal file
373
vllm_old/v1/attention/backends/triton_attn.py
Normal file
@@ -0,0 +1,373 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""High-Performance Triton-only Attention layer."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash,
|
||||
)
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TritonAttentionMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: torch.Tensor | None
|
||||
prefix_kv_lens: torch.Tensor | None
|
||||
suffix_kv_lens: torch.Tensor | None
|
||||
|
||||
# Optional aot scheduling
|
||||
scheduler_metadata: torch.Tensor | None = None
|
||||
prefix_scheduler_metadata: torch.Tensor | None = None
|
||||
|
||||
|
||||
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
self.num_heads_q = model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config)
|
||||
self.headdim = model_config.get_head_size()
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> TritonAttentionMetadata:
|
||||
attn_metadata = self.build(0, common_attn_metadata)
|
||||
# When doing full graph capture, setting seq_lens to
|
||||
# max_model_len will cause graph capture to be extremely
|
||||
# slow, so here we set it to 1.
|
||||
attn_metadata.seq_lens.fill_(1)
|
||||
return attn_metadata
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> TritonAttentionMetadata:
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
if use_cascade:
|
||||
cu_prefix_query_lens = torch.tensor(
|
||||
[0, num_actual_tokens], dtype=torch.int32, device=self.device
|
||||
)
|
||||
prefix_kv_lens = torch.tensor(
|
||||
[common_prefix_len], dtype=torch.int32, device=self.device
|
||||
)
|
||||
suffix_kv_lens = common_attn_metadata.seq_lens_cpu - common_prefix_len
|
||||
suffix_kv_lens = suffix_kv_lens.to(self.device)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
prefix_scheduler_metadata = None
|
||||
|
||||
attn_metadata = TritonAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class TritonAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
"fp8_e5m2",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TritonAttentionImpl"]:
|
||||
return TritonAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (num_blocks, 2, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
|
||||
return TritonAttentionMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def supports_head_size(cls, head_size: int) -> bool:
|
||||
return head_size >= 32
|
||||
|
||||
@classmethod
|
||||
def supports_sink(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class TritonAttentionImpl(AttentionImpl):
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
return quant_key == kFp8StaticTensorSym
|
||||
|
||||
def supports_quant_query_input(self) -> bool:
|
||||
return current_platform.is_cuda()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: int | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TritonAttentionImpl"
|
||||
)
|
||||
|
||||
self.fp8_dtype = current_platform.fp8_dtype()
|
||||
|
||||
self.sinks = sinks
|
||||
if sinks is not None:
|
||||
assert sinks.shape[0] == num_heads, (
|
||||
"Sinks must have the same number of heads as the number of "
|
||||
f"heads in the layer. Sinks shape: {sinks.shape}, "
|
||||
f"num_heads: {num_heads}."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: TritonAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Paged Attention impl. in Triton.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[num_blocks, 2, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused block_scale output quantization is not yet supported"
|
||||
" for TritonAttentionImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
|
||||
assert attn_metadata.use_cascade is False
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
# triton kernel does not support uint8 kv_cache
|
||||
# (because some explicit casts (e.g. float8_e4m3fnuz)
|
||||
# are not supported)
|
||||
triton_reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
if key_cache.dtype != self.fp8_dtype:
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
assert layer._q_scale_float == 1.0, (
|
||||
"A non 1.0 q_scale is not currently supported."
|
||||
)
|
||||
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||
|
||||
unified_attention(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
sinks=self.sinks,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
|
||||
return output
|
||||
1117
vllm_old/v1/attention/backends/utils.py
Normal file
1117
vllm_old/v1/attention/backends/utils.py
Normal file
File diff suppressed because it is too large
Load Diff
417
vllm_old/v1/attention/backends/xformers.py
Normal file
417
vllm_old/v1/attention/backends/xformers.py
Normal file
@@ -0,0 +1,417 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with XFormersAttention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import (
|
||||
AttentionBias,
|
||||
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
|
||||
)
|
||||
|
||||
XFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XFormersAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [
|
||||
32,
|
||||
40,
|
||||
48,
|
||||
56,
|
||||
64,
|
||||
72,
|
||||
80,
|
||||
88,
|
||||
96,
|
||||
104,
|
||||
112,
|
||||
120,
|
||||
128,
|
||||
136,
|
||||
144,
|
||||
152,
|
||||
160,
|
||||
168,
|
||||
176,
|
||||
184,
|
||||
192,
|
||||
200,
|
||||
208,
|
||||
216,
|
||||
224,
|
||||
232,
|
||||
240,
|
||||
248,
|
||||
256,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "XFORMERS"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["XFormersAttentionImpl"]:
|
||||
return XFormersAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["XFormersAttentionMetadataBuilder"]:
|
||||
return XFormersAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class XFormersAttentionMetadata:
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
num_prefill_tokens: int = 0
|
||||
num_decode_tokens: int = 0
|
||||
num_prefills: int = 0
|
||||
num_decodes: int = 0
|
||||
|
||||
# Biases for different attention types.
|
||||
attn_bias: Optional["AttentionBias"] = None
|
||||
|
||||
# Self-attention prefill/decode metadata cache
|
||||
_cached_prefill_metadata: Optional["XFormersAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional["XFormersAttentionMetadata"] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["XFormersAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
# Recover cached prefill-phase attention
|
||||
# metadata structure
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
q_start_loc = self.query_start_loc[self.num_decodes :]
|
||||
q_seqlens = torch.diff(q_start_loc)
|
||||
kv_seqlens = self.seq_lens[self.num_decodes :]
|
||||
# Construct & cache prefill-phase attention metadata structure
|
||||
self._cached_prefill_metadata = XFormersAttentionMetadata(
|
||||
num_actual_tokens=self.num_prefill_tokens,
|
||||
max_query_len=int(q_seqlens.max().item()),
|
||||
query_start_loc=q_start_loc - q_start_loc[0],
|
||||
max_seq_len=int(kv_seqlens.max().item()),
|
||||
seq_lens=kv_seqlens,
|
||||
block_table=self.block_table[self.num_decodes :],
|
||||
slot_mapping=self.slot_mapping[self.num_decode_tokens :],
|
||||
)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["XFormersAttentionMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
# Recover cached decode-phase attention
|
||||
# metadata structure
|
||||
return self._cached_decode_metadata
|
||||
|
||||
q_start_loc = self.query_start_loc
|
||||
q_seqlens = torch.diff(q_start_loc)
|
||||
decode_kv_seqlens = self.seq_lens[: self.num_decodes]
|
||||
# Construct & cache decode-phase attention metadata structure
|
||||
self._cached_decode_metadata = XFormersAttentionMetadata(
|
||||
num_actual_tokens=self.num_decode_tokens,
|
||||
max_query_len=int(q_seqlens[: self.num_decodes].max().item()),
|
||||
query_start_loc=q_start_loc[: self.num_decodes + 1],
|
||||
max_seq_len=int(decode_kv_seqlens.max().item()),
|
||||
seq_lens=decode_kv_seqlens,
|
||||
block_table=self.block_table[: self.num_decodes],
|
||||
slot_mapping=self.slot_mapping[: self.num_decode_tokens],
|
||||
attn_bias=self.attn_bias,
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class XFormersAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[XFormersAttentionMetadata]
|
||||
):
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
assert XFORMERS_AVAILABLE
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self._num_decodes = 0
|
||||
self._num_decode_tokens = 0
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> XFormersAttentionMetadata:
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
)
|
||||
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
q_start_loc = common_attn_metadata.query_start_loc
|
||||
q_seqlens = torch.diff(q_start_loc)
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
kv_seqlens = common_attn_metadata.seq_lens
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
bias = None
|
||||
if num_decodes > 0:
|
||||
# Construct the decoder bias.
|
||||
decode_q_seqlens = q_seqlens[:num_decodes]
|
||||
decode_kv_seqlens = kv_seqlens[:num_decodes]
|
||||
bias = PagedBlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
|
||||
q_seqlen=decode_q_seqlens.tolist(),
|
||||
kv_seqlen=decode_kv_seqlens.tolist(),
|
||||
page_size=self.block_size,
|
||||
block_tables=block_table[:num_decodes],
|
||||
device=block_table.device,
|
||||
)
|
||||
|
||||
return XFormersAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=q_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=kv_seqlens,
|
||||
block_table=block_table,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_bias=bias,
|
||||
)
|
||||
|
||||
|
||||
class XFormersAttentionImpl(AttentionImpl):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError("XFormers does not support alibi slopes yet.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
if logits_soft_cap is None:
|
||||
# Setting logits_soft_cap to 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"XFormersAttentionImpl."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: XFormersAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with XFormers.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for XFormersAttentionImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
|
||||
# Cache the input KVs.
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, key.shape[1])
|
||||
unified_attention(
|
||||
q=query[num_decode_tokens:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[num_decode_tokens:num_actual_tokens],
|
||||
cu_seqlens_q=prefill_meta.query_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_query_len,
|
||||
seqused_k=prefill_meta.seq_lens,
|
||||
max_seqlen_k=prefill_meta.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=prefill_meta.block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[:num_decode_tokens]
|
||||
# Reshape query to [1, B_T, G, H, D].
|
||||
q = decode_query.view(
|
||||
1, -1, self.num_kv_heads, self.num_queries_per_kv, self.head_size
|
||||
)
|
||||
# Reshape the k and v caches to [1, Bkv_T, G, H, D]
|
||||
cache_k = key_cache.view(
|
||||
1, -1, self.num_kv_heads, 1, self.head_size
|
||||
).expand(
|
||||
1,
|
||||
-1,
|
||||
self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
self.head_size,
|
||||
)
|
||||
cache_v = value_cache.view(
|
||||
1, -1, self.num_kv_heads, 1, self.head_size
|
||||
).expand(
|
||||
1,
|
||||
-1,
|
||||
self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
self.head_size,
|
||||
)
|
||||
|
||||
attn_bias = decode_meta.attn_bias
|
||||
output[:num_decode_tokens] = xops.memory_efficient_attention_forward(
|
||||
q,
|
||||
cache_k,
|
||||
cache_v,
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
).view(decode_query.shape)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output
|
||||
Reference in New Issue
Block a user