Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -15,13 +15,11 @@ logger = init_logger(__name__)
|
||||
_ROCM_FLASH_ATTN_AVAILABLE = False
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm._custom_ops import reshape_and_cache_flash
|
||||
# from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
|
||||
# flash_attn_varlen_func,
|
||||
# get_scheduler_metadata,
|
||||
# )
|
||||
from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache, flash_attn_varlen_int8_func
|
||||
|
||||
elif current_platform.is_xpu():
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._xpu_ops import xpu_ops
|
||||
@@ -53,67 +51,93 @@ elif current_platform.is_rocm():
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
|
||||
|
||||
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||
# import here to avoid circular dependencies
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return 3
|
||||
|
||||
def get_flash_attn_version(
|
||||
requires_alibi: bool = False, head_size: int | None = None
|
||||
) -> int | None:
|
||||
if current_platform.is_xpu():
|
||||
return 2
|
||||
if current_platform.is_rocm():
|
||||
# ROCm doesn't use vllm_flash_attn; return None to skip fa_version arg
|
||||
return None
|
||||
try:
|
||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
fa_version_unsupported_reason,
|
||||
is_fa_version_supported,
|
||||
)
|
||||
return None
|
||||
# try:
|
||||
# from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
# fa_version_unsupported_reason,
|
||||
# is_fa_version_supported,
|
||||
# )
|
||||
|
||||
device_capability = current_platform.get_device_capability()
|
||||
# device_capability = current_platform.get_device_capability()
|
||||
|
||||
assert device_capability is not None
|
||||
# assert device_capability is not None
|
||||
|
||||
# 1. default version depending on platform
|
||||
fa_version = (
|
||||
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
|
||||
)
|
||||
# # 1. default version depending on platform
|
||||
# if device_capability.major == 9 and is_fa_version_supported(3):
|
||||
# # Hopper (SM90): prefer FA3
|
||||
# fa_version = 3
|
||||
# elif device_capability.major == 10 and is_fa_version_supported(4):
|
||||
# # Blackwell (SM100+, restrict to SM100 for now): prefer FA4
|
||||
# fa_version = 4
|
||||
# else:
|
||||
# # Fallback to FA2
|
||||
# fa_version = 2
|
||||
|
||||
# 2. override if passed by environment or config
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
# # 2. override if passed by environment or config
|
||||
# from vllm.config import get_current_vllm_config_or_none
|
||||
|
||||
vllm_config = get_current_vllm_config_or_none()
|
||||
if (
|
||||
vllm_config is not None
|
||||
and vllm_config.attention_config.flash_attn_version is not None
|
||||
):
|
||||
fa_version = vllm_config.attention_config.flash_attn_version
|
||||
# vllm_config = get_current_vllm_config_or_none()
|
||||
# if (
|
||||
# vllm_config is not None
|
||||
# and vllm_config.attention_config.flash_attn_version is not None
|
||||
# ):
|
||||
# fa_version = vllm_config.attention_config.flash_attn_version
|
||||
|
||||
# 3. fallback for unsupported combinations
|
||||
if device_capability.major == 10 and fa_version == 3:
|
||||
logger.warning_once(
|
||||
"Cannot use FA version 3 on Blackwell platform, "
|
||||
"defaulting to FA version 2."
|
||||
)
|
||||
fa_version = 2
|
||||
# # 3. fallback for unsupported combinations
|
||||
# if device_capability.major >= 10 and fa_version == 3:
|
||||
# logger.warning_once(
|
||||
# "Cannot use FA version 3 on Blackwell platform, "
|
||||
# "defaulting to FA version 4 if supported, otherwise FA2."
|
||||
# )
|
||||
# fa_version = 4 if is_fa_version_supported(4) else 2
|
||||
|
||||
if requires_alibi and fa_version == 3:
|
||||
logger.warning_once(
|
||||
"Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
|
||||
)
|
||||
fa_version = 2
|
||||
# if requires_alibi and fa_version == 3:
|
||||
# logger.warning_once(
|
||||
# "Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
|
||||
# )
|
||||
# fa_version = 2
|
||||
|
||||
if not is_fa_version_supported(fa_version):
|
||||
logger.error(
|
||||
"Cannot use FA version %d is not supported due to %s",
|
||||
fa_version,
|
||||
fa_version_unsupported_reason(fa_version),
|
||||
)
|
||||
# if requires_alibi and fa_version == 4:
|
||||
# logger.warning_once(
|
||||
# "Cannot use FA version 4 with ALiBi, defaulting to FA version 2."
|
||||
# )
|
||||
# fa_version = 2
|
||||
|
||||
assert is_fa_version_supported(fa_version)
|
||||
return fa_version
|
||||
except (ImportError, AssertionError):
|
||||
return None
|
||||
# # FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict
|
||||
# # supported head dimensions.
|
||||
# # See: https://github.com/Dao-AILab/flash-attention/issues/1959
|
||||
# if (
|
||||
# fa_version == 4
|
||||
# and device_capability.major >= 10
|
||||
# and head_size is not None
|
||||
# and head_size > 128
|
||||
# ):
|
||||
# logger.warning_once(
|
||||
# "FA4 on Blackwell does not support head_size=%d due to TMEM "
|
||||
# "capacity limits, defaulting to FA version 2.",
|
||||
# head_size,
|
||||
# )
|
||||
# fa_version = 2
|
||||
|
||||
# if not is_fa_version_supported(fa_version):
|
||||
# logger.error(
|
||||
# "Cannot use FA version %d is not supported due to %s",
|
||||
# fa_version,
|
||||
# fa_version_unsupported_reason(fa_version),
|
||||
# )
|
||||
|
||||
# assert is_fa_version_supported(fa_version)
|
||||
# return fa_version
|
||||
# except (ImportError, AssertionError):
|
||||
# return None
|
||||
|
||||
|
||||
def flash_attn_supports_fp8() -> bool:
|
||||
@@ -124,10 +148,7 @@ def flash_attn_supports_fp8() -> bool:
|
||||
|
||||
|
||||
def flash_attn_supports_sinks() -> bool:
|
||||
if current_platform.is_xpu():
|
||||
return True
|
||||
else:
|
||||
return get_flash_attn_version() == 3
|
||||
return True
|
||||
|
||||
|
||||
def flash_attn_supports_mla():
|
||||
@@ -142,6 +163,10 @@ def flash_attn_supports_mla():
|
||||
return is_fa_version_supported(
|
||||
3
|
||||
) and current_platform.is_device_capability_family(90)
|
||||
|
||||
# NOTE(Lucas): FA4 CuteDSL does NOT currently support MLA's non-standard
|
||||
# head dimensions (576 for qk, 512 for v) due to TMEM capacity limits.
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
pass
|
||||
return False
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
from typing import ClassVar, Optional, Union, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -23,15 +23,15 @@ from vllm.v1.attention.backends.fa_utils import (
|
||||
is_flash_attn_varlen_func_available,
|
||||
)
|
||||
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
|
||||
from ixformer.contrib.vllm_flash_attn import merge_attn_states
|
||||
|
||||
if is_flash_attn_varlen_func_available():
|
||||
from vllm.v1.attention.backends.fa_utils import (
|
||||
flash_attn_supports_sinks,
|
||||
flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache,
|
||||
# get_scheduler_metadata,
|
||||
reshape_and_cache_flash,
|
||||
flash_attn_varlen_int8_func
|
||||
)
|
||||
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
|
||||
from vllm.config.cache import CacheDType
|
||||
@@ -50,9 +50,12 @@ from vllm.v1.attention.backend import (
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
get_dcp_local_seq_lens,
|
||||
get_kv_cache_layout,
|
||||
split_decodes_and_prefills,
|
||||
split_decodes_and_prefills
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm import _custom_ops as ops
|
||||
import vllm.envs as envs
|
||||
import ixformer.inference.functions as ixf_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -63,23 +66,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
if (
|
||||
model_config
|
||||
and model_config.is_hybrid
|
||||
and (
|
||||
cache_config.mamba_ssm_cache_dtype == "float32"
|
||||
or cache_config.mamba_cache_dtype == "float32"
|
||||
)
|
||||
):
|
||||
# NOTE(tdoublep): while in principle, FA supports
|
||||
# MultipleOf(16), these are the block sizes that do not
|
||||
# suffer from the NaN propagation problem described here:
|
||||
# https://github.com/Dao-AILab/flash-attention/issues/1974
|
||||
return [16, 32, 64]
|
||||
return [MultipleOf(16)]
|
||||
return [16, 32, 64]
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
@@ -120,7 +107,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
# return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
if envs.VLLM_ATTN_OPT_LEVEL == 2:
|
||||
return (3, num_blocks, num_kv_heads, block_size, head_size)
|
||||
return (2, num_blocks, num_kv_heads, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
@@ -139,7 +127,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
# (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
|
||||
return (2, 4, 0, 1, 3, 5)
|
||||
elif cache_layout == "HND":
|
||||
stride_order = (0, 1, 3, 2, 4)
|
||||
stride_order = (0, 1, 2, 3, 4)
|
||||
else:
|
||||
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
||||
return stride_order
|
||||
@@ -188,24 +176,22 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
if has_sink and device_capability < DeviceCapability(9, 0):
|
||||
return "sink not supported on compute capability < 9.0"
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionPrefillMetadata:
|
||||
"""Prefill Specific Metadata"""
|
||||
|
||||
""" Prefill Specific Metadata """
|
||||
block_table: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
key_start_loc: torch.Tensor
|
||||
max_query_len: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionDecodeMetadata:
|
||||
block_table: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
max_query_len: int
|
||||
max_decode_seq_len: int
|
||||
|
||||
use_graph: bool
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata:
|
||||
@@ -220,11 +206,12 @@ class FlashAttentionMetadata:
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
key_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
@@ -235,7 +222,6 @@ class FlashAttentionMetadata:
|
||||
cu_prefix_query_lens: torch.Tensor | None
|
||||
prefix_kv_lens: torch.Tensor | None
|
||||
suffix_kv_lens: torch.Tensor | None
|
||||
|
||||
cu_prefix_kv_lens: torch.Tensor | None
|
||||
cu_suffix_kv_lens: torch.Tensor | None
|
||||
|
||||
@@ -247,7 +233,7 @@ class FlashAttentionMetadata:
|
||||
scheduler_metadata: torch.Tensor | None = None
|
||||
prefix_scheduler_metadata: torch.Tensor | None = None
|
||||
max_num_splits: int = 0
|
||||
|
||||
|
||||
prefill: FlashAttentionPrefillMetadata | None = None
|
||||
decode: FlashAttentionDecodeMetadata | None = None
|
||||
|
||||
@@ -291,7 +277,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
else AttentionCGSupport.UNIFORM_BATCH
|
||||
)
|
||||
supports_update_block_table: bool = True
|
||||
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
@classmethod
|
||||
@@ -316,6 +302,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.attention_config = vllm_config.attention_config
|
||||
|
||||
self.decode_use_graph = vllm_config.compilation_config.cudagraph_mode.decode_use_graph()
|
||||
self.num_heads_q = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config
|
||||
)
|
||||
@@ -325,7 +312,6 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
|
||||
self.max_num_splits = 0 # No upper bound on the number of splits.
|
||||
# self.aot_schedule = get_flash_attn_version() == 3
|
||||
self.aot_schedule = False
|
||||
|
||||
try:
|
||||
@@ -346,6 +332,9 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
)
|
||||
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
|
||||
# Align decode/prefill split threshold with speculative decode query length
|
||||
# when backend supports treating spec requests as decode.
|
||||
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
|
||||
|
||||
if self.use_full_cuda_graph and self.aot_schedule:
|
||||
# FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4
|
||||
@@ -388,15 +377,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
key_start_loc = common_attn_metadata.key_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_np = common_attn_metadata.seq_lens_np
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
causal = common_attn_metadata.causal
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(common_attn_metadata)
|
||||
)
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold,
|
||||
)
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
|
||||
|
||||
@@ -467,11 +458,11 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
dcp_context_kv_lens = None
|
||||
|
||||
cu_prefix_query_lens = None
|
||||
cu_prefix_kv_lens = None
|
||||
cu_suffix_kv_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
prefix_scheduler_metadata = None
|
||||
cu_prefix_kv_lens = None
|
||||
cu_suffix_kv_lens = None
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
query_kv_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
@@ -507,11 +498,11 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
prefix_kv_lens = torch.tensor(
|
||||
[common_prefix_len], dtype=torch.int32, device=self.device
|
||||
)
|
||||
# Use GPU tensor directly - no CPU sync needed
|
||||
suffix_kv_lens = seq_lens[:num_reqs] - common_prefix_len
|
||||
cu_prefix_kv_lens = torch.tensor([0, common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
# Use GPU tensor directly - no CPU sync needed
|
||||
suffix_kv_lens = seq_lens[:num_reqs] - common_prefix_len
|
||||
|
||||
cu_suffix_kv_lens = torch.tensor([0,] + suffix_kv_lens.tolist(),
|
||||
dtype=torch.int32,
|
||||
@@ -542,7 +533,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
causal=causal,
|
||||
)
|
||||
# For FA3 + full cudagraph
|
||||
max_num_splits = 0
|
||||
max_num_splits = 0
|
||||
if self.use_full_cuda_graph and scheduler_metadata is not None:
|
||||
n = scheduler_metadata.shape[0]
|
||||
self.scheduler_metadata[:n] = scheduler_metadata
|
||||
@@ -552,50 +543,59 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
# output buffer.
|
||||
self.scheduler_metadata[n:] = 0
|
||||
scheduler_metadata = self.scheduler_metadata[:n]
|
||||
|
||||
|
||||
if num_actual_tokens <= self.max_cudagraph_size:
|
||||
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
||||
# usage, because the intermediate buffers of size [num_splits,
|
||||
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
reqs_start = num_decodes
|
||||
prefill_query_start_loc = (
|
||||
query_start_loc[reqs_start:] - query_start_loc[reqs_start]
|
||||
)
|
||||
prefill_key_start_loc = (
|
||||
query_start_loc[reqs_start:] - query_start_loc[reqs_start]
|
||||
)
|
||||
reqs_start = num_decodes # prefill_start
|
||||
|
||||
prefill_query_start_loc = query_start_loc[
|
||||
reqs_start:] - query_start_loc[reqs_start]
|
||||
prefill_key_start_loc = key_start_loc[
|
||||
reqs_start:] - key_start_loc[reqs_start]
|
||||
prefill_metadata = FlashAttentionPrefillMetadata(
|
||||
block_table=block_table_tensor[reqs_start:, ...],
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
key_start_loc=prefill_key_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
)
|
||||
block_table=block_table_tensor[reqs_start:, ...],
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
key_start_loc=prefill_key_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
)
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
reqs_start = num_decodes
|
||||
reqs_start = num_decodes # prefill_start
|
||||
decode_query_start_loc = query_start_loc[: reqs_start + 1]
|
||||
decode_query_lens = (
|
||||
decode_query_start_loc[1:] - decode_query_start_loc[:-1]
|
||||
)
|
||||
decode_metadata = FlashAttentionDecodeMetadata(
|
||||
block_table=block_table_tensor[:reqs_start, ...],
|
||||
query_start_loc=decode_query_start_loc,
|
||||
seq_lens=seq_lens[:reqs_start],
|
||||
max_decode_seq_len=torch.max(seq_lens[:reqs_start]).item(),
|
||||
max_query_len=decode_query_lens.max().item(),
|
||||
max_decode_seq_len=np.max(seq_lens_np[:reqs_start]).item(),
|
||||
use_graph=num_prefills==0 and self.decode_use_graph
|
||||
)
|
||||
|
||||
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
key_start_loc=key_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
max_dcp_context_kv_len=max_dcp_context_kv_len,
|
||||
dcp_context_kv_lens=dcp_context_kv_lens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
max_dcp_context_kv_len=max_dcp_context_kv_len,
|
||||
dcp_context_kv_lens=dcp_context_kv_lens,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
@@ -607,8 +607,8 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
max_num_splits=max_num_splits,
|
||||
causal=causal,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
prefill = prefill_metadata,
|
||||
decode = decode_metadata,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
@@ -621,6 +621,19 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
new_metadata = copy.copy(metadata)
|
||||
new_metadata.block_table = blk_table
|
||||
new_metadata.slot_mapping = slot_mapping
|
||||
# Keep nested prefill/decode block tables in sync. Decode path consumes
|
||||
# `attn_metadata.decode.block_table`, so updating only the top-level
|
||||
# `block_table` is insufficient when metadata is reused across groups.
|
||||
if metadata.decode is not None:
|
||||
new_decode = copy.copy(metadata.decode)
|
||||
reqs_start = metadata.num_decodes
|
||||
new_decode.block_table = blk_table[:reqs_start, ...]
|
||||
new_metadata.decode = new_decode
|
||||
if metadata.prefill is not None:
|
||||
new_prefill = copy.copy(metadata.prefill)
|
||||
reqs_start = metadata.num_decodes
|
||||
new_prefill.block_table = blk_table[reqs_start:, ...]
|
||||
new_metadata.prefill = new_prefill
|
||||
return new_metadata
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
@@ -667,7 +680,15 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
self.attn_type = attn_type
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
self.vllm_flash_attn_version = get_flash_attn_version(
|
||||
requires_alibi=alibi_slopes is not None,
|
||||
head_size=head_size,
|
||||
)
|
||||
logger.info_once(
|
||||
"Using FlashAttention version %s",
|
||||
self.vllm_flash_attn_version,
|
||||
scope="local",
|
||||
)
|
||||
# Cache the batch invariant result for use in forward passes
|
||||
self.batch_invariant_enabled = vllm_is_batch_invariant()
|
||||
|
||||
@@ -677,6 +698,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
self.sinks = sinks
|
||||
|
||||
if self.sinks is not None:
|
||||
assert flash_attn_supports_sinks(), (
|
||||
"Sinks are only supported in FlashAttention 3"
|
||||
@@ -687,6 +709,28 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
self.supports_quant_query_input = True
|
||||
self.supports_per_head_quant_scales = (
|
||||
self.vllm_flash_attn_version >= 3
|
||||
if self.vllm_flash_attn_version is not None
|
||||
else False
|
||||
)
|
||||
assert envs.VLLM_ATTN_OPT_LEVEL in [0, 1, 2], "VLLM_ATTN_OPT_LEVEL only support [0 for non-quant, 1 for I8Q_I8K_I8V, 2 for I8Q_I8K_F16V] now! but got {}".format(envs.VLLM_ATTN_OPT_LEVEL)
|
||||
'''
|
||||
quant_type = 0
|
||||
attention:f16 qkv
|
||||
cache:f16 kv cache
|
||||
quant_type = 1
|
||||
attention:int8q int8k int8v
|
||||
cache:
|
||||
int8 k cache && fp32 k cache scale
|
||||
int8 v cache && fp32 v cache scale(load from file, dont update)
|
||||
quant_type = 2
|
||||
attention:int8q int8k fp16v
|
||||
cache:
|
||||
int8 k cache && fp32 k cache scale
|
||||
fp16 v cache
|
||||
'''
|
||||
self.quant_type = int(envs.VLLM_ATTN_OPT_LEVEL)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -698,7 +742,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
sqrt_alibi: bool = False,
|
||||
kv_cache_scale: torch.Tensor | None = None,
|
||||
kv_cache_scale: Union[torch.Tensor, List[torch.Tensor]] | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -711,6 +755,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
kv_cache_scale = [num_blocks, num_kv_heads, block_size] + [num_kv_heads, head_size]
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
NOTE: FP8 quantization, flash-attn expect the size of
|
||||
@@ -718,9 +763,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
assert self.vllm_flash_attn_version is not None, (
|
||||
"FlashAttention version not detected."
|
||||
)
|
||||
# assert self.vllm_flash_attn_version is not None, (
|
||||
# "FlashAttention version not detected."
|
||||
# )
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
@@ -729,13 +774,12 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
# return output.fill_(0)
|
||||
return output.fill_(0).view(-1, self.num_heads * self.head_size)
|
||||
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
softmax_scale: float = self.scale
|
||||
window_size = self.sliding_window
|
||||
alibi_slopes: torch.Tensor | None = self.alibi_slopes
|
||||
logits_soft_cap: float | None = self.logits_soft_cap
|
||||
alibi_slopes: torch.Tensor = self.alibi_slopes
|
||||
logits_soft_cap: float = self.logits_soft_cap
|
||||
|
||||
attn_type = self.attn_type
|
||||
|
||||
@@ -761,18 +805,140 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
layer,
|
||||
)
|
||||
).view(-1, self.num_heads * self.head_size)
|
||||
|
||||
# For decoder and cross-attention, use KV cache as before
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
decode_only = has_decode and not has_prefill
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
if self.quant_type == 0:
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
elif self.quant_type == 1:
|
||||
i8_key_cache, i8_value_cache = kv_cache.unbind(0)
|
||||
num_blocks, num_kv_heads, block_size, head_size = i8_key_cache.shape
|
||||
|
||||
key_scale_cache, value_scale_cache = kv_cache_scale
|
||||
assert key_scale_cache.shape == (num_blocks, num_kv_heads, block_size) and key_scale_cache.dtype == torch.float32, f"key_scale_cache.shape {key_scale_cache.shape} != (num_blocks, num_kv_heads, block_size) or key_scale_cache.dtype {key_scale_cache.dtype} != torch.float32"
|
||||
assert value_scale_cache.shape == (num_kv_heads, head_size) and value_scale_cache.dtype == torch.float32, f"value_scale_cache.shape {value_scale_cache.shape} != (num_kv_heads, head_size) or value_scale_cache.dtype {value_scale_cache.dtype} != torch.float32"
|
||||
value_cache_info = (i8_value_cache, value_scale_cache)
|
||||
|
||||
elif self.quant_type == 2:
|
||||
# key_cache 是 f16,value_cache 是 int8
|
||||
i8_key_cache = kv_cache[0]
|
||||
num_blocks, num_kv_heads, block_size, head_size = i8_key_cache.shape
|
||||
value_cache = kv_cache[1:].view(query.dtype).reshape(num_blocks, num_kv_heads, block_size, head_size)
|
||||
key_scale_cache = kv_cache_scale
|
||||
value_cache_info = (value_cache, None)
|
||||
|
||||
decode_q = query[:num_decode_tokens]
|
||||
prefill_q = query[num_decode_tokens:]
|
||||
prefill_output = output[num_decode_tokens:]
|
||||
decode_output = output[:num_decode_tokens]
|
||||
|
||||
if self.quant_type == 1:
|
||||
if decode_only:
|
||||
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
query, 2, transpose_scale=False
|
||||
)
|
||||
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
key, 2, transpose_scale=False
|
||||
)
|
||||
i8_value, _value_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
value, 0, transpose_scale=False, scale=value_cache_info[1]
|
||||
)
|
||||
else:
|
||||
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
query, 2, transpose_scale=True
|
||||
)
|
||||
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
key, 2, transpose_scale=False
|
||||
)
|
||||
i8_value, _value_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
value, 0, transpose_scale=False, scale=value_cache_info[1]
|
||||
)
|
||||
elif self.quant_type == 2:
|
||||
'''
|
||||
origin key cache
|
||||
num_blocks, num_kv_heads, block_size, head_size f16
|
||||
reformat key cache
|
||||
key_cache_i8 : num_blocks, num_kv_heads, block_size, head_size int8
|
||||
key_scale_cache : num_blocks, num_kv_heads, block_size fp32
|
||||
'''
|
||||
|
||||
if decode_only:
|
||||
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
query, 2, transpose_scale=False
|
||||
)
|
||||
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
key, 2, transpose_scale=False
|
||||
)
|
||||
else:
|
||||
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
query, 2, transpose_scale=True
|
||||
)
|
||||
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
key, 2, transpose_scale=False
|
||||
)
|
||||
else:
|
||||
if layer.quant_manager is not None and layer.quant_manager.check_enable():
|
||||
i8_value, value_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
value, 0, transpose_scale=False
|
||||
)
|
||||
layer.quant_manager.update_data(value_scale)
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
if self.quant_type == 1:
|
||||
if has_prefill:
|
||||
ixf_ops.reshape_and_cache_flash_int8(
|
||||
key=i8_key,
|
||||
value=i8_value,
|
||||
k_scale=key_scale,
|
||||
key_cache=i8_key_cache,
|
||||
value_cache=value_cache_info[0],
|
||||
key_scale_cache=key_scale_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping,
|
||||
kv_cache_dtype="",
|
||||
)
|
||||
elif self.quant_type == 2:
|
||||
if has_prefill:
|
||||
ixf_ops.reshape_and_cache_flash_mix(
|
||||
key=i8_key,
|
||||
value=value,
|
||||
k_scale=key_scale,
|
||||
key_cache=i8_key_cache,
|
||||
value_cache=value_cache_info[0],
|
||||
key_scale_cache=key_scale_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping,
|
||||
kv_cache_dtype="",
|
||||
)
|
||||
|
||||
else:
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
# queries are quantized in the attention layer
|
||||
@@ -783,19 +949,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
value_cache = value_cache.view(dtype)
|
||||
|
||||
if not attn_metadata.use_cascade:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
scheduler_metadata = attn_metadata.scheduler_metadata
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
|
||||
|
||||
q_descale = layer._q_scale.expand(descale_shape)
|
||||
k_descale = layer._k_scale.expand(descale_shape)
|
||||
v_descale = layer._v_scale.expand(descale_shape)
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
self._forward_with_dcp(
|
||||
query[:num_actual_tokens],
|
||||
@@ -805,79 +958,140 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
value_cache,
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
else:
|
||||
sliding_window_size = (
|
||||
list(self.sliding_window)
|
||||
if self.sliding_window is not None
|
||||
else None
|
||||
)
|
||||
if has_prefill:
|
||||
flash_attn_varlen_func(
|
||||
q=prefill_q,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
|
||||
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
|
||||
max_seqlen_q=attn_metadata.prefill.max_query_len,
|
||||
max_seqlen_k=attn_metadata.max_query_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=sliding_window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
sqrt_alibi=sqrt_alibi,
|
||||
sinks=self.sinks,
|
||||
out=prefill_output,
|
||||
block_table=attn_metadata.prefill.block_table,
|
||||
)
|
||||
# key = key[num_decode_tokens:]
|
||||
# value = value[num_decode_tokens:]
|
||||
|
||||
# int8 attn
|
||||
if self.quant_type > 0:
|
||||
flash_attn_varlen_int8_func(
|
||||
q=int8_query[num_decode_tokens:],
|
||||
k=i8_key_cache,
|
||||
v=value_cache_info[0],
|
||||
q_scale=query_scale[:, num_decode_tokens:],
|
||||
k_scale=key_scale_cache,
|
||||
v_scale=value_cache_info[1],
|
||||
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
|
||||
cu_seqlens_k=attn_metadata.prefill.key_start_loc,
|
||||
max_seqlen_q=attn_metadata.prefill.max_query_len,
|
||||
max_seqlen_k=attn_metadata.max_query_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
sqrt_alibi=sqrt_alibi,
|
||||
out=prefill_output,
|
||||
block_table=attn_metadata.prefill.block_table,
|
||||
output_dtype=query.dtype
|
||||
)
|
||||
else:
|
||||
flash_attn_varlen_func(
|
||||
q=prefill_q,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
|
||||
cu_seqlens_k=attn_metadata.prefill.key_start_loc,
|
||||
max_seqlen_q=attn_metadata.prefill.max_query_len,
|
||||
max_seqlen_k=attn_metadata.max_query_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
sqrt_alibi=sqrt_alibi,
|
||||
sinks=self.sinks,
|
||||
out=prefill_output,
|
||||
block_table=attn_metadata.prefill.block_table,
|
||||
)
|
||||
if has_decode:
|
||||
flash_attn_with_kvcache(
|
||||
q=decode_q.unsqueeze(1),
|
||||
k_cache=key_cache.contiguous(),
|
||||
v_cache=value_cache.contiguous(),
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=sliding_window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
use_sqrt_alibi=sqrt_alibi,
|
||||
out=decode_output.unsqueeze(1),
|
||||
max_context_len=attn_metadata.decode.max_decode_seq_len,
|
||||
# sinks=self.sinks,
|
||||
)
|
||||
# for mtp + cuda graph
|
||||
max_q_len = attn_metadata.decode.max_query_len if attn_metadata.decode is not None else attn_metadata.max_query_len
|
||||
max_ct_len = attn_metadata.decode.max_decode_seq_len if attn_metadata.decode is not None else attn_metadata.max_seq_len
|
||||
if self.quant_type in [1, 2]:
|
||||
para_dict = dict(
|
||||
output=decode_output,
|
||||
query=int8_query[:num_decode_tokens],
|
||||
key_cache=i8_key_cache,
|
||||
query_scale=query_scale[:num_decode_tokens] if decode_only else query_scale[:, :num_decode_tokens].t().contiguous(),
|
||||
key_scale_cache=key_scale_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
scale=softmax_scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
block_size=i8_key_cache.shape[-2],
|
||||
softcap=logits_soft_cap,
|
||||
alibi_slopes=alibi_slopes,
|
||||
causal=True,
|
||||
window_left=window_size[0],
|
||||
window_right=window_size[1],
|
||||
use_sqrt_alibi = sqrt_alibi,
|
||||
use_cuda_graph=attn_metadata.decode.use_graph if decode_only else False,
|
||||
max_context_len=max_ct_len,
|
||||
# mtp
|
||||
cu_query_lens=attn_metadata.decode.query_start_loc,
|
||||
max_query_len=max_q_len,
|
||||
)
|
||||
|
||||
if self.quant_type == 1:
|
||||
para_dict.update(
|
||||
dict(
|
||||
value_cache=value_cache_info[0],
|
||||
value_scale_cache=value_cache_info[1],
|
||||
)
|
||||
)
|
||||
# for kv + k_scale write fusion
|
||||
if decode_only:
|
||||
para_dict.update(
|
||||
dict(
|
||||
save_key=i8_key[:num_decode_tokens],
|
||||
save_value=i8_value[:num_decode_tokens],
|
||||
save_key_scale=key_scale[:num_decode_tokens],
|
||||
)
|
||||
)
|
||||
ixf_ops.vllm_paged_attention_int8(**para_dict)
|
||||
elif self.quant_type == 2:
|
||||
para_dict.update(
|
||||
dict(
|
||||
value_cache=value_cache,
|
||||
)
|
||||
)
|
||||
if decode_only:
|
||||
para_dict.update(
|
||||
dict(
|
||||
save_key=i8_key[:num_decode_tokens],
|
||||
save_value=value[:num_decode_tokens].contiguous(),
|
||||
save_key_scale=key_scale[:num_decode_tokens],
|
||||
)
|
||||
)
|
||||
ixf_ops.vllm_paged_attention_mix(
|
||||
**para_dict
|
||||
)
|
||||
else:
|
||||
flash_attn_with_kvcache(
|
||||
q=decode_q.unsqueeze(1),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
max_query_len=max_q_len,
|
||||
cu_query_lens=attn_metadata.decode.query_start_loc,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
use_sqrt_alibi=sqrt_alibi,
|
||||
sinks=self.sinks,
|
||||
out=decode_output.unsqueeze(1),
|
||||
use_cuda_graph=attn_metadata.decode.use_graph,
|
||||
max_context_len=max_ct_len
|
||||
)
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
# flash_attn_varlen_func(
|
||||
# q=query[:num_actual_tokens],
|
||||
# k=key_cache,
|
||||
# v=value_cache,
|
||||
# out=output[:num_actual_tokens],
|
||||
# cu_seqlens_q=cu_seqlens_q,
|
||||
# max_seqlen_q=max_seqlen_q,
|
||||
# seqused_k=seqused_k,
|
||||
# max_seqlen_k=max_seqlen_k,
|
||||
# softmax_scale=self.scale,
|
||||
# causal=attn_metadata.causal,
|
||||
# alibi_slopes=self.alibi_slopes,
|
||||
# window_size=sliding_window_size,
|
||||
# block_table=block_table,
|
||||
# softcap=self.logits_soft_cap,
|
||||
# scheduler_metadata=scheduler_metadata,
|
||||
# fa_version=self.vllm_flash_attn_version,
|
||||
# q_descale=q_descale,
|
||||
# k_descale=k_descale,
|
||||
# v_descale=v_descale,
|
||||
# num_splits=attn_metadata.max_num_splits,
|
||||
# s_aux=self.sinks,
|
||||
# )
|
||||
# return output
|
||||
|
||||
# Cascade attention (rare case).
|
||||
cascade_attention(
|
||||
@@ -906,12 +1120,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
v_descale=layer._v_scale,
|
||||
s_aux=self.sinks,
|
||||
)
|
||||
# return output
|
||||
return (
|
||||
output[:num_actual_tokens]
|
||||
.contiguous()
|
||||
.view(-1, self.num_heads * self.head_size)
|
||||
)
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
@@ -935,7 +1144,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
reshape_and_cache_flash(
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
@@ -959,9 +1168,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
k_descale: torch.Tensor | None = None,
|
||||
v_descale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.vllm_flash_attn_version is not None, (
|
||||
"FlashAttention version not detected."
|
||||
)
|
||||
# assert self.vllm_flash_attn_version is not None, (
|
||||
# "FlashAttention version not detected."
|
||||
# )
|
||||
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
@@ -969,27 +1178,22 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
query = query.contiguous()
|
||||
query_across_dcp = get_dcp_group().all_gather(query, dim=1)
|
||||
cu_dcp_kv_klens = attn_metadata.dcp_context_kv_lens.cumsum(dim=0, dtype=torch.int32)
|
||||
new_tensor = torch.tensor([0],
|
||||
device=attn_metadata.dcp_context_kv_lens.device,
|
||||
dtype=attn_metadata.dcp_context_kv_lens.dtype)
|
||||
cu_seqlens_k = torch.cat([new_tensor, cu_dcp_kv_klens])
|
||||
sliding_window_size = (
|
||||
list(self.sliding_window) if self.sliding_window is not None else None
|
||||
)
|
||||
cu_seqlens_k = torch.cat(
|
||||
[
|
||||
torch.zeros(1, device=cu_seqlens_q.device, dtype=cu_seqlens_q.dtype),
|
||||
attn_metadata.dcp_context_kv_lens.cumsum(
|
||||
dim=0, dtype=cu_seqlens_q.dtype
|
||||
),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
context_attn_out, context_lse = flash_attn_varlen_func(
|
||||
q=query_across_dcp,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=None,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
# seqused_k=attn_metadata.dcp_context_kv_lens,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=attn_metadata.max_dcp_context_kv_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
@@ -998,11 +1202,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
# scheduler_metadata=attn_metadata.scheduler_metadata,
|
||||
# fa_version=self.vllm_flash_attn_version,
|
||||
# q_descale=q_descale,
|
||||
# k_descale=k_descale,
|
||||
# v_descale=v_descale,
|
||||
)
|
||||
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
|
||||
context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs(
|
||||
@@ -1028,10 +1227,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
window_size=sliding_window_size,
|
||||
softcap=self.logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
# fa_version=self.vllm_flash_attn_version,
|
||||
# q_descale=q_descale,
|
||||
# k_descale=k_descale,
|
||||
# v_descale=v_descale,
|
||||
)
|
||||
assert context_attn_out_cor.shape == query_attn_out.shape
|
||||
assert context_lse_cor.shape == query_lse.shape
|
||||
@@ -1040,7 +1235,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
context_lse_cor,
|
||||
query_attn_out,
|
||||
query_lse,
|
||||
output,
|
||||
output
|
||||
)
|
||||
|
||||
def _forward_encoder_attention(
|
||||
@@ -1062,9 +1257,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
attn_metadata: Encoder attention metadata
|
||||
layer: The attention layer
|
||||
"""
|
||||
assert self.vllm_flash_attn_version is not None, (
|
||||
"FlashAttention version not detected."
|
||||
)
|
||||
# assert self.vllm_flash_attn_version is not None, (
|
||||
# "FlashAttention version not detected."
|
||||
# )
|
||||
|
||||
# For encoder attention, process FP8 quantization if needed
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
@@ -1101,18 +1296,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=sliding_window_size,
|
||||
softcap=self.logits_soft_cap,
|
||||
# fa_version=self.vllm_flash_attn_version,
|
||||
# q_descale=layer._q_scale.expand(descale_shape),
|
||||
# k_descale=layer._k_scale.expand(descale_shape),
|
||||
# v_descale=layer._v_scale.expand(descale_shape),
|
||||
# num_splits=1 if self.batch_invariant_enabled else 0,
|
||||
)
|
||||
|
||||
return (
|
||||
output[: attn_metadata.num_actual_tokens]
|
||||
.contiguous()
|
||||
.view(-1, self.num_heads * self.head_size)
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def use_cascade_attention(
|
||||
@@ -1203,8 +1389,6 @@ def cascade_attention(
|
||||
cu_prefix_query_lens: torch.Tensor,
|
||||
cu_prefix_kv_lens: torch.Tensor,
|
||||
cu_suffix_kv_lens: torch.Tensor,
|
||||
# prefix_kv_lens: torch.Tensor,
|
||||
# suffix_kv_lens: torch.Tensor,
|
||||
max_kv_len: int,
|
||||
softmax_scale: float,
|
||||
alibi_slopes: torch.Tensor | None,
|
||||
@@ -1228,12 +1412,13 @@ def cascade_attention(
|
||||
)
|
||||
|
||||
num_tokens = query.shape[0]
|
||||
# block_size = key_cache.shape[-3]
|
||||
block_size = key_cache.shape[-2]
|
||||
assert common_prefix_len % block_size == 0
|
||||
num_common_kv_blocks = common_prefix_len // block_size
|
||||
assert num_common_kv_blocks > 0
|
||||
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||
assert q_descale is None or q_descale==1, f"q_descale is not None, q_descale: {q_descale}"
|
||||
assert k_descale is None or k_descale==1, f"k_descale is not None, k_descale: {k_descale}"
|
||||
assert v_descale is None or v_descale==1, f"v_descale is not None, v_descale: {v_descale}"
|
||||
|
||||
# Process shared prefix.
|
||||
prefix_output, prefix_lse = flash_attn_varlen_func(
|
||||
@@ -1241,7 +1426,6 @@ def cascade_attention(
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_prefix_query_lens,
|
||||
# seqused_k=prefix_kv_lens,
|
||||
cu_seqlens_k=cu_prefix_kv_lens,
|
||||
max_seqlen_q=num_tokens,
|
||||
max_seqlen_k=common_prefix_len,
|
||||
@@ -1251,26 +1435,14 @@ def cascade_attention(
|
||||
block_table=block_table[:1],
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
# scheduler_metadata=prefix_scheduler_metadata,
|
||||
# fa_version=fa_version,
|
||||
# q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
|
||||
# k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
|
||||
# v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
|
||||
# s_aux is incorporated into prefix_lse inside the GPU kernel,
|
||||
# enabling its effect during the final attention merge.
|
||||
# s_aux=s_aux,
|
||||
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
|
||||
)
|
||||
|
||||
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||
|
||||
# Process suffix per query.
|
||||
suffix_output, suffix_lse = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
# seqused_k=suffix_kv_lens,
|
||||
cu_seqlens_k=cu_suffix_kv_lens,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len - common_prefix_len,
|
||||
@@ -1280,14 +1452,6 @@ def cascade_attention(
|
||||
block_table=block_table[:, num_common_kv_blocks:],
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
# scheduler_metadata=suffix_scheduler_metadata,
|
||||
# fa_version=fa_version,
|
||||
# q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
|
||||
# k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
|
||||
# v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
|
||||
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
|
||||
)
|
||||
|
||||
# Merge prefix and suffix outputs, and store the result in output.
|
||||
# merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse)
|
||||
merge_attn_states(prefix_output, prefix_lse, suffix_output, suffix_lse, output)
|
||||
|
||||
@@ -13,7 +13,7 @@ from flashinfer import (
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
MultiLevelCascadeAttentionWrapper,
|
||||
)
|
||||
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
|
||||
from flashinfer.decode import fast_decode_plan, trtllm_batch_decode_with_kv_cache
|
||||
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||
from flashinfer.utils import FP4Tensor
|
||||
from typing_extensions import override
|
||||
@@ -199,14 +199,14 @@ class BatchDCPPrefillWrapper:
|
||||
):
|
||||
"""Plan the prefill operation with given parameters."""
|
||||
self._context.plan(
|
||||
qo_indptr_cpu,
|
||||
paged_kv_indptr_cpu,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len_cpu,
|
||||
num_qo_heads * dcp_world_size,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
qo_indptr=qo_indptr_cpu,
|
||||
paged_kv_indptr=paged_kv_indptr_cpu,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len_cpu,
|
||||
num_qo_heads=num_qo_heads * dcp_world_size,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim_qk=head_dim,
|
||||
page_size=page_size,
|
||||
causal=False, # This is context run
|
||||
sm_scale=sm_scale,
|
||||
window_left=window_left,
|
||||
@@ -374,13 +374,13 @@ class FlashInferBackend(AttentionBackend):
|
||||
|
||||
@classmethod
|
||||
def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
if capability is not None and capability.major == 10:
|
||||
return "HND"
|
||||
return None
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FIPrefill:
|
||||
@@ -573,20 +573,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# try to use fp8 q if kv cache is fp8, and will fall back to model dtype
|
||||
# if TRTLLM attention kernel is not used when building attn metadata
|
||||
can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
|
||||
|
||||
# TRTLLM attention requires strictly contiguous KV cache tensors.
|
||||
# When KV transfer (P/D disaggregation) is enabled, the KV cache may be
|
||||
# permuted into non-contiguous views, which causes assertion failures.
|
||||
self._kv_transfer_enabled = vllm_config.kv_transfer_config is not None
|
||||
if can_use_trtllm and self._kv_transfer_enabled:
|
||||
logger.info_once(
|
||||
"TRTLLM attention is disabled because KV transfer "
|
||||
"(P/D disaggregation) is enabled. TRTLLM attention requires "
|
||||
"strictly contiguous KV cache tensors which may not be "
|
||||
"guaranteed with KV transfer."
|
||||
)
|
||||
can_use_trtllm = False
|
||||
|
||||
if (
|
||||
can_use_trtllm
|
||||
and not vllm_config.attention_config.disable_flashinfer_q_quantization
|
||||
@@ -816,6 +802,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
page_size,
|
||||
paged_kv_last_page_len_np,
|
||||
)
|
||||
self.paged_kv_last_page_len.gpu[:num_reqs].copy_(
|
||||
self.paged_kv_last_page_len.cpu[:num_reqs], non_blocking=True
|
||||
)
|
||||
return paged_kv_indices
|
||||
|
||||
def build(
|
||||
@@ -860,9 +849,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
has_sinks=self.has_sinks,
|
||||
has_spec=uses_spec_reorder,
|
||||
)
|
||||
# KV transfer requires non-contiguous KV cache views, incompatible with TRTLLM
|
||||
if self._kv_transfer_enabled:
|
||||
prefill_use_trtllm = False
|
||||
decode_use_trtllm = (
|
||||
self.use_trtllm_decode_attention and self.dcp_world_size <= 1
|
||||
)
|
||||
@@ -997,14 +983,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
||||
attn_metadata.cascade_wrapper.plan(
|
||||
[shared_qo_indptr_cpu, qo_indptr_cpu],
|
||||
[shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
|
||||
[shared_kv_page_indices_cpu, paged_kv_indices],
|
||||
[shared_kv_last_page_len_cpu, paged_kv_last_page_len_cpu],
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
qo_indptr_arr=[shared_qo_indptr_cpu, qo_indptr_cpu],
|
||||
paged_kv_indptr_arr=[shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
|
||||
paged_kv_indices_arr=[shared_kv_page_indices_cpu, paged_kv_indices],
|
||||
paged_kv_last_page_len=[
|
||||
shared_kv_last_page_len_cpu,
|
||||
paged_kv_last_page_len_cpu,
|
||||
],
|
||||
num_qo_heads=self.num_qo_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_dim=self.head_dim,
|
||||
page_size=self.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.sm_scale,
|
||||
window_left=self.window_left,
|
||||
@@ -1082,14 +1071,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
)
|
||||
prefill_wrapper.plan(
|
||||
qo_indptr_prefill_cpu,
|
||||
paged_kv_indptr_prefill_cpu,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len_prefill_cpu,
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
qo_indptr=qo_indptr_prefill_cpu,
|
||||
paged_kv_indptr=paged_kv_indptr_prefill_cpu,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len_prefill_cpu,
|
||||
num_qo_heads=self.num_qo_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_dim_qk=self.head_dim,
|
||||
page_size=self.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.sm_scale,
|
||||
window_left=self.window_left,
|
||||
@@ -1130,14 +1119,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# in atten_metadata when using cudagraph.
|
||||
fast_plan_decode(
|
||||
decode_wrapper,
|
||||
self.paged_kv_indptr.cpu[: num_input_tokens + 1],
|
||||
paged_kv_indices,
|
||||
self.paged_kv_last_page_len.cpu[:num_input_tokens],
|
||||
seq_lens_cpu[:num_input_tokens],
|
||||
self.num_qo_heads * self.dcp_world_size,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
indptr_cpu=self.paged_kv_indptr.cpu[: num_input_tokens + 1],
|
||||
indices=paged_kv_indices,
|
||||
last_page_len_cpu=self.paged_kv_last_page_len.cpu[
|
||||
:num_input_tokens
|
||||
],
|
||||
num_qo_heads=self.num_qo_heads * self.dcp_world_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_dim=self.head_dim,
|
||||
page_size=self.page_size,
|
||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||
pos_encoding_mode="NONE",
|
||||
sm_scale=self.sm_scale,
|
||||
@@ -1330,32 +1320,15 @@ class FlashInferImpl(AttentionImpl):
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[:, 0],
|
||||
kv_cache[:, 1],
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
||||
# to process the cache when the kv_cache_dtype is fp8
|
||||
if self.kv_sharing_target_layer_name is None and self.kv_cache_dtype.startswith(
|
||||
"fp8"
|
||||
):
|
||||
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
self.kv_cache_dtype
|
||||
)
|
||||
|
||||
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
||||
# to process the cache when the kv_cache_dtype is fp8
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
self.kv_cache_dtype
|
||||
)
|
||||
kv_cache = kv_cache.view(torch_dtype)
|
||||
kv_cache = kv_cache.view(torch_dtype)
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
query = query[:num_actual_tokens]
|
||||
@@ -1599,13 +1572,39 @@ class FlashInferImpl(AttentionImpl):
|
||||
)
|
||||
return output_padded
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[:, 0],
|
||||
kv_cache[:, 1],
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
|
||||
def fast_plan_decode(
|
||||
self, # decode wrapper
|
||||
indptr_cpu: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
last_page_len_cpu: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
@@ -1642,110 +1641,56 @@ def fast_plan_decode(
|
||||
# this warm up is to generate the _cached_module for the decode wrapper.
|
||||
if not self.is_cuda_graph_enabled or getattr(self, "vllm_first_call", True):
|
||||
self.plan(
|
||||
indptr_cpu,
|
||||
indices,
|
||||
last_page_len_cpu,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
pos_encoding_mode,
|
||||
window_left,
|
||||
logits_soft_cap,
|
||||
q_data_type,
|
||||
kv_data_type,
|
||||
o_data_type,
|
||||
data_type,
|
||||
sm_scale,
|
||||
rope_scale,
|
||||
rope_theta,
|
||||
non_blocking,
|
||||
None, # block_tables
|
||||
None, # seq_lens
|
||||
fixed_split_size,
|
||||
disable_split_kv,
|
||||
indptr=indptr_cpu,
|
||||
indices=indices,
|
||||
last_page_len=last_page_len_cpu,
|
||||
num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
page_size=page_size,
|
||||
pos_encoding_mode=pos_encoding_mode,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
q_data_type=q_data_type,
|
||||
kv_data_type=kv_data_type,
|
||||
o_data_type=o_data_type,
|
||||
data_type=data_type,
|
||||
sm_scale=sm_scale,
|
||||
rope_scale=rope_scale,
|
||||
rope_theta=rope_theta,
|
||||
non_blocking=non_blocking,
|
||||
block_tables=None,
|
||||
seq_lens=None,
|
||||
fixed_split_size=fixed_split_size,
|
||||
disable_split_kv=disable_split_kv,
|
||||
)
|
||||
self.vllm_first_call = False
|
||||
return
|
||||
|
||||
assert self.is_cuda_graph_enabled, "Should be cudagraph only here"
|
||||
|
||||
batch_size = len(last_page_len_cpu)
|
||||
if logits_soft_cap is None:
|
||||
logits_soft_cap = 0.0
|
||||
|
||||
# Handle data types consistently
|
||||
if data_type is not None:
|
||||
if q_data_type is None:
|
||||
q_data_type = data_type
|
||||
if kv_data_type is None:
|
||||
kv_data_type = data_type
|
||||
elif q_data_type is None:
|
||||
q_data_type = "float16"
|
||||
|
||||
if kv_data_type is None:
|
||||
kv_data_type = q_data_type
|
||||
q_data_type = (
|
||||
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
|
||||
fast_decode_plan(
|
||||
self,
|
||||
indptr=indptr_cpu,
|
||||
indices=indices,
|
||||
last_page_len=last_page_len_cpu,
|
||||
num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
page_size=page_size,
|
||||
pos_encoding_mode=pos_encoding_mode,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
q_data_type=q_data_type,
|
||||
kv_data_type=kv_data_type,
|
||||
data_type=data_type,
|
||||
sm_scale=sm_scale,
|
||||
rope_scale=rope_scale,
|
||||
rope_theta=rope_theta,
|
||||
non_blocking=non_blocking,
|
||||
fixed_split_size=fixed_split_size,
|
||||
disable_split_kv=disable_split_kv,
|
||||
)
|
||||
kv_data_type = (
|
||||
getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type
|
||||
)
|
||||
|
||||
if batch_size != self._fixed_batch_size:
|
||||
raise ValueError(
|
||||
"The batch size should be fixed in cudagraph mode, the runtime "
|
||||
"batch size {} mismatches the batch size set during "
|
||||
"initialization {}".format(batch_size, self._fixed_batch_size)
|
||||
)
|
||||
if len(indices) > len(self._paged_kv_indices_buf):
|
||||
raise ValueError(
|
||||
"The size of indices should be less than or equal to the allocated buffer"
|
||||
)
|
||||
|
||||
# host-to-device copy for the indptr buffer
|
||||
self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True)
|
||||
# host-to-device copy for the last_page_len buffer
|
||||
self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, non_blocking=True)
|
||||
|
||||
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
||||
|
||||
try:
|
||||
# Make sure we pass exactly 19 arguments for tensor core version
|
||||
args = [
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
qo_indptr_host,
|
||||
indptr_cpu,
|
||||
seq_lens_cpu,
|
||||
batch_size, # total_num_rows
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
page_size,
|
||||
self.is_cuda_graph_enabled,
|
||||
head_dim,
|
||||
head_dim,
|
||||
False, # causal
|
||||
window_left,
|
||||
]
|
||||
if self._backend == "fa2":
|
||||
args.append(fixed_split_size)
|
||||
args.append(disable_split_kv)
|
||||
args.append(0) # num_colocated_ctas
|
||||
self._plan_info = self._cached_module.plan(
|
||||
*args,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in tensor core plan: {e}") from e
|
||||
|
||||
self._pos_encoding_mode = pos_encoding_mode
|
||||
self._window_left = window_left
|
||||
self._logits_soft_cap = logits_soft_cap
|
||||
self._sm_scale = sm_scale
|
||||
self._rope_scale = rope_scale
|
||||
self._rope_theta = rope_theta
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Any
|
||||
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
BaseMambaAttentionMetadata,
|
||||
BaseMambaAttentionMetadataBuilder,
|
||||
@@ -29,3 +30,31 @@ class Mamba1AttentionMetadataBuilder(
|
||||
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
|
||||
):
|
||||
metadata_cls = Mamba1AttentionMetadata
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Mamba1AttentionMetadata:
|
||||
common = self._compute_common_metadata(common_attn_metadata)
|
||||
|
||||
if (
|
||||
common.num_prefills > 0
|
||||
and self.vllm_config.cache_config.mamba_cache_mode == "all"
|
||||
):
|
||||
cu_chunk_seqlen_p, _, last_chunk_indices_p = (
|
||||
self._build_chunk_metadata_tensors(
|
||||
self.kv_cache_spec.block_size,
|
||||
common,
|
||||
common_attn_metadata,
|
||||
)
|
||||
)
|
||||
return replace(
|
||||
common,
|
||||
cu_chunk_seqlen_p=cu_chunk_seqlen_p,
|
||||
last_chunk_indices_p=last_chunk_indices_p,
|
||||
)
|
||||
|
||||
return common
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
CommonAttentionMetadata,
|
||||
@@ -105,14 +104,6 @@ class Mamba2AttentionMetadata(BaseMambaAttentionMetadata):
|
||||
|
||||
# Chunk-related metadata (only for prefill)
|
||||
seq_idx_p: torch.Tensor | None = None
|
||||
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
|
||||
# each chunk, its offsets into the varlen sequence dimension. It is defined
|
||||
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
|
||||
# cu_chunk_seqlen_p[i+1].
|
||||
cu_chunk_seqlen_p: torch.Tensor | None = None
|
||||
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
|
||||
# index of the last chunk for every sequence in the (prefill) batch.
|
||||
last_chunk_indices_p: torch.Tensor | None = None
|
||||
|
||||
|
||||
class Mamba2AttentionMetadataBuilder(
|
||||
@@ -134,68 +125,6 @@ class Mamba2AttentionMetadataBuilder(
|
||||
)
|
||||
self.chunk_size: int = chunk_size
|
||||
|
||||
def _compute_chunk_metadata(
|
||||
self,
|
||||
num_prefills: int,
|
||||
num_computed_tokens_p_cpu: torch.Tensor,
|
||||
query_start_loc_p_cpu: torch.Tensor,
|
||||
) -> tuple[list[int], list[int], list[int]]:
|
||||
"""
|
||||
Compute chunk-specific metadata for Mamba2.
|
||||
|
||||
The code below carefully constructs the chunks such that:
|
||||
1. Chunks contain tokens from a *single* sequence only.
|
||||
2. For every sequence, we are guaranteed that we can
|
||||
retrieve the mamba state *every* chunk_size tokens.
|
||||
Constraint (1) dramatically simplifies the mamba2 kernels.
|
||||
Constraint (2) dramatically simplifies the implementation
|
||||
of prefix caching for mamba2 (wip). We need to take care
|
||||
of the interaction with chunked prefill in order to
|
||||
satisfy constraint (2).
|
||||
"""
|
||||
# TODO (tdoublep): This code could probably be optimized.
|
||||
cu_chunk_seqlen = []
|
||||
seq_idx = []
|
||||
last_chunk_indices = []
|
||||
seqlen_pos = 0
|
||||
|
||||
for req_idx in range(num_prefills):
|
||||
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
|
||||
this_new_tokens = (
|
||||
query_start_loc_p_cpu[req_idx + 1].item()
|
||||
- query_start_loc_p_cpu[req_idx].item()
|
||||
)
|
||||
|
||||
# if computed tokens are not chunk-aligned, use the first
|
||||
# chunk to finish it off
|
||||
if this_num_computed % self.chunk_size != 0:
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
# how many tokens to finish the chunk?
|
||||
chunk_len = (
|
||||
cdiv(this_num_computed, self.chunk_size) * self.chunk_size
|
||||
- this_num_computed
|
||||
)
|
||||
# we can only use at most this_new_tokens
|
||||
chunk_len = min(chunk_len, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
n_chunks = cdiv(this_new_tokens, self.chunk_size)
|
||||
for chunk in range(n_chunks):
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
chunk_len = min(self.chunk_size, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
assert this_new_tokens == 0
|
||||
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
|
||||
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
|
||||
return cu_chunk_seqlen, seq_idx, last_chunk_indices
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
@@ -220,41 +149,12 @@ class Mamba2AttentionMetadataBuilder(
|
||||
else False
|
||||
)
|
||||
|
||||
num_reqs = common.num_reqs
|
||||
num_prefills = common.num_prefills
|
||||
num_decode_tokens = common.num_decode_tokens
|
||||
|
||||
num_computed_tokens_cpu = (
|
||||
common_attn_metadata.compute_num_computed_tokens().cpu()
|
||||
)
|
||||
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
query_start_loc_p_cpu = (
|
||||
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
|
||||
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
|
||||
num_prefills,
|
||||
num_computed_tokens_p_cpu,
|
||||
query_start_loc_p_cpu,
|
||||
)
|
||||
|
||||
seq_idx_p = torch.as_tensor(
|
||||
seq_idx,
|
||||
device=common_attn_metadata.query_start_loc.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_chunk_seqlen_p = torch.as_tensor(
|
||||
cu_chunk_seqlen,
|
||||
device=common_attn_metadata.query_start_loc.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
last_chunk_indices_p = torch.as_tensor(
|
||||
last_chunk_indices,
|
||||
device=common_attn_metadata.query_start_loc.device,
|
||||
dtype=torch.int32,
|
||||
cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p = (
|
||||
self._build_chunk_metadata_tensors(
|
||||
self.chunk_size,
|
||||
common,
|
||||
common_attn_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return replace(
|
||||
|
||||
@@ -59,6 +59,15 @@ class BaseMambaAttentionMetadata:
|
||||
# The following tensor is only used for prefix caching in align mode
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
|
||||
# each chunk, its offsets into the varlen sequence dimension. It is defined
|
||||
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
|
||||
# cu_chunk_seqlen_p[i+1].
|
||||
cu_chunk_seqlen_p: torch.Tensor | None = None
|
||||
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
|
||||
# index of the last chunk for every sequence in the (prefill) batch.
|
||||
last_chunk_indices_p: torch.Tensor | None = None
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: dict | None = None
|
||||
batch_ptr: torch.Tensor | None = None
|
||||
@@ -185,6 +194,118 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
common_attn_metadata, num_accepted_tokens=num_accepted_tokens
|
||||
)
|
||||
|
||||
def _compute_chunk_metadata(
|
||||
self,
|
||||
chunk_size: int,
|
||||
num_prefills: int,
|
||||
num_computed_tokens_p_cpu: torch.Tensor,
|
||||
query_start_loc_p_cpu: torch.Tensor,
|
||||
) -> tuple[list[int], list[int], list[int]]:
|
||||
"""
|
||||
Compute chunk-specific metadata for Mamba models.
|
||||
|
||||
The code below carefully constructs the chunks such that:
|
||||
1. Chunks contain tokens from a *single* sequence only.
|
||||
2. For every sequence, we are guaranteed that we can
|
||||
retrieve the mamba state *every* chunk_size tokens.
|
||||
Constraint (1) dramatically simplifies the mamba kernels.
|
||||
Constraint (2) dramatically simplifies the implementation
|
||||
of prefix caching for mamba (wip). We need to take care
|
||||
of the interaction with chunked prefill in order to
|
||||
satisfy constraint (2).
|
||||
"""
|
||||
# TODO (tdoublep): This code could probably be optimized.
|
||||
cu_chunk_seqlen = []
|
||||
seq_idx = []
|
||||
last_chunk_indices = []
|
||||
seqlen_pos = 0
|
||||
|
||||
for req_idx in range(num_prefills):
|
||||
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
|
||||
this_new_tokens = (
|
||||
query_start_loc_p_cpu[req_idx + 1].item()
|
||||
- query_start_loc_p_cpu[req_idx].item()
|
||||
)
|
||||
|
||||
# if computed tokens are not chunk-aligned, use the first
|
||||
# chunk to finish it off
|
||||
if this_num_computed % chunk_size != 0:
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
# how many tokens to finish the chunk?
|
||||
chunk_len = (
|
||||
cdiv(this_num_computed, chunk_size) * chunk_size - this_num_computed
|
||||
)
|
||||
# we can only use at most this_new_tokens
|
||||
chunk_len = min(chunk_len, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
n_chunks = cdiv(this_new_tokens, chunk_size)
|
||||
for chunk in range(n_chunks):
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
chunk_len = min(chunk_size, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
assert this_new_tokens == 0
|
||||
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
|
||||
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
|
||||
return cu_chunk_seqlen, seq_idx, last_chunk_indices
|
||||
|
||||
def _build_chunk_metadata_tensors(
|
||||
self,
|
||||
chunk_size: int,
|
||||
common: M,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute chunk metadata and return as device tensors.
|
||||
Returns (cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p).
|
||||
"""
|
||||
num_reqs = common.num_reqs
|
||||
num_prefills = common.num_prefills
|
||||
num_decode_tokens = common.num_decode_tokens
|
||||
|
||||
num_computed_tokens_cpu = (
|
||||
common_attn_metadata.compute_num_computed_tokens().cpu()
|
||||
)
|
||||
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
query_start_loc_p_cpu = (
|
||||
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
|
||||
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
|
||||
chunk_size,
|
||||
num_prefills,
|
||||
num_computed_tokens_p_cpu,
|
||||
query_start_loc_p_cpu,
|
||||
)
|
||||
|
||||
device = common_attn_metadata.query_start_loc.device
|
||||
cu_chunk_seqlen_p = torch.as_tensor(
|
||||
cu_chunk_seqlen,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
seq_idx_p = torch.as_tensor(
|
||||
seq_idx,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
last_chunk_indices_p = torch.as_tensor(
|
||||
last_chunk_indices,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
return cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p
|
||||
|
||||
def _compute_prefix_caching_block_indices(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
|
||||
@@ -191,6 +191,8 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
max_decode_seq_len: int = 0,
|
||||
use_cuda_graph: bool = False,
|
||||
) -> FlashAttnMLADecodeMetadata:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
max_query_len = query_lens_cpu.max().item()
|
||||
@@ -239,12 +241,14 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
metadata = FlashAttnMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
query_start_loc=query_start_loc_device,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
max_num_splits=max_num_splits,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
@@ -156,6 +156,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
max_decode_seq_len: int = 0,
|
||||
use_cuda_graph: bool = False,
|
||||
) -> FlashMLADecodeMetadata:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
# we use the max but all should be the same due to uniform length requirement
|
||||
@@ -179,8 +181,10 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
return FlashMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,11 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
get_mla_dims,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
LinearBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
@@ -37,13 +42,17 @@ from vllm.v1.attention.backends.utils import (
|
||||
)
|
||||
from vllm.v1.attention.ops.flashmla import (
|
||||
FlashMLASchedMeta,
|
||||
flash_mla_sparse_fwd,
|
||||
flash_mla_sparse_prefill,
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.workspace import current_workspace_manager
|
||||
|
||||
import functools
|
||||
from vllm import envs
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import scaled_dequantize
|
||||
import ixformer.inference.functions as ixf_ops
|
||||
import numpy as np
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.deepseek_v2 import Indexer
|
||||
|
||||
@@ -74,7 +83,15 @@ structured as:
|
||||
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
|
||||
part is not quantized for accuracy.
|
||||
"""
|
||||
|
||||
def dynamic_per_batched_tensor_quant(
|
||||
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
|
||||
):
|
||||
DTYPE_MAX = torch.finfo(dtype).max
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
|
||||
scale = DTYPE_MAX / amax
|
||||
x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
|
||||
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
||||
|
||||
class FlashMLASparseBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
@@ -558,6 +575,11 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.kv_lora_rank: int = mla_args["kv_lora_rank"]
|
||||
self.qk_nope_head_dim = mla_args["qk_nope_head_dim"]
|
||||
self.qk_rope_head_dim = mla_args["qk_rope_head_dim"]
|
||||
self.qk_head_dim = mla_args["qk_head_dim"]
|
||||
self.v_head_dim = mla_args["v_head_dim"]
|
||||
self.kv_b_proj = mla_args["kv_b_proj"]
|
||||
self.softmax_scale = scale
|
||||
assert indexer is not None
|
||||
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
|
||||
@@ -580,6 +602,65 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
(self.prefill_workspace_shape, torch.bfloat16)
|
||||
)
|
||||
)
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
def get_layer_weight(layer):
|
||||
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||
for attr in WEIGHT_NAMES:
|
||||
if hasattr(layer, attr):
|
||||
return getattr(layer, attr)
|
||||
raise AttributeError(
|
||||
f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}."
|
||||
)
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if layer.quant_method is not None and not isinstance(
|
||||
layer.quant_method, UnquantizedLinearMethod
|
||||
):
|
||||
# NOTE: This should only be used offline, since it's O(N^3)
|
||||
eye = torch.eye(
|
||||
layer.input_size_per_partition,
|
||||
dtype=act_dtype,
|
||||
device=get_layer_weight(layer).device,
|
||||
)
|
||||
dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
|
||||
del eye
|
||||
# standardize to (output, input)
|
||||
return dequant_weights.T
|
||||
return layer.weight
|
||||
|
||||
# we currently do not have quantized bmm's which are needed for
|
||||
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
|
||||
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
), (
|
||||
f"{kv_b_proj_weight.shape=}, "
|
||||
f"{self.kv_lora_rank=}, "
|
||||
f"{self.num_heads=}, "
|
||||
f"{self.qk_nope_head_dim=}, "
|
||||
f"{self.v_head_dim=}"
|
||||
)
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
||||
)
|
||||
self.W_UV = W_UV
|
||||
self.W_UK = W_UK
|
||||
# self.W_UK_T = W_UK.permute(1, 2, 0)
|
||||
|
||||
def _v_up_proj(self, x: torch.Tensor):
|
||||
|
||||
return torch.einsum("bnl,lnv->bnv", x, self.W_UV)
|
||||
def _k_up_proj(self, q_nope):
|
||||
|
||||
return torch.einsum("bnp,lnp->bnl", q_nope, self.W_UK).view(-1, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
def _forward_bf16_kv(
|
||||
self,
|
||||
@@ -590,12 +671,11 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
) -> torch.Tensor:
|
||||
# Convert per-request indices to global slots (decode) or workspace
|
||||
# offsets (prefill).
|
||||
topk_indices = triton_convert_req_index_to_global_index(
|
||||
topk_indices = ops.dsa_convert_req_index_to_global_index(
|
||||
attn_metadata.req_id_per_token,
|
||||
attn_metadata.block_table,
|
||||
topk_indices,
|
||||
BLOCK_SIZE=attn_metadata.block_size,
|
||||
NUM_TOPK_TOKENS=topk_indices.shape[1],
|
||||
attn_metadata.block_size,
|
||||
)
|
||||
|
||||
return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices)
|
||||
@@ -790,22 +870,10 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
-1, 1, kv_c_and_k_pe_cache.shape[-1]
|
||||
)
|
||||
|
||||
# NOTE(Chen): kernel requires num_local_head to be a multiple of
|
||||
# 64 on hopper and 128 on blackwell
|
||||
if self.num_heads % self.prefill_padding != 0:
|
||||
assert self.prefill_padding % self.num_heads == 0
|
||||
logger.warning_once(
|
||||
f"Padding num_heads from {self.num_heads} to "
|
||||
f"{self.prefill_padding} for BF16 sparse prefill kernel"
|
||||
)
|
||||
q_padded = q.new_empty((q.shape[0], self.prefill_padding, q.shape[2]))
|
||||
q_padded[:, : self.num_heads, :] = q
|
||||
q = q_padded
|
||||
|
||||
topk_indices = topk_indices.view(num_tokens, 1, -1)
|
||||
output = flash_mla_sparse_fwd(
|
||||
output = flash_mla_sparse_prefill(
|
||||
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
|
||||
)[0]
|
||||
)
|
||||
output = output[:, : self.num_heads, :]
|
||||
return output
|
||||
|
||||
@@ -843,5 +911,5 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
attn_out = self._forward_fp8_kv_separate_prefill_decode(
|
||||
q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
|
||||
)
|
||||
|
||||
return attn_out, None
|
||||
|
||||
return attn_out
|
||||
|
||||
@@ -8,7 +8,11 @@ import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, has_deep_gemm
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_paged_mqa_logits_metadata,
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
@@ -21,6 +25,7 @@ from vllm.v1.attention.backends.utils import (
|
||||
split_decodes_and_prefills,
|
||||
split_prefill_chunks,
|
||||
)
|
||||
from vllm.v1.worker.cp_utils import get_total_cp_world_size
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -68,11 +73,15 @@ class DeepseekV32IndexerPrefillChunkMetadata:
|
||||
cu_seqlen_ks: torch.Tensor
|
||||
cu_seqlen_ke: torch.Tensor
|
||||
cu_seq_lens: torch.Tensor
|
||||
cu_seqlens_q: torch.Tensor
|
||||
token_to_seq: torch.Tensor
|
||||
total_seq_lens: int
|
||||
token_start: int
|
||||
token_end: int
|
||||
num_reqs: int
|
||||
max_context_len: int
|
||||
max_q_len: int # Maximum query length for dsa_indexer_mqa_logits_with_blocks
|
||||
max_kv_len: int # Maximum key-value length for dsa_indexer_mqa_logits_with_blocks
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -86,9 +95,16 @@ class DeepSeekV32IndexerDecodeMetadata:
|
||||
seq_lens: torch.Tensor
|
||||
decode_lens: torch.Tensor
|
||||
requires_padding: bool
|
||||
schedule_metadata: torch.Tensor
|
||||
# schedule_metadata: torch.Tensor
|
||||
use_large_context_topk: bool
|
||||
offsets: torch.Tensor | None # Precomputed offsets for speculative decoding
|
||||
cu_seqlen_ks: torch.Tensor
|
||||
cu_seqlen_ke: torch.Tensor
|
||||
cu_seqlens_kv: torch.Tensor
|
||||
cu_seqlens_q: torch.Tensor
|
||||
max_context_len: int
|
||||
max_q_len: int
|
||||
max_kv_len: int
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -211,20 +227,39 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
if self.vllm_config.speculative_config
|
||||
else 0
|
||||
)
|
||||
if self.num_speculative_tokens > 1:
|
||||
raise ValueError(
|
||||
"Sparse MLA only supports "
|
||||
"num_speculative_tokens <= 1 because the DeepGEMM "
|
||||
"fp8_paged_mqa_logits kernel does not support next_n > 2. "
|
||||
f"Got num_speculative_tokens={self.num_speculative_tokens}."
|
||||
)
|
||||
self.reorder_batch_threshold += self.num_speculative_tokens
|
||||
|
||||
sm_count = num_compute_units(self.device.index)
|
||||
self.num_sms = sm_count
|
||||
|
||||
self.decode_lens_buffer = torch.empty(
|
||||
(scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device
|
||||
(scheduler_config.max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Pre-allocated buffers for flattening (spec decode).
|
||||
self.arange_buffer = torch.arange(
|
||||
scheduler_config.max_num_seqs * (1 + self.num_speculative_tokens),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.expanded_seq_lens_buffer = torch.zeros(
|
||||
(scheduler_config.max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
max_num_blocks_per_req = cdiv(
|
||||
self.vllm_config.model_config.max_model_len,
|
||||
self.kv_cache_spec.block_size * get_total_cp_world_size(),
|
||||
)
|
||||
self.expanded_block_table_buffer = torch.zeros(
|
||||
(
|
||||
scheduler_config.max_num_batched_tokens,
|
||||
max_num_blocks_per_req,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# See: DeepGMM/csrc/apis/attention.hpp
|
||||
@@ -260,18 +295,88 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
.to(torch.int32)
|
||||
.to(self.device)
|
||||
)
|
||||
cu_seqlens_q = prefill_query_start_loc.to(torch.int32).to(self.device)
|
||||
max_context_len = seq_lens_cpu[reqs_start:reqs_end].max().item()
|
||||
# max_q_len is the maximum query length among all batches in this chunk
|
||||
# prefill_query_start_loc is cumsum of lengths with shape [batch+1]
|
||||
max_q_len = (prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]).max().item()
|
||||
return DeepseekV32IndexerPrefillChunkMetadata(
|
||||
cu_seqlen_ks=cu_seqlen_ks,
|
||||
cu_seqlen_ke=cu_seqlen_ke,
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
token_to_seq=token_to_seq,
|
||||
total_seq_lens=total_seq_lens,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
block_table=block_table[reqs_start:reqs_end],
|
||||
token_start=token_start,
|
||||
token_end=token_end,
|
||||
num_reqs=reqs_end - reqs_start,
|
||||
max_context_len=max_context_len,
|
||||
max_q_len=max_q_len,
|
||||
max_kv_len=max_context_len
|
||||
)
|
||||
|
||||
def build_decode_metadata(
|
||||
self, common_attn_metadata, num_decodes, decode_lens, use_large_context_topk, offsets
|
||||
):
|
||||
decode_lens_cpu = torch.diff(
|
||||
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
|
||||
)
|
||||
assert (
|
||||
decode_lens_cpu.max().item()
|
||||
== decode_lens_cpu.min().item()
|
||||
== 1
|
||||
), "Only support single token decode in dsa_indexer backend"
|
||||
|
||||
# Calculate decode metadata parameters
|
||||
seq_lens_decode = common_attn_metadata.seq_lens_cpu[:num_decodes]
|
||||
max_context_len = seq_lens_decode.max().item()
|
||||
max_kv_len = max_context_len
|
||||
max_q_len = 1 # Single token decode
|
||||
|
||||
# Create cu_seqlens_q: cumulative sum of query lengths (all 1s)
|
||||
cu_seqlens_q = torch.arange(
|
||||
num_decodes + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
# Create cu_seqlens_kv and related tensors using kv_spans_from_batches
|
||||
decode_query_start_loc = torch.arange(
|
||||
num_decodes + 1, dtype=torch.long
|
||||
)
|
||||
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
|
||||
decode_query_start_loc, seq_lens_decode, self.device
|
||||
)
|
||||
|
||||
cu_seqlens_kv = torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=torch.int32, device=self.device),
|
||||
torch.cumsum(seq_lens_decode.to(self.device), dim=0)
|
||||
.to(torch.int32),
|
||||
]
|
||||
)
|
||||
|
||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
block_table=common_attn_metadata.block_table_tensor[
|
||||
:num_decodes, ...
|
||||
],
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
decode_lens=decode_lens,
|
||||
requires_padding=(
|
||||
decode_lens_cpu.max() > decode_lens_cpu.min()
|
||||
).item(),
|
||||
use_large_context_topk=use_large_context_topk,
|
||||
offsets=offsets,
|
||||
cu_seqlen_ks=cu_seqlen_ks,
|
||||
cu_seqlen_ke=cu_seqlen_ke,
|
||||
cu_seqlens_kv=cu_seqlens_kv,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_context_len=max_context_len,
|
||||
max_q_len=max_q_len,
|
||||
max_kv_len=max_kv_len,
|
||||
# schedule_metadata=self.scheduler_metadata_buffer,
|
||||
)
|
||||
return decode_metadata
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
@@ -323,45 +428,103 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
|
||||
)
|
||||
|
||||
# Use CPU to avoid GPU sync; breaking async scheduling
|
||||
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
|
||||
|
||||
# Decide which top-k kernel to use based on batch size and sequence length
|
||||
batch_size = num_decodes
|
||||
_is_large_context = common_attn_metadata.max_seq_len > 8192
|
||||
|
||||
# Decision logic based on micro-benchmark results:
|
||||
# - large_context_topk wins for batch <= 128 and seq_len > 8K
|
||||
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
|
||||
use_large_context_topk = batch_size <= 128 and _is_large_context
|
||||
|
||||
next_n = 1 + self.num_speculative_tokens
|
||||
if next_n > 1:
|
||||
offsets = torch.arange(next_n, device=self.device, dtype=torch.int32)
|
||||
else:
|
||||
offsets = None
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
||||
|
||||
# DeepGEMM is required for the paged MQA logits on CUDA devices
|
||||
if current_platform.is_cuda() and has_deep_gemm():
|
||||
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||
seq_lens, self.kv_cache_spec.block_size, self.num_sms
|
||||
)
|
||||
block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]
|
||||
|
||||
# Padded CUDA graph requests have block_table entries of -1.
|
||||
# Clamp to 0 to prevent OOB access in the DeepGEMM kernel.
|
||||
# This is safe because padded requests have seq_lens=0, so the
|
||||
# kernel produces no meaningful output for those rows.
|
||||
block_table.clamp_(min=0)
|
||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
block_table=block_table,
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
decode_lens=decode_lens,
|
||||
requires_padding=requires_padding,
|
||||
schedule_metadata=self.scheduler_metadata_buffer,
|
||||
use_large_context_topk=use_large_context_topk,
|
||||
offsets=offsets,
|
||||
|
||||
max_decode_len = int(decode_lens_cpu.max().item())
|
||||
if max_decode_len > 1:
|
||||
# Flatten multi-token decode requests into single-token
|
||||
# batch entries, expanding seq_lens and block tables so
|
||||
# the kernel always sees next_n=1.
|
||||
|
||||
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
|
||||
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
|
||||
# The context lengths are therefore
|
||||
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
|
||||
|
||||
# 3 + 1 + 4 + 0 = 8
|
||||
actual_expanded = int(decode_lens_cpu.sum().item())
|
||||
|
||||
# [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8]
|
||||
expanded_base = torch.repeat_interleave(
|
||||
seq_lens - decode_lens, decode_lens
|
||||
)
|
||||
|
||||
# [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4]
|
||||
expanded_starts = torch.repeat_interleave(
|
||||
common_attn_metadata.query_start_loc[:num_decodes], decode_lens
|
||||
)
|
||||
|
||||
# [0, 1, 2, 0, 0, 1, 2, 3]
|
||||
positions_within = (
|
||||
self.arange_buffer[:actual_expanded] - expanded_starts
|
||||
)
|
||||
|
||||
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
|
||||
self.expanded_seq_lens_buffer[:actual_expanded] = (
|
||||
expanded_base + positions_within + 1
|
||||
)
|
||||
self.expanded_seq_lens_buffer[actual_expanded:] = 0
|
||||
seq_lens = self.expanded_seq_lens_buffer[:num_decode_tokens]
|
||||
|
||||
# Give each of the flattened entries the same block table row as the
|
||||
# original request.
|
||||
self.expanded_block_table_buffer[:actual_expanded] = (
|
||||
torch.repeat_interleave(block_table, decode_lens, dim=0)
|
||||
)
|
||||
if actual_expanded < num_decode_tokens:
|
||||
self.expanded_block_table_buffer[
|
||||
actual_expanded:num_decode_tokens, 0
|
||||
] = 0
|
||||
block_table = self.expanded_block_table_buffer[:num_decode_tokens]
|
||||
|
||||
# All reqs now have decode_len=1
|
||||
self.decode_lens_buffer[:num_decode_tokens] = 1
|
||||
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
|
||||
offsets = None
|
||||
batch_size = num_decode_tokens
|
||||
else:
|
||||
next_n = 1 + self.num_speculative_tokens
|
||||
if next_n > 1:
|
||||
offsets = torch.arange(
|
||||
next_n, device=self.device, dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
offsets = None
|
||||
batch_size = num_decodes
|
||||
|
||||
# DeepGEMM is required for the paged MQA logits on CUDA devices
|
||||
if current_platform.is_cuda() and is_deep_gemm_supported():
|
||||
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||
seq_lens,
|
||||
self.kv_cache_spec.block_size,
|
||||
self.num_sms,
|
||||
)
|
||||
|
||||
# Decide which top-k kernel to use based on batch size and sequence length
|
||||
# Decision logic based on micro-benchmark results:
|
||||
# - large_context_topk wins for batch <= 128 and seq_len > 8K
|
||||
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
|
||||
_is_large_context = common_attn_metadata.max_seq_len > 8192
|
||||
use_large_context_topk = batch_size <= 128 and _is_large_context
|
||||
|
||||
# decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
# block_table=block_table,
|
||||
# seq_lens=seq_lens,
|
||||
# decode_lens=decode_lens,
|
||||
# requires_padding=False,
|
||||
# # schedule_metadata=self.scheduler_metadata_buffer,
|
||||
# use_large_context_topk=use_large_context_topk,
|
||||
# offsets=offsets,
|
||||
# )
|
||||
decode_metadata = self.build_decode_metadata(
|
||||
common_attn_metadata, num_decodes, decode_lens, use_large_context_topk, offsets
|
||||
)
|
||||
|
||||
attn_metadata = DeepseekV32IndexerMetadata(
|
||||
|
||||
@@ -115,6 +115,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
max_decode_seq_len: int = 0,
|
||||
use_cuda_graph: bool = False,
|
||||
) -> AiterMLADecodeMetadata:
|
||||
# kernel block size is always 1, although the kv block size is not 1.
|
||||
device = self.device
|
||||
@@ -170,11 +172,13 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
attn_metadata = AiterMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
paged_kv_indptr=paged_kv_indptr,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
qo_indptr=qo_indptr,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
max_qo_len=max_qo_len,
|
||||
attn_out_dtype=self.decode_attn_out_dtype,
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionLayer,
|
||||
@@ -22,20 +23,19 @@ from vllm.v1.attention.backend import (
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
|
||||
import ixformer.inference.functions as ixf_ops
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TritonMLABackend(MLACommonBackend):
|
||||
# supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
# supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
# "auto",
|
||||
# "bfloat16",
|
||||
# ]
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@@ -120,10 +120,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
# layer: AttentionLayer,
|
||||
k_c_normed: torch.Tensor |None = None,
|
||||
k_pe: torch.Tensor |None = None,
|
||||
kv_c_and_k_pe_cache_scale: torch.Tensor |None = None,
|
||||
k_c_normed: torch.Tensor | None,
|
||||
k_pe: torch.Tensor | None,
|
||||
kv_c_and_k_pe_cache_scale: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
@@ -136,7 +135,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
q_nope = q_nope.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
q = get_dcp_group().all_gather(q, dim=1)
|
||||
@@ -147,7 +146,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
device=q_nope.device)
|
||||
if envs.VLLM_USE_INT8_MLA:
|
||||
q_int8, q_scale = ops.quant_kv(q)
|
||||
attn_out, softmax_lse = ixf_ops.ref_vllm_paged_attention_mla_int8(
|
||||
attn_out, softmax_lse = ixf_ops.vllm_paged_attention_mla_int8(
|
||||
o,
|
||||
q_int8,
|
||||
q_scale,
|
||||
@@ -160,7 +159,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
return_softmax_lse=True
|
||||
)
|
||||
else:
|
||||
attn_out, softmax_lse = ixf_ops.ref_vllm_paged_attention_mla(
|
||||
attn_out, softmax_lse = ixf_ops.vllm_paged_attention_mla(
|
||||
output=o,
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
@@ -170,12 +169,12 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
return_softmax_lse=True)
|
||||
return attn_out, softmax_lse
|
||||
|
||||
|
||||
o = torch.empty(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
device=q_nope.device)
|
||||
|
||||
if envs.VLLM_USE_INT8_MLA:
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
@@ -193,18 +192,30 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
attn_metadata.decode.use_cuda_graph
|
||||
)
|
||||
else:
|
||||
# fused q concat & cache write
|
||||
ixf_ops.vllm_paged_attention_mla_fused(
|
||||
output=o,
|
||||
q_nope=q_nope,
|
||||
q_pe=q_pe.contiguous(),
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
k_c_normed=k_c_normed,
|
||||
k_pe=k_pe,
|
||||
use_cuda_graph=decode_meta.use_cuda_graph
|
||||
)
|
||||
if k_c_normed is None:
|
||||
q = torch.cat([q_nope, q_pe.contiguous()], dim=-1)
|
||||
ixf_ops.vllm_paged_attention_mla(
|
||||
output=o,
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
use_cuda_graph=decode_meta.use_cuda_graph,
|
||||
)
|
||||
else:
|
||||
ixf_ops.vllm_paged_attention_mla_fused(
|
||||
output=o,
|
||||
q_nope=q_nope.contiguous(),
|
||||
q_pe=q_pe.contiguous(),
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
k_c_normed=k_c_normed,
|
||||
k_pe=k_pe,
|
||||
use_cuda_graph=decode_meta.use_cuda_graph,
|
||||
)
|
||||
return self._v_up_proj(o), None
|
||||
|
||||
@@ -55,6 +55,16 @@ class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
|
||||
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
|
||||
return RocmAttentionMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""RocmAiterUnifiedAttention supports all attention types."""
|
||||
return attn_type in (
|
||||
AttentionType.DECODER,
|
||||
AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY,
|
||||
AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
|
||||
|
||||
class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
@@ -143,6 +153,19 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
# Handle encoder attention differently - no KV cache needed
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
# For encoder attention,
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return self._forward_encoder_attention(
|
||||
query[:num_actual_tokens],
|
||||
key[:num_actual_tokens],
|
||||
value[:num_actual_tokens],
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
layer,
|
||||
)
|
||||
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
@@ -195,6 +218,10 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
):
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
# For encoder attention,
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
@@ -224,6 +251,10 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
layer_slot_mapping: torch.Tensor,
|
||||
):
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
# For encoder attention,
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
flash_layout = True
|
||||
|
||||
|
||||
@@ -182,7 +182,7 @@ class RocmAttentionBackend(AttentionBackend):
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
return [32, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
@@ -205,6 +205,16 @@ class RocmAttentionBackend(AttentionBackend):
|
||||
def get_impl_cls() -> type["RocmAttentionImpl"]:
|
||||
return RocmAttentionImpl
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""RocmAttention supports all attention types."""
|
||||
return attn_type in (
|
||||
AttentionType.DECODER,
|
||||
AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY,
|
||||
AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
@@ -244,6 +254,7 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
kv_sharing_target_layer_name: int | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
self.attn_type = attn_type
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
@@ -266,11 +277,6 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
|
||||
RocmAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention is not implemented for RocmAttentionImpl"
|
||||
)
|
||||
|
||||
self.fp8_dtype = current_platform.fp8_dtype()
|
||||
|
||||
self.sinks = sinks
|
||||
@@ -281,6 +287,54 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
f"num_heads: {num_heads}."
|
||||
)
|
||||
|
||||
def _forward_encoder_attention(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
layer: torch.nn.Module,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass for encoder attention without KV cache.
|
||||
|
||||
Args:
|
||||
query: shape = [num_encoder_tokens, num_heads, head_size]
|
||||
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
|
||||
output: shape = [num_encoder_tokens, num_heads, head_size]
|
||||
attn_metadata: Encoder attention metadata
|
||||
layer: The attention layer
|
||||
"""
|
||||
# For encoder attention, process FP8 quantization if needed
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError(
|
||||
"quantization is not supported for encoder attention"
|
||||
)
|
||||
|
||||
# Use encoder-specific metadata for sequence information
|
||||
query_start_loc = attn_metadata.query_start_loc
|
||||
seq_lens = attn_metadata.seq_lens
|
||||
max_query_len = attn_metadata.max_query_len
|
||||
|
||||
# Call flash attention directly on Q, K, V tensors
|
||||
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
|
||||
|
||||
context_attention_fwd(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
o=output,
|
||||
b_start_loc=query_start_loc,
|
||||
b_seq_len=seq_lens,
|
||||
max_input_len=max_query_len,
|
||||
is_causal=False,
|
||||
softmax_scale=self.scale,
|
||||
sliding_window_q=self.sliding_window[0],
|
||||
sliding_window_k=self.sliding_window[1],
|
||||
)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -330,6 +384,16 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
return self._forward_encoder_attention(
|
||||
query[:num_actual_tokens],
|
||||
key[:num_actual_tokens],
|
||||
value[:num_actual_tokens],
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
layer,
|
||||
)
|
||||
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size
|
||||
)
|
||||
@@ -380,6 +444,8 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
):
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
return
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size
|
||||
)
|
||||
@@ -432,6 +498,8 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
layer_slot_mapping: torch.Tensor,
|
||||
):
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
return
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache,
|
||||
layer.num_kv_heads, # type: ignore[attr-defined]
|
||||
|
||||
Reference in New Issue
Block a user