Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -15,13 +15,11 @@ logger = init_logger(__name__)
_ROCM_FLASH_ATTN_AVAILABLE = False
if current_platform.is_cuda():
from vllm._custom_ops import reshape_and_cache_flash
# from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
# flash_attn_varlen_func,
# get_scheduler_metadata,
# )
from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from vllm import _custom_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache, flash_attn_varlen_int8_func
elif current_platform.is_xpu():
from vllm import _custom_ops as ops
from vllm._xpu_ops import xpu_ops
@@ -53,67 +51,93 @@ elif current_platform.is_rocm():
reshape_and_cache_flash = ops.reshape_and_cache_flash
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
# import here to avoid circular dependencies
from vllm.platforms import current_platform
return 3
def get_flash_attn_version(
requires_alibi: bool = False, head_size: int | None = None
) -> int | None:
if current_platform.is_xpu():
return 2
if current_platform.is_rocm():
# ROCm doesn't use vllm_flash_attn; return None to skip fa_version arg
return None
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason,
is_fa_version_supported,
)
return None
# try:
# from vllm.vllm_flash_attn.flash_attn_interface import (
# fa_version_unsupported_reason,
# is_fa_version_supported,
# )
device_capability = current_platform.get_device_capability()
# device_capability = current_platform.get_device_capability()
assert device_capability is not None
# assert device_capability is not None
# 1. default version depending on platform
fa_version = (
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
)
# # 1. default version depending on platform
# if device_capability.major == 9 and is_fa_version_supported(3):
# # Hopper (SM90): prefer FA3
# fa_version = 3
# elif device_capability.major == 10 and is_fa_version_supported(4):
# # Blackwell (SM100+, restrict to SM100 for now): prefer FA4
# fa_version = 4
# else:
# # Fallback to FA2
# fa_version = 2
# 2. override if passed by environment or config
from vllm.config import get_current_vllm_config_or_none
# # 2. override if passed by environment or config
# from vllm.config import get_current_vllm_config_or_none
vllm_config = get_current_vllm_config_or_none()
if (
vllm_config is not None
and vllm_config.attention_config.flash_attn_version is not None
):
fa_version = vllm_config.attention_config.flash_attn_version
# vllm_config = get_current_vllm_config_or_none()
# if (
# vllm_config is not None
# and vllm_config.attention_config.flash_attn_version is not None
# ):
# fa_version = vllm_config.attention_config.flash_attn_version
# 3. fallback for unsupported combinations
if device_capability.major == 10 and fa_version == 3:
logger.warning_once(
"Cannot use FA version 3 on Blackwell platform, "
"defaulting to FA version 2."
)
fa_version = 2
# # 3. fallback for unsupported combinations
# if device_capability.major >= 10 and fa_version == 3:
# logger.warning_once(
# "Cannot use FA version 3 on Blackwell platform, "
# "defaulting to FA version 4 if supported, otherwise FA2."
# )
# fa_version = 4 if is_fa_version_supported(4) else 2
if requires_alibi and fa_version == 3:
logger.warning_once(
"Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
)
fa_version = 2
# if requires_alibi and fa_version == 3:
# logger.warning_once(
# "Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
# )
# fa_version = 2
if not is_fa_version_supported(fa_version):
logger.error(
"Cannot use FA version %d is not supported due to %s",
fa_version,
fa_version_unsupported_reason(fa_version),
)
# if requires_alibi and fa_version == 4:
# logger.warning_once(
# "Cannot use FA version 4 with ALiBi, defaulting to FA version 2."
# )
# fa_version = 2
assert is_fa_version_supported(fa_version)
return fa_version
except (ImportError, AssertionError):
return None
# # FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict
# # supported head dimensions.
# # See: https://github.com/Dao-AILab/flash-attention/issues/1959
# if (
# fa_version == 4
# and device_capability.major >= 10
# and head_size is not None
# and head_size > 128
# ):
# logger.warning_once(
# "FA4 on Blackwell does not support head_size=%d due to TMEM "
# "capacity limits, defaulting to FA version 2.",
# head_size,
# )
# fa_version = 2
# if not is_fa_version_supported(fa_version):
# logger.error(
# "Cannot use FA version %d is not supported due to %s",
# fa_version,
# fa_version_unsupported_reason(fa_version),
# )
# assert is_fa_version_supported(fa_version)
# return fa_version
# except (ImportError, AssertionError):
# return None
def flash_attn_supports_fp8() -> bool:
@@ -124,10 +148,7 @@ def flash_attn_supports_fp8() -> bool:
def flash_attn_supports_sinks() -> bool:
if current_platform.is_xpu():
return True
else:
return get_flash_attn_version() == 3
return True
def flash_attn_supports_mla():
@@ -142,6 +163,10 @@ def flash_attn_supports_mla():
return is_fa_version_supported(
3
) and current_platform.is_device_capability_family(90)
# NOTE(Lucas): FA4 CuteDSL does NOT currently support MLA's non-standard
# head dimensions (576 for qk, 512 for v) due to TMEM capacity limits.
except (ImportError, AssertionError):
pass
return False

View File

@@ -4,7 +4,7 @@
import copy
from dataclasses import dataclass
from typing import ClassVar
from typing import ClassVar, Optional, Union, List
import numpy as np
import torch
@@ -23,15 +23,15 @@ from vllm.v1.attention.backends.fa_utils import (
is_flash_attn_varlen_func_available,
)
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from ixformer.contrib.vllm_flash_attn import merge_attn_states
if is_flash_attn_varlen_func_available():
from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_sinks,
flash_attn_varlen_func,
flash_attn_with_kvcache,
# get_scheduler_metadata,
reshape_and_cache_flash,
flash_attn_varlen_int8_func
)
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
from vllm.config.cache import CacheDType
@@ -50,9 +50,12 @@ from vllm.v1.attention.backend import (
from vllm.v1.attention.backends.utils import (
get_dcp_local_seq_lens,
get_kv_cache_layout,
split_decodes_and_prefills,
split_decodes_and_prefills
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm import _custom_ops as ops
import vllm.envs as envs
import ixformer.inference.functions as ixf_ops
logger = init_logger(__name__)
@@ -63,23 +66,7 @@ class FlashAttentionBackend(AttentionBackend):
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
if (
model_config
and model_config.is_hybrid
and (
cache_config.mamba_ssm_cache_dtype == "float32"
or cache_config.mamba_cache_dtype == "float32"
)
):
# NOTE(tdoublep): while in principle, FA supports
# MultipleOf(16), these are the block sizes that do not
# suffer from the NaN propagation problem described here:
# https://github.com/Dao-AILab/flash-attention/issues/1974
return [16, 32, 64]
return [MultipleOf(16)]
return [16, 32, 64]
forward_includes_kv_cache_update: bool = False
@@ -120,7 +107,8 @@ class FlashAttentionBackend(AttentionBackend):
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
# return (2, num_blocks, block_size, num_kv_heads, head_size)
if envs.VLLM_ATTN_OPT_LEVEL == 2:
return (3, num_blocks, num_kv_heads, block_size, head_size)
return (2, num_blocks, num_kv_heads, block_size, head_size)
@staticmethod
@@ -139,7 +127,7 @@ class FlashAttentionBackend(AttentionBackend):
# (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
return (2, 4, 0, 1, 3, 5)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
stride_order = (0, 1, 2, 3, 4)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order
@@ -188,24 +176,22 @@ class FlashAttentionBackend(AttentionBackend):
if has_sink and device_capability < DeviceCapability(9, 0):
return "sink not supported on compute capability < 9.0"
return None
@dataclass
class FlashAttentionPrefillMetadata:
"""Prefill Specific Metadata"""
""" Prefill Specific Metadata """
block_table: torch.Tensor
query_start_loc: torch.Tensor
key_start_loc: torch.Tensor
max_query_len: int
@dataclass
class FlashAttentionDecodeMetadata:
block_table: torch.Tensor
query_start_loc: torch.Tensor
seq_lens: torch.Tensor
max_query_len: int
max_decode_seq_len: int
use_graph: bool
@dataclass
class FlashAttentionMetadata:
@@ -220,11 +206,12 @@ class FlashAttentionMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
key_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
num_decodes: int
num_decode_tokens: int
num_prefills: int
@@ -235,7 +222,6 @@ class FlashAttentionMetadata:
cu_prefix_query_lens: torch.Tensor | None
prefix_kv_lens: torch.Tensor | None
suffix_kv_lens: torch.Tensor | None
cu_prefix_kv_lens: torch.Tensor | None
cu_suffix_kv_lens: torch.Tensor | None
@@ -247,7 +233,7 @@ class FlashAttentionMetadata:
scheduler_metadata: torch.Tensor | None = None
prefix_scheduler_metadata: torch.Tensor | None = None
max_num_splits: int = 0
prefill: FlashAttentionPrefillMetadata | None = None
decode: FlashAttentionDecodeMetadata | None = None
@@ -291,7 +277,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
else AttentionCGSupport.UNIFORM_BATCH
)
supports_update_block_table: bool = True
reorder_batch_threshold: ClassVar[int] = 1
@classmethod
@@ -316,6 +302,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.compilation_config = vllm_config.compilation_config
self.attention_config = vllm_config.attention_config
self.decode_use_graph = vllm_config.compilation_config.cudagraph_mode.decode_use_graph()
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config
)
@@ -325,7 +312,6 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.block_size = kv_cache_spec.block_size
self.max_num_splits = 0 # No upper bound on the number of splits.
# self.aot_schedule = get_flash_attn_version() == 3
self.aot_schedule = False
try:
@@ -346,6 +332,9 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
# Align decode/prefill split threshold with speculative decode query length
# when backend supports treating spec requests as decode.
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
if self.use_full_cuda_graph and self.aot_schedule:
# FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4
@@ -388,15 +377,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
max_query_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
key_start_loc = common_attn_metadata.key_start_loc
seq_lens = common_attn_metadata.seq_lens
seq_lens_np = common_attn_metadata.seq_lens_np
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
causal = common_attn_metadata.causal
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata)
)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
@@ -467,11 +458,11 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
dcp_context_kv_lens = None
cu_prefix_query_lens = None
cu_prefix_kv_lens = None
cu_suffix_kv_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
prefix_scheduler_metadata = None
cu_prefix_kv_lens = None
cu_suffix_kv_lens = None
if self.dcp_world_size > 1:
query_kv_lens = query_start_loc[1:] - query_start_loc[:-1]
@@ -507,11 +498,11 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
prefix_kv_lens = torch.tensor(
[common_prefix_len], dtype=torch.int32, device=self.device
)
# Use GPU tensor directly - no CPU sync needed
suffix_kv_lens = seq_lens[:num_reqs] - common_prefix_len
cu_prefix_kv_lens = torch.tensor([0, common_prefix_len],
dtype=torch.int32,
device=self.device)
# Use GPU tensor directly - no CPU sync needed
suffix_kv_lens = seq_lens[:num_reqs] - common_prefix_len
cu_suffix_kv_lens = torch.tensor([0,] + suffix_kv_lens.tolist(),
dtype=torch.int32,
@@ -542,7 +533,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
causal=causal,
)
# For FA3 + full cudagraph
max_num_splits = 0
max_num_splits = 0
if self.use_full_cuda_graph and scheduler_metadata is not None:
n = scheduler_metadata.shape[0]
self.scheduler_metadata[:n] = scheduler_metadata
@@ -552,50 +543,59 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
# output buffer.
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]
if num_actual_tokens <= self.max_cudagraph_size:
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
prefill_metadata = None
if num_prefills > 0:
reqs_start = num_decodes
prefill_query_start_loc = (
query_start_loc[reqs_start:] - query_start_loc[reqs_start]
)
prefill_key_start_loc = (
query_start_loc[reqs_start:] - query_start_loc[reqs_start]
)
reqs_start = num_decodes # prefill_start
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
prefill_key_start_loc = key_start_loc[
reqs_start:] - key_start_loc[reqs_start]
prefill_metadata = FlashAttentionPrefillMetadata(
block_table=block_table_tensor[reqs_start:, ...],
query_start_loc=prefill_query_start_loc,
key_start_loc=prefill_key_start_loc,
max_query_len=max_query_len,
)
block_table=block_table_tensor[reqs_start:, ...],
query_start_loc=prefill_query_start_loc,
key_start_loc=prefill_key_start_loc,
max_query_len=max_query_len,
)
decode_metadata = None
if num_decodes > 0:
reqs_start = num_decodes
reqs_start = num_decodes # prefill_start
decode_query_start_loc = query_start_loc[: reqs_start + 1]
decode_query_lens = (
decode_query_start_loc[1:] - decode_query_start_loc[:-1]
)
decode_metadata = FlashAttentionDecodeMetadata(
block_table=block_table_tensor[:reqs_start, ...],
query_start_loc=decode_query_start_loc,
seq_lens=seq_lens[:reqs_start],
max_decode_seq_len=torch.max(seq_lens[:reqs_start]).item(),
max_query_len=decode_query_lens.max().item(),
max_decode_seq_len=np.max(seq_lens_np[:reqs_start]).item(),
use_graph=num_prefills==0 and self.decode_use_graph
)
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
key_start_loc=key_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
max_dcp_context_kv_len=max_dcp_context_kv_len,
dcp_context_kv_lens=dcp_context_kv_lens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
max_dcp_context_kv_len=max_dcp_context_kv_len,
dcp_context_kv_lens=dcp_context_kv_lens,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
scheduler_metadata=scheduler_metadata,
@@ -607,8 +607,8 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
prefix_scheduler_metadata=prefix_scheduler_metadata,
max_num_splits=max_num_splits,
causal=causal,
prefill=prefill_metadata,
decode=decode_metadata,
prefill = prefill_metadata,
decode = decode_metadata,
)
return attn_metadata
@@ -621,6 +621,19 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
new_metadata = copy.copy(metadata)
new_metadata.block_table = blk_table
new_metadata.slot_mapping = slot_mapping
# Keep nested prefill/decode block tables in sync. Decode path consumes
# `attn_metadata.decode.block_table`, so updating only the top-level
# `block_table` is insufficient when metadata is reused across groups.
if metadata.decode is not None:
new_decode = copy.copy(metadata.decode)
reqs_start = metadata.num_decodes
new_decode.block_table = blk_table[:reqs_start, ...]
new_metadata.decode = new_decode
if metadata.prefill is not None:
new_prefill = copy.copy(metadata.prefill)
reqs_start = metadata.num_decodes
new_prefill.block_table = blk_table[reqs_start:, ...]
new_metadata.prefill = new_prefill
return new_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool:
@@ -667,7 +680,15 @@ class FlashAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()
self.vllm_flash_attn_version = get_flash_attn_version(
requires_alibi=alibi_slopes is not None,
head_size=head_size,
)
logger.info_once(
"Using FlashAttention version %s",
self.vllm_flash_attn_version,
scope="local",
)
# Cache the batch invariant result for use in forward passes
self.batch_invariant_enabled = vllm_is_batch_invariant()
@@ -677,6 +698,7 @@ class FlashAttentionImpl(AttentionImpl):
)
self.sinks = sinks
if self.sinks is not None:
assert flash_attn_supports_sinks(), (
"Sinks are only supported in FlashAttention 3"
@@ -687,6 +709,28 @@ class FlashAttentionImpl(AttentionImpl):
)
self.supports_quant_query_input = True
self.supports_per_head_quant_scales = (
self.vllm_flash_attn_version >= 3
if self.vllm_flash_attn_version is not None
else False
)
assert envs.VLLM_ATTN_OPT_LEVEL in [0, 1, 2], "VLLM_ATTN_OPT_LEVEL only support [0 for non-quant, 1 for I8Q_I8K_I8V, 2 for I8Q_I8K_F16V] now! but got {}".format(envs.VLLM_ATTN_OPT_LEVEL)
'''
quant_type = 0
attention:f16 qkv
cache:f16 kv cache
quant_type = 1
attention:int8q int8k int8v
cache:
int8 k cache && fp32 k cache scale
int8 v cache && fp32 v cache scale(load from file, dont update)
quant_type = 2
attention:int8q int8k fp16v
cache:
int8 k cache && fp32 k cache scale
fp16 v cache
'''
self.quant_type = int(envs.VLLM_ATTN_OPT_LEVEL)
def forward(
self,
@@ -698,7 +742,7 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: FlashAttentionMetadata,
output: torch.Tensor | None = None,
sqrt_alibi: bool = False,
kv_cache_scale: torch.Tensor | None = None,
kv_cache_scale: Union[torch.Tensor, List[torch.Tensor]] | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
@@ -711,6 +755,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
kv_cache_scale = [num_blocks, num_kv_heads, block_size] + [num_kv_heads, head_size]
Returns:
shape = [num_tokens, num_heads * head_size]
NOTE: FP8 quantization, flash-attn expect the size of
@@ -718,9 +763,9 @@ class FlashAttentionImpl(AttentionImpl):
We use torch's .expand() to avoid duplicating values
"""
assert output is not None, "Output tensor must be provided."
assert self.vllm_flash_attn_version is not None, (
"FlashAttention version not detected."
)
# assert self.vllm_flash_attn_version is not None, (
# "FlashAttention version not detected."
# )
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
@@ -729,13 +774,12 @@ class FlashAttentionImpl(AttentionImpl):
if attn_metadata is None:
# Profiling run.
# return output.fill_(0)
return output.fill_(0).view(-1, self.num_heads * self.head_size)
return output.view(-1, self.num_heads * self.head_size)
softmax_scale: float = self.scale
window_size = self.sliding_window
alibi_slopes: torch.Tensor | None = self.alibi_slopes
logits_soft_cap: float | None = self.logits_soft_cap
alibi_slopes: torch.Tensor = self.alibi_slopes
logits_soft_cap: float = self.logits_soft_cap
attn_type = self.attn_type
@@ -761,18 +805,140 @@ class FlashAttentionImpl(AttentionImpl):
output[:num_actual_tokens],
attn_metadata,
layer,
)
).view(-1, self.num_heads * self.head_size)
# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(0)
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
decode_only = has_decode and not has_prefill
num_decode_tokens = attn_metadata.num_decode_tokens
if self.quant_type == 0:
key_cache, value_cache = kv_cache.unbind(0)
elif self.quant_type == 1:
i8_key_cache, i8_value_cache = kv_cache.unbind(0)
num_blocks, num_kv_heads, block_size, head_size = i8_key_cache.shape
key_scale_cache, value_scale_cache = kv_cache_scale
assert key_scale_cache.shape == (num_blocks, num_kv_heads, block_size) and key_scale_cache.dtype == torch.float32, f"key_scale_cache.shape {key_scale_cache.shape} != (num_blocks, num_kv_heads, block_size) or key_scale_cache.dtype {key_scale_cache.dtype} != torch.float32"
assert value_scale_cache.shape == (num_kv_heads, head_size) and value_scale_cache.dtype == torch.float32, f"value_scale_cache.shape {value_scale_cache.shape} != (num_kv_heads, head_size) or value_scale_cache.dtype {value_scale_cache.dtype} != torch.float32"
value_cache_info = (i8_value_cache, value_scale_cache)
elif self.quant_type == 2:
# key_cache 是 f16value_cache 是 int8
i8_key_cache = kv_cache[0]
num_blocks, num_kv_heads, block_size, head_size = i8_key_cache.shape
value_cache = kv_cache[1:].view(query.dtype).reshape(num_blocks, num_kv_heads, block_size, head_size)
key_scale_cache = kv_cache_scale
value_cache_info = (value_cache, None)
decode_q = query[:num_decode_tokens]
prefill_q = query[num_decode_tokens:]
prefill_output = output[num_decode_tokens:]
decode_output = output[:num_decode_tokens]
if self.quant_type == 1:
if decode_only:
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
query, 2, transpose_scale=False
)
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
key, 2, transpose_scale=False
)
i8_value, _value_scale = ixf_ops.scaled_int8_quant_for_attn(
value, 0, transpose_scale=False, scale=value_cache_info[1]
)
else:
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
query, 2, transpose_scale=True
)
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
key, 2, transpose_scale=False
)
i8_value, _value_scale = ixf_ops.scaled_int8_quant_for_attn(
value, 0, transpose_scale=False, scale=value_cache_info[1]
)
elif self.quant_type == 2:
'''
origin key cache
num_blocks, num_kv_heads, block_size, head_size f16
reformat key cache
key_cache_i8 : num_blocks, num_kv_heads, block_size, head_size int8
key_scale_cache : num_blocks, num_kv_heads, block_size fp32
'''
if decode_only:
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
query, 2, transpose_scale=False
)
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
key, 2, transpose_scale=False
)
else:
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
query, 2, transpose_scale=True
)
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
key, 2, transpose_scale=False
)
else:
if layer.quant_manager is not None and layer.quant_manager.check_enable():
i8_value, value_scale = ixf_ops.scaled_int8_quant_for_attn(
value, 0, transpose_scale=False
)
layer.quant_manager.update_data(value_scale)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
if self.quant_type == 1:
if has_prefill:
ixf_ops.reshape_and_cache_flash_int8(
key=i8_key,
value=i8_value,
k_scale=key_scale,
key_cache=i8_key_cache,
value_cache=value_cache_info[0],
key_scale_cache=key_scale_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache_dtype="",
)
elif self.quant_type == 2:
if has_prefill:
ixf_ops.reshape_and_cache_flash_mix(
key=i8_key,
value=value,
k_scale=key_scale,
key_cache=i8_key_cache,
value_cache=value_cache_info[0],
key_scale_cache=key_scale_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache_dtype="",
)
else:
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
# queries are quantized in the attention layer
@@ -783,19 +949,6 @@ class FlashAttentionImpl(AttentionImpl):
value_cache = value_cache.view(dtype)
if not attn_metadata.use_cascade:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
q_descale = layer._q_scale.expand(descale_shape)
k_descale = layer._k_scale.expand(descale_shape)
v_descale = layer._v_scale.expand(descale_shape)
if self.dcp_world_size > 1:
self._forward_with_dcp(
query[:num_actual_tokens],
@@ -805,79 +958,140 @@ class FlashAttentionImpl(AttentionImpl):
value_cache,
output[:num_actual_tokens],
attn_metadata,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
return output.view(-1, self.num_heads * self.head_size)
else:
sliding_window_size = (
list(self.sliding_window)
if self.sliding_window is not None
else None
)
if has_prefill:
flash_attn_varlen_func(
q=prefill_q,
k=key_cache,
v=value_cache,
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
max_seqlen_q=attn_metadata.prefill.max_query_len,
max_seqlen_k=attn_metadata.max_query_len,
softmax_scale=softmax_scale,
causal=True,
window_size=sliding_window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
sqrt_alibi=sqrt_alibi,
sinks=self.sinks,
out=prefill_output,
block_table=attn_metadata.prefill.block_table,
)
# key = key[num_decode_tokens:]
# value = value[num_decode_tokens:]
# int8 attn
if self.quant_type > 0:
flash_attn_varlen_int8_func(
q=int8_query[num_decode_tokens:],
k=i8_key_cache,
v=value_cache_info[0],
q_scale=query_scale[:, num_decode_tokens:],
k_scale=key_scale_cache,
v_scale=value_cache_info[1],
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.prefill.key_start_loc,
max_seqlen_q=attn_metadata.prefill.max_query_len,
max_seqlen_k=attn_metadata.max_query_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
sqrt_alibi=sqrt_alibi,
out=prefill_output,
block_table=attn_metadata.prefill.block_table,
output_dtype=query.dtype
)
else:
flash_attn_varlen_func(
q=prefill_q,
k=key_cache,
v=value_cache,
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.prefill.key_start_loc,
max_seqlen_q=attn_metadata.prefill.max_query_len,
max_seqlen_k=attn_metadata.max_query_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
sqrt_alibi=sqrt_alibi,
sinks=self.sinks,
out=prefill_output,
block_table=attn_metadata.prefill.block_table,
)
if has_decode:
flash_attn_with_kvcache(
q=decode_q.unsqueeze(1),
k_cache=key_cache.contiguous(),
v_cache=value_cache.contiguous(),
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
softmax_scale=softmax_scale,
causal=True,
window_size=sliding_window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
use_sqrt_alibi=sqrt_alibi,
out=decode_output.unsqueeze(1),
max_context_len=attn_metadata.decode.max_decode_seq_len,
# sinks=self.sinks,
)
# for mtp + cuda graph
max_q_len = attn_metadata.decode.max_query_len if attn_metadata.decode is not None else attn_metadata.max_query_len
max_ct_len = attn_metadata.decode.max_decode_seq_len if attn_metadata.decode is not None else attn_metadata.max_seq_len
if self.quant_type in [1, 2]:
para_dict = dict(
output=decode_output,
query=int8_query[:num_decode_tokens],
key_cache=i8_key_cache,
query_scale=query_scale[:num_decode_tokens] if decode_only else query_scale[:, :num_decode_tokens].t().contiguous(),
key_scale_cache=key_scale_cache,
num_kv_heads=self.num_kv_heads,
scale=softmax_scale,
block_tables=attn_metadata.decode.block_table,
context_lens=attn_metadata.decode.seq_lens,
block_size=i8_key_cache.shape[-2],
softcap=logits_soft_cap,
alibi_slopes=alibi_slopes,
causal=True,
window_left=window_size[0],
window_right=window_size[1],
use_sqrt_alibi = sqrt_alibi,
use_cuda_graph=attn_metadata.decode.use_graph if decode_only else False,
max_context_len=max_ct_len,
# mtp
cu_query_lens=attn_metadata.decode.query_start_loc,
max_query_len=max_q_len,
)
if self.quant_type == 1:
para_dict.update(
dict(
value_cache=value_cache_info[0],
value_scale_cache=value_cache_info[1],
)
)
# for kv + k_scale write fusion
if decode_only:
para_dict.update(
dict(
save_key=i8_key[:num_decode_tokens],
save_value=i8_value[:num_decode_tokens],
save_key_scale=key_scale[:num_decode_tokens],
)
)
ixf_ops.vllm_paged_attention_int8(**para_dict)
elif self.quant_type == 2:
para_dict.update(
dict(
value_cache=value_cache,
)
)
if decode_only:
para_dict.update(
dict(
save_key=i8_key[:num_decode_tokens],
save_value=value[:num_decode_tokens].contiguous(),
save_key_scale=key_scale[:num_decode_tokens],
)
)
ixf_ops.vllm_paged_attention_mix(
**para_dict
)
else:
flash_attn_with_kvcache(
q=decode_q.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
max_query_len=max_q_len,
cu_query_lens=attn_metadata.decode.query_start_loc,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
use_sqrt_alibi=sqrt_alibi,
sinks=self.sinks,
out=decode_output.unsqueeze(1),
use_cuda_graph=attn_metadata.decode.use_graph,
max_context_len=max_ct_len
)
# Compute attention and update output up to `num_actual_tokens`.
return output.view(-1, self.num_heads * self.head_size)
# flash_attn_varlen_func(
# q=query[:num_actual_tokens],
# k=key_cache,
# v=value_cache,
# out=output[:num_actual_tokens],
# cu_seqlens_q=cu_seqlens_q,
# max_seqlen_q=max_seqlen_q,
# seqused_k=seqused_k,
# max_seqlen_k=max_seqlen_k,
# softmax_scale=self.scale,
# causal=attn_metadata.causal,
# alibi_slopes=self.alibi_slopes,
# window_size=sliding_window_size,
# block_table=block_table,
# softcap=self.logits_soft_cap,
# scheduler_metadata=scheduler_metadata,
# fa_version=self.vllm_flash_attn_version,
# q_descale=q_descale,
# k_descale=k_descale,
# v_descale=v_descale,
# num_splits=attn_metadata.max_num_splits,
# s_aux=self.sinks,
# )
# return output
# Cascade attention (rare case).
cascade_attention(
@@ -906,12 +1120,7 @@ class FlashAttentionImpl(AttentionImpl):
v_descale=layer._v_scale,
s_aux=self.sinks,
)
# return output
return (
output[:num_actual_tokens]
.contiguous()
.view(-1, self.num_heads * self.head_size)
)
return output.view(-1, self.num_heads * self.head_size)
def do_kv_cache_update(
self,
@@ -935,7 +1144,7 @@ class FlashAttentionImpl(AttentionImpl):
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
reshape_and_cache_flash(
ops.reshape_and_cache_flash(
key,
value,
key_cache,
@@ -959,9 +1168,9 @@ class FlashAttentionImpl(AttentionImpl):
k_descale: torch.Tensor | None = None,
v_descale: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.vllm_flash_attn_version is not None, (
"FlashAttention version not detected."
)
# assert self.vllm_flash_attn_version is not None, (
# "FlashAttention version not detected."
# )
cu_seqlens_q = attn_metadata.query_start_loc
max_seqlen_q = attn_metadata.max_query_len
@@ -969,27 +1178,22 @@ class FlashAttentionImpl(AttentionImpl):
query = query.contiguous()
query_across_dcp = get_dcp_group().all_gather(query, dim=1)
cu_dcp_kv_klens = attn_metadata.dcp_context_kv_lens.cumsum(dim=0, dtype=torch.int32)
new_tensor = torch.tensor([0],
device=attn_metadata.dcp_context_kv_lens.device,
dtype=attn_metadata.dcp_context_kv_lens.dtype)
cu_seqlens_k = torch.cat([new_tensor, cu_dcp_kv_klens])
sliding_window_size = (
list(self.sliding_window) if self.sliding_window is not None else None
)
cu_seqlens_k = torch.cat(
[
torch.zeros(1, device=cu_seqlens_q.device, dtype=cu_seqlens_q.dtype),
attn_metadata.dcp_context_kv_lens.cumsum(
dim=0, dtype=cu_seqlens_q.dtype
),
],
dim=0,
)
context_attn_out, context_lse = flash_attn_varlen_func(
q=query_across_dcp,
k=key_cache,
v=value_cache,
out=None,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
# seqused_k=attn_metadata.dcp_context_kv_lens,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=attn_metadata.max_dcp_context_kv_len,
softmax_scale=self.scale,
causal=False,
@@ -998,11 +1202,6 @@ class FlashAttentionImpl(AttentionImpl):
block_table=block_table,
softcap=self.logits_soft_cap,
return_softmax_lse=True,
# scheduler_metadata=attn_metadata.scheduler_metadata,
# fa_version=self.vllm_flash_attn_version,
# q_descale=q_descale,
# k_descale=k_descale,
# v_descale=v_descale,
)
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs(
@@ -1028,10 +1227,6 @@ class FlashAttentionImpl(AttentionImpl):
window_size=sliding_window_size,
softcap=self.logits_soft_cap,
return_softmax_lse=True,
# fa_version=self.vllm_flash_attn_version,
# q_descale=q_descale,
# k_descale=k_descale,
# v_descale=v_descale,
)
assert context_attn_out_cor.shape == query_attn_out.shape
assert context_lse_cor.shape == query_lse.shape
@@ -1040,7 +1235,7 @@ class FlashAttentionImpl(AttentionImpl):
context_lse_cor,
query_attn_out,
query_lse,
output,
output
)
def _forward_encoder_attention(
@@ -1062,9 +1257,9 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: Encoder attention metadata
layer: The attention layer
"""
assert self.vllm_flash_attn_version is not None, (
"FlashAttention version not detected."
)
# assert self.vllm_flash_attn_version is not None, (
# "FlashAttention version not detected."
# )
# For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"):
@@ -1101,18 +1296,9 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes,
window_size=sliding_window_size,
softcap=self.logits_soft_cap,
# fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
# num_splits=1 if self.batch_invariant_enabled else 0,
)
return (
output[: attn_metadata.num_actual_tokens]
.contiguous()
.view(-1, self.num_heads * self.head_size)
)
return output
def use_cascade_attention(
@@ -1203,8 +1389,6 @@ def cascade_attention(
cu_prefix_query_lens: torch.Tensor,
cu_prefix_kv_lens: torch.Tensor,
cu_suffix_kv_lens: torch.Tensor,
# prefix_kv_lens: torch.Tensor,
# suffix_kv_lens: torch.Tensor,
max_kv_len: int,
softmax_scale: float,
alibi_slopes: torch.Tensor | None,
@@ -1228,12 +1412,13 @@ def cascade_attention(
)
num_tokens = query.shape[0]
# block_size = key_cache.shape[-3]
block_size = key_cache.shape[-2]
assert common_prefix_len % block_size == 0
num_common_kv_blocks = common_prefix_len // block_size
assert num_common_kv_blocks > 0
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
assert q_descale is None or q_descale==1, f"q_descale is not None, q_descale: {q_descale}"
assert k_descale is None or k_descale==1, f"k_descale is not None, k_descale: {k_descale}"
assert v_descale is None or v_descale==1, f"v_descale is not None, v_descale: {v_descale}"
# Process shared prefix.
prefix_output, prefix_lse = flash_attn_varlen_func(
@@ -1241,7 +1426,6 @@ def cascade_attention(
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_prefix_query_lens,
# seqused_k=prefix_kv_lens,
cu_seqlens_k=cu_prefix_kv_lens,
max_seqlen_q=num_tokens,
max_seqlen_k=common_prefix_len,
@@ -1251,26 +1435,14 @@ def cascade_attention(
block_table=block_table[:1],
softcap=logits_soft_cap,
return_softmax_lse=True,
# scheduler_metadata=prefix_scheduler_metadata,
# fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge.
# s_aux=s_aux,
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
)
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
# Process suffix per query.
suffix_output, suffix_lse = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
# seqused_k=suffix_kv_lens,
cu_seqlens_k=cu_suffix_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len - common_prefix_len,
@@ -1280,14 +1452,6 @@ def cascade_attention(
block_table=block_table[:, num_common_kv_blocks:],
softcap=logits_soft_cap,
return_softmax_lse=True,
# scheduler_metadata=suffix_scheduler_metadata,
# fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
)
# Merge prefix and suffix outputs, and store the result in output.
# merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse)
merge_attn_states(prefix_output, prefix_lse, suffix_output, suffix_lse, output)

View File

@@ -13,7 +13,7 @@ from flashinfer import (
BatchPrefillWithRaggedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper,
)
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
from flashinfer.decode import fast_decode_plan, trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
from flashinfer.utils import FP4Tensor
from typing_extensions import override
@@ -199,14 +199,14 @@ class BatchDCPPrefillWrapper:
):
"""Plan the prefill operation with given parameters."""
self._context.plan(
qo_indptr_cpu,
paged_kv_indptr_cpu,
paged_kv_indices,
paged_kv_last_page_len_cpu,
num_qo_heads * dcp_world_size,
num_kv_heads,
head_dim,
page_size,
qo_indptr=qo_indptr_cpu,
paged_kv_indptr=paged_kv_indptr_cpu,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len_cpu,
num_qo_heads=num_qo_heads * dcp_world_size,
num_kv_heads=num_kv_heads,
head_dim_qk=head_dim,
page_size=page_size,
causal=False, # This is context run
sm_scale=sm_scale,
window_left=window_left,
@@ -374,13 +374,13 @@ class FlashInferBackend(AttentionBackend):
@classmethod
def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None:
from vllm.platforms import current_platform
capability = current_platform.get_device_capability()
if capability is not None and capability.major == 10:
return "HND"
return None
forward_includes_kv_cache_update: bool = False
@dataclass
class FIPrefill:
@@ -573,20 +573,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# try to use fp8 q if kv cache is fp8, and will fall back to model dtype
# if TRTLLM attention kernel is not used when building attn metadata
can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
# TRTLLM attention requires strictly contiguous KV cache tensors.
# When KV transfer (P/D disaggregation) is enabled, the KV cache may be
# permuted into non-contiguous views, which causes assertion failures.
self._kv_transfer_enabled = vllm_config.kv_transfer_config is not None
if can_use_trtllm and self._kv_transfer_enabled:
logger.info_once(
"TRTLLM attention is disabled because KV transfer "
"(P/D disaggregation) is enabled. TRTLLM attention requires "
"strictly contiguous KV cache tensors which may not be "
"guaranteed with KV transfer."
)
can_use_trtllm = False
if (
can_use_trtllm
and not vllm_config.attention_config.disable_flashinfer_q_quantization
@@ -816,6 +802,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
page_size,
paged_kv_last_page_len_np,
)
self.paged_kv_last_page_len.gpu[:num_reqs].copy_(
self.paged_kv_last_page_len.cpu[:num_reqs], non_blocking=True
)
return paged_kv_indices
def build(
@@ -860,9 +849,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
has_sinks=self.has_sinks,
has_spec=uses_spec_reorder,
)
# KV transfer requires non-contiguous KV cache views, incompatible with TRTLLM
if self._kv_transfer_enabled:
prefill_use_trtllm = False
decode_use_trtllm = (
self.use_trtllm_decode_attention and self.dcp_world_size <= 1
)
@@ -997,14 +983,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
attn_metadata.cascade_wrapper.plan(
[shared_qo_indptr_cpu, qo_indptr_cpu],
[shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
[shared_kv_page_indices_cpu, paged_kv_indices],
[shared_kv_last_page_len_cpu, paged_kv_last_page_len_cpu],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
qo_indptr_arr=[shared_qo_indptr_cpu, qo_indptr_cpu],
paged_kv_indptr_arr=[shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
paged_kv_indices_arr=[shared_kv_page_indices_cpu, paged_kv_indices],
paged_kv_last_page_len=[
shared_kv_last_page_len_cpu,
paged_kv_last_page_len_cpu,
],
num_qo_heads=self.num_qo_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
page_size=self.page_size,
causal=True,
sm_scale=self.sm_scale,
window_left=self.window_left,
@@ -1082,14 +1071,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
BatchPrefillWithPagedKVCacheWrapper,
)
prefill_wrapper.plan(
qo_indptr_prefill_cpu,
paged_kv_indptr_prefill_cpu,
paged_kv_indices,
paged_kv_last_page_len_prefill_cpu,
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
qo_indptr=qo_indptr_prefill_cpu,
paged_kv_indptr=paged_kv_indptr_prefill_cpu,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len_prefill_cpu,
num_qo_heads=self.num_qo_heads,
num_kv_heads=self.num_kv_heads,
head_dim_qk=self.head_dim,
page_size=self.page_size,
causal=True,
sm_scale=self.sm_scale,
window_left=self.window_left,
@@ -1130,14 +1119,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# in atten_metadata when using cudagraph.
fast_plan_decode(
decode_wrapper,
self.paged_kv_indptr.cpu[: num_input_tokens + 1],
paged_kv_indices,
self.paged_kv_last_page_len.cpu[:num_input_tokens],
seq_lens_cpu[:num_input_tokens],
self.num_qo_heads * self.dcp_world_size,
self.num_kv_heads,
self.head_dim,
self.page_size,
indptr_cpu=self.paged_kv_indptr.cpu[: num_input_tokens + 1],
indices=paged_kv_indices,
last_page_len_cpu=self.paged_kv_last_page_len.cpu[
:num_input_tokens
],
num_qo_heads=self.num_qo_heads * self.dcp_world_size,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
page_size=self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
sm_scale=self.sm_scale,
@@ -1330,32 +1320,15 @@ class FlashInferImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if self.kv_sharing_target_layer_name is None and self.kv_cache_dtype.startswith(
"fp8"
):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.kv_cache_dtype
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if self.kv_cache_dtype.startswith("fp8"):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.kv_cache_dtype
)
kv_cache = kv_cache.view(torch_dtype)
kv_cache = kv_cache.view(torch_dtype)
# Inputs and outputs may be padded for CUDA graphs
query = query[:num_actual_tokens]
@@ -1599,13 +1572,39 @@ class FlashInferImpl(AttentionImpl):
)
return output_padded
def do_kv_cache_update(
self,
layer: torch.nn.Module,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
def fast_plan_decode(
self, # decode wrapper
indptr_cpu: torch.Tensor,
indices: torch.Tensor,
last_page_len_cpu: torch.Tensor,
seq_lens_cpu: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
@@ -1642,110 +1641,56 @@ def fast_plan_decode(
# this warm up is to generate the _cached_module for the decode wrapper.
if not self.is_cuda_graph_enabled or getattr(self, "vllm_first_call", True):
self.plan(
indptr_cpu,
indices,
last_page_len_cpu,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
pos_encoding_mode,
window_left,
logits_soft_cap,
q_data_type,
kv_data_type,
o_data_type,
data_type,
sm_scale,
rope_scale,
rope_theta,
non_blocking,
None, # block_tables
None, # seq_lens
fixed_split_size,
disable_split_kv,
indptr=indptr_cpu,
indices=indices,
last_page_len=last_page_len_cpu,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
page_size=page_size,
pos_encoding_mode=pos_encoding_mode,
window_left=window_left,
logits_soft_cap=logits_soft_cap,
q_data_type=q_data_type,
kv_data_type=kv_data_type,
o_data_type=o_data_type,
data_type=data_type,
sm_scale=sm_scale,
rope_scale=rope_scale,
rope_theta=rope_theta,
non_blocking=non_blocking,
block_tables=None,
seq_lens=None,
fixed_split_size=fixed_split_size,
disable_split_kv=disable_split_kv,
)
self.vllm_first_call = False
return
assert self.is_cuda_graph_enabled, "Should be cudagraph only here"
batch_size = len(last_page_len_cpu)
if logits_soft_cap is None:
logits_soft_cap = 0.0
# Handle data types consistently
if data_type is not None:
if q_data_type is None:
q_data_type = data_type
if kv_data_type is None:
kv_data_type = data_type
elif q_data_type is None:
q_data_type = "float16"
if kv_data_type is None:
kv_data_type = q_data_type
q_data_type = (
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
fast_decode_plan(
self,
indptr=indptr_cpu,
indices=indices,
last_page_len=last_page_len_cpu,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
page_size=page_size,
pos_encoding_mode=pos_encoding_mode,
window_left=window_left,
logits_soft_cap=logits_soft_cap,
q_data_type=q_data_type,
kv_data_type=kv_data_type,
data_type=data_type,
sm_scale=sm_scale,
rope_scale=rope_scale,
rope_theta=rope_theta,
non_blocking=non_blocking,
fixed_split_size=fixed_split_size,
disable_split_kv=disable_split_kv,
)
kv_data_type = (
getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type
)
if batch_size != self._fixed_batch_size:
raise ValueError(
"The batch size should be fixed in cudagraph mode, the runtime "
"batch size {} mismatches the batch size set during "
"initialization {}".format(batch_size, self._fixed_batch_size)
)
if len(indices) > len(self._paged_kv_indices_buf):
raise ValueError(
"The size of indices should be less than or equal to the allocated buffer"
)
# host-to-device copy for the indptr buffer
self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True)
# host-to-device copy for the last_page_len buffer
self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, non_blocking=True)
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
try:
# Make sure we pass exactly 19 arguments for tensor core version
args = [
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
indptr_cpu,
seq_lens_cpu,
batch_size, # total_num_rows
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
head_dim,
head_dim,
False, # causal
window_left,
]
if self._backend == "fa2":
args.append(fixed_split_size)
args.append(disable_split_kv)
args.append(0) # num_colocated_ctas
self._plan_info = self._cached_module.plan(
*args,
)
except Exception as e:
raise RuntimeError(f"Error in tensor core plan: {e}") from e
self._pos_encoding_mode = pos_encoding_mode
self._window_left = window_left
self._logits_soft_cap = logits_soft_cap
self._sm_scale = sm_scale
self._rope_scale = rope_scale
self._rope_theta = rope_theta
@triton.jit

View File

@@ -1,9 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from dataclasses import dataclass, replace
from typing import Any
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder,
@@ -29,3 +30,31 @@ class Mamba1AttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
):
metadata_cls = Mamba1AttentionMetadata
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
**kwargs: Any,
) -> Mamba1AttentionMetadata:
common = self._compute_common_metadata(common_attn_metadata)
if (
common.num_prefills > 0
and self.vllm_config.cache_config.mamba_cache_mode == "all"
):
cu_chunk_seqlen_p, _, last_chunk_indices_p = (
self._build_chunk_metadata_tensors(
self.kv_cache_spec.block_size,
common,
common_attn_metadata,
)
)
return replace(
common,
cu_chunk_seqlen_p=cu_chunk_seqlen_p,
last_chunk_indices_p=last_chunk_indices_p,
)
return common

View File

@@ -7,7 +7,6 @@ from typing import Any
import torch
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import (
AttentionBackend,
CommonAttentionMetadata,
@@ -105,14 +104,6 @@ class Mamba2AttentionMetadata(BaseMambaAttentionMetadata):
# Chunk-related metadata (only for prefill)
seq_idx_p: torch.Tensor | None = None
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
# each chunk, its offsets into the varlen sequence dimension. It is defined
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
# cu_chunk_seqlen_p[i+1].
cu_chunk_seqlen_p: torch.Tensor | None = None
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
# index of the last chunk for every sequence in the (prefill) batch.
last_chunk_indices_p: torch.Tensor | None = None
class Mamba2AttentionMetadataBuilder(
@@ -134,68 +125,6 @@ class Mamba2AttentionMetadataBuilder(
)
self.chunk_size: int = chunk_size
def _compute_chunk_metadata(
self,
num_prefills: int,
num_computed_tokens_p_cpu: torch.Tensor,
query_start_loc_p_cpu: torch.Tensor,
) -> tuple[list[int], list[int], list[int]]:
"""
Compute chunk-specific metadata for Mamba2.
The code below carefully constructs the chunks such that:
1. Chunks contain tokens from a *single* sequence only.
2. For every sequence, we are guaranteed that we can
retrieve the mamba state *every* chunk_size tokens.
Constraint (1) dramatically simplifies the mamba2 kernels.
Constraint (2) dramatically simplifies the implementation
of prefix caching for mamba2 (wip). We need to take care
of the interaction with chunked prefill in order to
satisfy constraint (2).
"""
# TODO (tdoublep): This code could probably be optimized.
cu_chunk_seqlen = []
seq_idx = []
last_chunk_indices = []
seqlen_pos = 0
for req_idx in range(num_prefills):
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
this_new_tokens = (
query_start_loc_p_cpu[req_idx + 1].item()
- query_start_loc_p_cpu[req_idx].item()
)
# if computed tokens are not chunk-aligned, use the first
# chunk to finish it off
if this_num_computed % self.chunk_size != 0:
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
# how many tokens to finish the chunk?
chunk_len = (
cdiv(this_num_computed, self.chunk_size) * self.chunk_size
- this_num_computed
)
# we can only use at most this_new_tokens
chunk_len = min(chunk_len, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
n_chunks = cdiv(this_new_tokens, self.chunk_size)
for chunk in range(n_chunks):
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
chunk_len = min(self.chunk_size, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
assert this_new_tokens == 0
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
cu_chunk_seqlen.append(seqlen_pos)
return cu_chunk_seqlen, seq_idx, last_chunk_indices
def build(
self,
common_prefix_len: int,
@@ -220,41 +149,12 @@ class Mamba2AttentionMetadataBuilder(
else False
)
num_reqs = common.num_reqs
num_prefills = common.num_prefills
num_decode_tokens = common.num_decode_tokens
num_computed_tokens_cpu = (
common_attn_metadata.compute_num_computed_tokens().cpu()
)
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
query_start_loc_p_cpu = (
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
- num_decode_tokens
)
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
num_prefills,
num_computed_tokens_p_cpu,
query_start_loc_p_cpu,
)
seq_idx_p = torch.as_tensor(
seq_idx,
device=common_attn_metadata.query_start_loc.device,
dtype=torch.int32,
)
cu_chunk_seqlen_p = torch.as_tensor(
cu_chunk_seqlen,
device=common_attn_metadata.query_start_loc.device,
dtype=torch.int32,
)
last_chunk_indices_p = torch.as_tensor(
last_chunk_indices,
device=common_attn_metadata.query_start_loc.device,
dtype=torch.int32,
cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p = (
self._build_chunk_metadata_tensors(
self.chunk_size,
common,
common_attn_metadata,
)
)
return replace(

View File

@@ -59,6 +59,15 @@ class BaseMambaAttentionMetadata:
# The following tensor is only used for prefix caching in align mode
seq_lens: torch.Tensor
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
# each chunk, its offsets into the varlen sequence dimension. It is defined
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
# cu_chunk_seqlen_p[i+1].
cu_chunk_seqlen_p: torch.Tensor | None = None
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
# index of the last chunk for every sequence in the (prefill) batch.
last_chunk_indices_p: torch.Tensor | None = None
# The following attributes are for triton implementation of causal_conv1d
nums_dict: dict | None = None
batch_ptr: torch.Tensor | None = None
@@ -185,6 +194,118 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
common_attn_metadata, num_accepted_tokens=num_accepted_tokens
)
def _compute_chunk_metadata(
self,
chunk_size: int,
num_prefills: int,
num_computed_tokens_p_cpu: torch.Tensor,
query_start_loc_p_cpu: torch.Tensor,
) -> tuple[list[int], list[int], list[int]]:
"""
Compute chunk-specific metadata for Mamba models.
The code below carefully constructs the chunks such that:
1. Chunks contain tokens from a *single* sequence only.
2. For every sequence, we are guaranteed that we can
retrieve the mamba state *every* chunk_size tokens.
Constraint (1) dramatically simplifies the mamba kernels.
Constraint (2) dramatically simplifies the implementation
of prefix caching for mamba (wip). We need to take care
of the interaction with chunked prefill in order to
satisfy constraint (2).
"""
# TODO (tdoublep): This code could probably be optimized.
cu_chunk_seqlen = []
seq_idx = []
last_chunk_indices = []
seqlen_pos = 0
for req_idx in range(num_prefills):
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
this_new_tokens = (
query_start_loc_p_cpu[req_idx + 1].item()
- query_start_loc_p_cpu[req_idx].item()
)
# if computed tokens are not chunk-aligned, use the first
# chunk to finish it off
if this_num_computed % chunk_size != 0:
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
# how many tokens to finish the chunk?
chunk_len = (
cdiv(this_num_computed, chunk_size) * chunk_size - this_num_computed
)
# we can only use at most this_new_tokens
chunk_len = min(chunk_len, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
n_chunks = cdiv(this_new_tokens, chunk_size)
for chunk in range(n_chunks):
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
chunk_len = min(chunk_size, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
assert this_new_tokens == 0
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
cu_chunk_seqlen.append(seqlen_pos)
return cu_chunk_seqlen, seq_idx, last_chunk_indices
def _build_chunk_metadata_tensors(
self,
chunk_size: int,
common: M,
common_attn_metadata: CommonAttentionMetadata,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute chunk metadata and return as device tensors.
Returns (cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p).
"""
num_reqs = common.num_reqs
num_prefills = common.num_prefills
num_decode_tokens = common.num_decode_tokens
num_computed_tokens_cpu = (
common_attn_metadata.compute_num_computed_tokens().cpu()
)
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
query_start_loc_p_cpu = (
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
- num_decode_tokens
)
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
chunk_size,
num_prefills,
num_computed_tokens_p_cpu,
query_start_loc_p_cpu,
)
device = common_attn_metadata.query_start_loc.device
cu_chunk_seqlen_p = torch.as_tensor(
cu_chunk_seqlen,
device=device,
dtype=torch.int32,
)
seq_idx_p = torch.as_tensor(
seq_idx,
device=device,
dtype=torch.int32,
)
last_chunk_indices_p = torch.as_tensor(
last_chunk_indices,
device=device,
dtype=torch.int32,
)
return cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p
def _compute_prefix_caching_block_indices(
self,
common_attn_metadata: CommonAttentionMetadata,

View File

@@ -191,6 +191,8 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: torch.Tensor | None,
max_decode_seq_len: int = 0,
use_cuda_graph: bool = False,
) -> FlashAttnMLADecodeMetadata:
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
max_query_len = query_lens_cpu.max().item()
@@ -239,12 +241,14 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
metadata = FlashAttnMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
max_decode_seq_len=max_decode_seq_len,
use_cuda_graph=use_cuda_graph,
query_start_loc=query_start_loc_device,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
scheduler_metadata=scheduler_metadata,
max_num_splits=max_num_splits,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)
return metadata

View File

@@ -156,6 +156,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: torch.Tensor | None,
max_decode_seq_len: int = 0,
use_cuda_graph: bool = False,
) -> FlashMLADecodeMetadata:
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
# we use the max but all should be the same due to uniform length requirement
@@ -179,8 +181,10 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
return FlashMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
scheduler_metadata=scheduler_metadata,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
max_decode_seq_len=max_decode_seq_len,
use_cuda_graph=use_cuda_graph,
scheduler_metadata=scheduler_metadata,
)

View File

@@ -13,6 +13,11 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
get_mla_dims,
)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
LinearBase,
UnquantizedLinearMethod,
)
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.utils.platform_utils import num_compute_units
@@ -37,13 +42,17 @@ from vllm.v1.attention.backends.utils import (
)
from vllm.v1.attention.ops.flashmla import (
FlashMLASchedMeta,
flash_mla_sparse_fwd,
flash_mla_sparse_prefill,
flash_mla_with_kvcache,
get_mla_metadata,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.workspace import current_workspace_manager
import functools
from vllm import envs
from vllm.model_executor.layers.quantization.utils.quant_utils import scaled_dequantize
import ixformer.inference.functions as ixf_ops
import numpy as np
if TYPE_CHECKING:
from vllm.model_executor.models.deepseek_v2 import Indexer
@@ -74,7 +83,15 @@ structured as:
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
part is not quantized for accuracy.
"""
def dynamic_per_batched_tensor_quant(
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
):
DTYPE_MAX = torch.finfo(dtype).max
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
scale = DTYPE_MAX / amax
x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
class FlashMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
@@ -558,6 +575,11 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.kv_lora_rank: int = mla_args["kv_lora_rank"]
self.qk_nope_head_dim = mla_args["qk_nope_head_dim"]
self.qk_rope_head_dim = mla_args["qk_rope_head_dim"]
self.qk_head_dim = mla_args["qk_head_dim"]
self.v_head_dim = mla_args["v_head_dim"]
self.kv_b_proj = mla_args["kv_b_proj"]
self.softmax_scale = scale
assert indexer is not None
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
@@ -580,6 +602,65 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
(self.prefill_workspace_shape, torch.bfloat16)
)
)
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
if hasattr(layer, attr):
return getattr(layer, attr)
raise AttributeError(
f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}."
)
def get_and_maybe_dequant_weights(layer: LinearBase):
if layer.quant_method is not None and not isinstance(
layer.quant_method, UnquantizedLinearMethod
):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(
layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device,
)
dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}"
)
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
self.W_UV = W_UV
self.W_UK = W_UK
# self.W_UK_T = W_UK.permute(1, 2, 0)
def _v_up_proj(self, x: torch.Tensor):
return torch.einsum("bnl,lnv->bnv", x, self.W_UV)
def _k_up_proj(self, q_nope):
return torch.einsum("bnp,lnp->bnl", q_nope, self.W_UK).view(-1, self.num_heads, self.kv_lora_rank)
def _forward_bf16_kv(
self,
@@ -590,12 +671,11 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
) -> torch.Tensor:
# Convert per-request indices to global slots (decode) or workspace
# offsets (prefill).
topk_indices = triton_convert_req_index_to_global_index(
topk_indices = ops.dsa_convert_req_index_to_global_index(
attn_metadata.req_id_per_token,
attn_metadata.block_table,
topk_indices,
BLOCK_SIZE=attn_metadata.block_size,
NUM_TOPK_TOKENS=topk_indices.shape[1],
attn_metadata.block_size,
)
return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices)
@@ -790,22 +870,10 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
-1, 1, kv_c_and_k_pe_cache.shape[-1]
)
# NOTE(Chen): kernel requires num_local_head to be a multiple of
# 64 on hopper and 128 on blackwell
if self.num_heads % self.prefill_padding != 0:
assert self.prefill_padding % self.num_heads == 0
logger.warning_once(
f"Padding num_heads from {self.num_heads} to "
f"{self.prefill_padding} for BF16 sparse prefill kernel"
)
q_padded = q.new_empty((q.shape[0], self.prefill_padding, q.shape[2]))
q_padded[:, : self.num_heads, :] = q
q = q_padded
topk_indices = topk_indices.view(num_tokens, 1, -1)
output = flash_mla_sparse_fwd(
output = flash_mla_sparse_prefill(
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
)[0]
)
output = output[:, : self.num_heads, :]
return output
@@ -843,5 +911,5 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
attn_out = self._forward_fp8_kv_separate_prefill_decode(
q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
)
return attn_out, None
return attn_out

View File

@@ -8,7 +8,11 @@ import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, has_deep_gemm
from vllm.utils.deep_gemm import (
get_paged_mqa_logits_metadata,
is_deep_gemm_supported,
)
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import (
AttentionBackend,
@@ -21,6 +25,7 @@ from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills,
split_prefill_chunks,
)
from vllm.v1.worker.cp_utils import get_total_cp_world_size
logger = init_logger(__name__)
@@ -68,11 +73,15 @@ class DeepseekV32IndexerPrefillChunkMetadata:
cu_seqlen_ks: torch.Tensor
cu_seqlen_ke: torch.Tensor
cu_seq_lens: torch.Tensor
cu_seqlens_q: torch.Tensor
token_to_seq: torch.Tensor
total_seq_lens: int
token_start: int
token_end: int
num_reqs: int
max_context_len: int
max_q_len: int # Maximum query length for dsa_indexer_mqa_logits_with_blocks
max_kv_len: int # Maximum key-value length for dsa_indexer_mqa_logits_with_blocks
@dataclass
@@ -86,9 +95,16 @@ class DeepSeekV32IndexerDecodeMetadata:
seq_lens: torch.Tensor
decode_lens: torch.Tensor
requires_padding: bool
schedule_metadata: torch.Tensor
# schedule_metadata: torch.Tensor
use_large_context_topk: bool
offsets: torch.Tensor | None # Precomputed offsets for speculative decoding
cu_seqlen_ks: torch.Tensor
cu_seqlen_ke: torch.Tensor
cu_seqlens_kv: torch.Tensor
cu_seqlens_q: torch.Tensor
max_context_len: int
max_q_len: int
max_kv_len: int
@dataclass
@@ -211,20 +227,39 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
if self.vllm_config.speculative_config
else 0
)
if self.num_speculative_tokens > 1:
raise ValueError(
"Sparse MLA only supports "
"num_speculative_tokens <= 1 because the DeepGEMM "
"fp8_paged_mqa_logits kernel does not support next_n > 2. "
f"Got num_speculative_tokens={self.num_speculative_tokens}."
)
self.reorder_batch_threshold += self.num_speculative_tokens
sm_count = num_compute_units(self.device.index)
self.num_sms = sm_count
self.decode_lens_buffer = torch.empty(
(scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device
(scheduler_config.max_num_batched_tokens,),
dtype=torch.int32,
device=self.device,
)
# Pre-allocated buffers for flattening (spec decode).
self.arange_buffer = torch.arange(
scheduler_config.max_num_seqs * (1 + self.num_speculative_tokens),
dtype=torch.int32,
device=self.device,
)
self.expanded_seq_lens_buffer = torch.zeros(
(scheduler_config.max_num_batched_tokens,),
dtype=torch.int32,
device=self.device,
)
max_num_blocks_per_req = cdiv(
self.vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size * get_total_cp_world_size(),
)
self.expanded_block_table_buffer = torch.zeros(
(
scheduler_config.max_num_batched_tokens,
max_num_blocks_per_req,
),
dtype=torch.int32,
device=self.device,
)
# See: DeepGMM/csrc/apis/attention.hpp
@@ -260,18 +295,88 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
.to(torch.int32)
.to(self.device)
)
cu_seqlens_q = prefill_query_start_loc.to(torch.int32).to(self.device)
max_context_len = seq_lens_cpu[reqs_start:reqs_end].max().item()
# max_q_len is the maximum query length among all batches in this chunk
# prefill_query_start_loc is cumsum of lengths with shape [batch+1]
max_q_len = (prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]).max().item()
return DeepseekV32IndexerPrefillChunkMetadata(
cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke,
cu_seq_lens=cu_seq_lens,
token_to_seq=token_to_seq,
total_seq_lens=total_seq_lens,
cu_seqlens_q=cu_seqlens_q,
block_table=block_table[reqs_start:reqs_end],
token_start=token_start,
token_end=token_end,
num_reqs=reqs_end - reqs_start,
max_context_len=max_context_len,
max_q_len=max_q_len,
max_kv_len=max_context_len
)
def build_decode_metadata(
self, common_attn_metadata, num_decodes, decode_lens, use_large_context_topk, offsets
):
decode_lens_cpu = torch.diff(
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
)
assert (
decode_lens_cpu.max().item()
== decode_lens_cpu.min().item()
== 1
), "Only support single token decode in dsa_indexer backend"
# Calculate decode metadata parameters
seq_lens_decode = common_attn_metadata.seq_lens_cpu[:num_decodes]
max_context_len = seq_lens_decode.max().item()
max_kv_len = max_context_len
max_q_len = 1 # Single token decode
# Create cu_seqlens_q: cumulative sum of query lengths (all 1s)
cu_seqlens_q = torch.arange(
num_decodes + 1, dtype=torch.int32, device=self.device
)
# Create cu_seqlens_kv and related tensors using kv_spans_from_batches
decode_query_start_loc = torch.arange(
num_decodes + 1, dtype=torch.long
)
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
decode_query_start_loc, seq_lens_decode, self.device
)
cu_seqlens_kv = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device=self.device),
torch.cumsum(seq_lens_decode.to(self.device), dim=0)
.to(torch.int32),
]
)
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=common_attn_metadata.block_table_tensor[
:num_decodes, ...
],
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
decode_lens=decode_lens,
requires_padding=(
decode_lens_cpu.max() > decode_lens_cpu.min()
).item(),
use_large_context_topk=use_large_context_topk,
offsets=offsets,
cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q=cu_seqlens_q,
max_context_len=max_context_len,
max_q_len=max_q_len,
max_kv_len=max_kv_len,
# schedule_metadata=self.scheduler_metadata_buffer,
)
return decode_metadata
def build(
self,
common_prefix_len: int,
@@ -323,45 +428,103 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
)
# Use CPU to avoid GPU sync; breaking async scheduling
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
# Decide which top-k kernel to use based on batch size and sequence length
batch_size = num_decodes
_is_large_context = common_attn_metadata.max_seq_len > 8192
# Decision logic based on micro-benchmark results:
# - large_context_topk wins for batch <= 128 and seq_len > 8K
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
use_large_context_topk = batch_size <= 128 and _is_large_context
next_n = 1 + self.num_speculative_tokens
if next_n > 1:
offsets = torch.arange(next_n, device=self.device, dtype=torch.int32)
else:
offsets = None
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
# DeepGEMM is required for the paged MQA logits on CUDA devices
if current_platform.is_cuda() and has_deep_gemm():
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms
)
block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]
# Padded CUDA graph requests have block_table entries of -1.
# Clamp to 0 to prevent OOB access in the DeepGEMM kernel.
# This is safe because padded requests have seq_lens=0, so the
# kernel produces no meaningful output for those rows.
block_table.clamp_(min=0)
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=block_table,
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
decode_lens=decode_lens,
requires_padding=requires_padding,
schedule_metadata=self.scheduler_metadata_buffer,
use_large_context_topk=use_large_context_topk,
offsets=offsets,
max_decode_len = int(decode_lens_cpu.max().item())
if max_decode_len > 1:
# Flatten multi-token decode requests into single-token
# batch entries, expanding seq_lens and block tables so
# the kernel always sees next_n=1.
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
# The context lengths are therefore
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
# 3 + 1 + 4 + 0 = 8
actual_expanded = int(decode_lens_cpu.sum().item())
# [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8]
expanded_base = torch.repeat_interleave(
seq_lens - decode_lens, decode_lens
)
# [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4]
expanded_starts = torch.repeat_interleave(
common_attn_metadata.query_start_loc[:num_decodes], decode_lens
)
# [0, 1, 2, 0, 0, 1, 2, 3]
positions_within = (
self.arange_buffer[:actual_expanded] - expanded_starts
)
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
self.expanded_seq_lens_buffer[:actual_expanded] = (
expanded_base + positions_within + 1
)
self.expanded_seq_lens_buffer[actual_expanded:] = 0
seq_lens = self.expanded_seq_lens_buffer[:num_decode_tokens]
# Give each of the flattened entries the same block table row as the
# original request.
self.expanded_block_table_buffer[:actual_expanded] = (
torch.repeat_interleave(block_table, decode_lens, dim=0)
)
if actual_expanded < num_decode_tokens:
self.expanded_block_table_buffer[
actual_expanded:num_decode_tokens, 0
] = 0
block_table = self.expanded_block_table_buffer[:num_decode_tokens]
# All reqs now have decode_len=1
self.decode_lens_buffer[:num_decode_tokens] = 1
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
offsets = None
batch_size = num_decode_tokens
else:
next_n = 1 + self.num_speculative_tokens
if next_n > 1:
offsets = torch.arange(
next_n, device=self.device, dtype=torch.int32
)
else:
offsets = None
batch_size = num_decodes
# DeepGEMM is required for the paged MQA logits on CUDA devices
if current_platform.is_cuda() and is_deep_gemm_supported():
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens,
self.kv_cache_spec.block_size,
self.num_sms,
)
# Decide which top-k kernel to use based on batch size and sequence length
# Decision logic based on micro-benchmark results:
# - large_context_topk wins for batch <= 128 and seq_len > 8K
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
_is_large_context = common_attn_metadata.max_seq_len > 8192
use_large_context_topk = batch_size <= 128 and _is_large_context
# decode_metadata = DeepSeekV32IndexerDecodeMetadata(
# block_table=block_table,
# seq_lens=seq_lens,
# decode_lens=decode_lens,
# requires_padding=False,
# # schedule_metadata=self.scheduler_metadata_buffer,
# use_large_context_topk=use_large_context_topk,
# offsets=offsets,
# )
decode_metadata = self.build_decode_metadata(
common_attn_metadata, num_decodes, decode_lens, use_large_context_topk, offsets
)
attn_metadata = DeepseekV32IndexerMetadata(

View File

@@ -115,6 +115,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: torch.Tensor | None,
max_decode_seq_len: int = 0,
use_cuda_graph: bool = False,
) -> AiterMLADecodeMetadata:
# kernel block size is always 1, although the kv block size is not 1.
device = self.device
@@ -170,11 +172,13 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
attn_metadata = AiterMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
max_decode_seq_len=max_decode_seq_len,
use_cuda_graph=use_cuda_graph,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
qo_indptr=qo_indptr,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
max_qo_len=max_qo_len,
attn_out_dtype=self.decode_attn_out_dtype,
)

View File

@@ -15,6 +15,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.distributed.parallel_state import get_dcp_group
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionLayer,
@@ -22,20 +23,19 @@ from vllm.v1.attention.backend import (
is_quantized_kv_cache,
)
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
import ixformer.inference.functions as ixf_ops
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.parallel_state import get_dcp_group
logger = init_logger(__name__)
class TritonMLABackend(MLACommonBackend):
# supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
# supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
# "auto",
# "bfloat16",
# ]
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
]
@staticmethod
def get_name() -> str:
@@ -120,10 +120,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
# layer: AttentionLayer,
k_c_normed: torch.Tensor |None = None,
k_pe: torch.Tensor |None = None,
kv_c_and_k_pe_cache_scale: torch.Tensor |None = None,
k_c_normed: torch.Tensor | None,
k_pe: torch.Tensor | None,
kv_c_and_k_pe_cache_scale: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
@@ -136,7 +135,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
q_nope = q_nope.view(-1, self.num_heads, self.kv_lora_rank)
B = q_nope.shape[0]
if self.dcp_world_size > 1:
q = torch.cat([q_nope, q_pe], dim=-1)
q = get_dcp_group().all_gather(q, dim=1)
@@ -147,7 +146,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
device=q_nope.device)
if envs.VLLM_USE_INT8_MLA:
q_int8, q_scale = ops.quant_kv(q)
attn_out, softmax_lse = ixf_ops.ref_vllm_paged_attention_mla_int8(
attn_out, softmax_lse = ixf_ops.vllm_paged_attention_mla_int8(
o,
q_int8,
q_scale,
@@ -160,7 +159,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
return_softmax_lse=True
)
else:
attn_out, softmax_lse = ixf_ops.ref_vllm_paged_attention_mla(
attn_out, softmax_lse = ixf_ops.vllm_paged_attention_mla(
output=o,
query=q,
kv_cache=kv_c_and_k_pe_cache,
@@ -170,12 +169,12 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
max_context_len=decode_meta.max_decode_seq_len,
return_softmax_lse=True)
return attn_out, softmax_lse
o = torch.empty(B,
self.num_heads,
self.kv_lora_rank,
dtype=q_nope.dtype,
device=q_nope.device)
device=q_nope.device)
if envs.VLLM_USE_INT8_MLA:
q = torch.cat([q_nope, q_pe], dim=-1)
@@ -193,18 +192,30 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
attn_metadata.decode.use_cuda_graph
)
else:
# fused q concat & cache write
ixf_ops.vllm_paged_attention_mla_fused(
output=o,
q_nope=q_nope,
q_pe=q_pe.contiguous(),
kv_cache=kv_c_and_k_pe_cache,
scale=self.scale,
block_tables=attn_metadata.decode.block_table,
context_lens=attn_metadata.decode.seq_lens,
max_context_len=decode_meta.max_decode_seq_len,
k_c_normed=k_c_normed,
k_pe=k_pe,
use_cuda_graph=decode_meta.use_cuda_graph
)
if k_c_normed is None:
q = torch.cat([q_nope, q_pe.contiguous()], dim=-1)
ixf_ops.vllm_paged_attention_mla(
output=o,
query=q,
kv_cache=kv_c_and_k_pe_cache,
scale=self.scale,
block_tables=attn_metadata.decode.block_table,
context_lens=attn_metadata.decode.seq_lens,
max_context_len=decode_meta.max_decode_seq_len,
use_cuda_graph=decode_meta.use_cuda_graph,
)
else:
ixf_ops.vllm_paged_attention_mla_fused(
output=o,
q_nope=q_nope.contiguous(),
q_pe=q_pe.contiguous(),
kv_cache=kv_c_and_k_pe_cache,
scale=self.scale,
block_tables=attn_metadata.decode.block_table,
context_lens=attn_metadata.decode.seq_lens,
max_context_len=decode_meta.max_decode_seq_len,
k_c_normed=k_c_normed,
k_pe=k_pe,
use_cuda_graph=decode_meta.use_cuda_graph,
)
return self._v_up_proj(o), None

View File

@@ -55,6 +55,16 @@ class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
return RocmAttentionMetadataBuilder
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""RocmAiterUnifiedAttention supports all attention types."""
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)
class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
@@ -143,6 +153,19 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens
# Handle encoder attention differently - no KV cache needed
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return self._forward_encoder_attention(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata,
layer,
)
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_cache_dtype.startswith("fp8"):
@@ -195,6 +218,10 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return
key_cache, value_cache = kv_cache.unbind(0)
# Reshape the input keys and values and store them in the cache.
@@ -224,6 +251,10 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
kv_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
):
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return
key_cache, value_cache = kv_cache.unbind(0)
flash_layout = True

View File

@@ -182,7 +182,7 @@ class RocmAttentionBackend(AttentionBackend):
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
return [32, 64, 80, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
@@ -205,6 +205,16 @@ class RocmAttentionBackend(AttentionBackend):
def get_impl_cls() -> type["RocmAttentionImpl"]:
return RocmAttentionImpl
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""RocmAttention supports all attention types."""
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
@@ -244,6 +254,7 @@ class RocmAttentionImpl(AttentionImpl):
kv_sharing_target_layer_name: int | None = None,
sinks: torch.Tensor | None = None,
) -> None:
self.attn_type = attn_type
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
@@ -266,11 +277,6 @@ class RocmAttentionImpl(AttentionImpl):
RocmAttentionBackend.validate_head_size(head_size)
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
"Encoder self-attention is not implemented for RocmAttentionImpl"
)
self.fp8_dtype = current_platform.fp8_dtype()
self.sinks = sinks
@@ -281,6 +287,54 @@ class RocmAttentionImpl(AttentionImpl):
f"num_heads: {num_heads}."
)
def _forward_encoder_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
layer: torch.nn.Module,
) -> torch.Tensor:
"""Forward pass for encoder attention without KV cache.
Args:
query: shape = [num_encoder_tokens, num_heads, head_size]
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
output: shape = [num_encoder_tokens, num_heads, head_size]
attn_metadata: Encoder attention metadata
layer: The attention layer
"""
# For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError(
"quantization is not supported for encoder attention"
)
# Use encoder-specific metadata for sequence information
query_start_loc = attn_metadata.query_start_loc
seq_lens = attn_metadata.seq_lens
max_query_len = attn_metadata.max_query_len
# Call flash attention directly on Q, K, V tensors
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
context_attention_fwd(
q=query,
k=key,
v=value,
o=output,
b_start_loc=query_start_loc,
b_seq_len=seq_lens,
max_input_len=max_query_len,
is_causal=False,
softmax_scale=self.scale,
sliding_window_q=self.sliding_window[0],
sliding_window_k=self.sliding_window[1],
)
return output
def forward(
self,
layer: torch.nn.Module,
@@ -330,6 +384,16 @@ class RocmAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
return self._forward_encoder_attention(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata,
layer,
)
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size
)
@@ -380,6 +444,8 @@ class RocmAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
return
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size
)
@@ -432,6 +498,8 @@ class RocmAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
):
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
return
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache,
layer.num_kv_heads, # type: ignore[attr-defined]