Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -4,7 +4,7 @@
import copy
from dataclasses import dataclass
from typing import ClassVar
from typing import ClassVar, Optional, Union, List
import numpy as np
import torch
@@ -23,15 +23,15 @@ from vllm.v1.attention.backends.fa_utils import (
is_flash_attn_varlen_func_available,
)
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from ixformer.contrib.vllm_flash_attn import merge_attn_states
if is_flash_attn_varlen_func_available():
from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_sinks,
flash_attn_varlen_func,
flash_attn_with_kvcache,
# get_scheduler_metadata,
reshape_and_cache_flash,
flash_attn_varlen_int8_func
)
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
from vllm.config.cache import CacheDType
@@ -50,9 +50,12 @@ from vllm.v1.attention.backend import (
from vllm.v1.attention.backends.utils import (
get_dcp_local_seq_lens,
get_kv_cache_layout,
split_decodes_and_prefills,
split_decodes_and_prefills
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm import _custom_ops as ops
import vllm.envs as envs
import ixformer.inference.functions as ixf_ops
logger = init_logger(__name__)
@@ -63,23 +66,7 @@ class FlashAttentionBackend(AttentionBackend):
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
if (
model_config
and model_config.is_hybrid
and (
cache_config.mamba_ssm_cache_dtype == "float32"
or cache_config.mamba_cache_dtype == "float32"
)
):
# NOTE(tdoublep): while in principle, FA supports
# MultipleOf(16), these are the block sizes that do not
# suffer from the NaN propagation problem described here:
# https://github.com/Dao-AILab/flash-attention/issues/1974
return [16, 32, 64]
return [MultipleOf(16)]
return [16, 32, 64]
forward_includes_kv_cache_update: bool = False
@@ -120,7 +107,8 @@ class FlashAttentionBackend(AttentionBackend):
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
# return (2, num_blocks, block_size, num_kv_heads, head_size)
if envs.VLLM_ATTN_OPT_LEVEL == 2:
return (3, num_blocks, num_kv_heads, block_size, head_size)
return (2, num_blocks, num_kv_heads, block_size, head_size)
@staticmethod
@@ -139,7 +127,7 @@ class FlashAttentionBackend(AttentionBackend):
# (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
return (2, 4, 0, 1, 3, 5)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
stride_order = (0, 1, 2, 3, 4)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order
@@ -188,24 +176,22 @@ class FlashAttentionBackend(AttentionBackend):
if has_sink and device_capability < DeviceCapability(9, 0):
return "sink not supported on compute capability < 9.0"
return None
@dataclass
class FlashAttentionPrefillMetadata:
"""Prefill Specific Metadata"""
""" Prefill Specific Metadata """
block_table: torch.Tensor
query_start_loc: torch.Tensor
key_start_loc: torch.Tensor
max_query_len: int
@dataclass
class FlashAttentionDecodeMetadata:
block_table: torch.Tensor
query_start_loc: torch.Tensor
seq_lens: torch.Tensor
max_query_len: int
max_decode_seq_len: int
use_graph: bool
@dataclass
class FlashAttentionMetadata:
@@ -220,11 +206,12 @@ class FlashAttentionMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
key_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
num_decodes: int
num_decode_tokens: int
num_prefills: int
@@ -235,7 +222,6 @@ class FlashAttentionMetadata:
cu_prefix_query_lens: torch.Tensor | None
prefix_kv_lens: torch.Tensor | None
suffix_kv_lens: torch.Tensor | None
cu_prefix_kv_lens: torch.Tensor | None
cu_suffix_kv_lens: torch.Tensor | None
@@ -247,7 +233,7 @@ class FlashAttentionMetadata:
scheduler_metadata: torch.Tensor | None = None
prefix_scheduler_metadata: torch.Tensor | None = None
max_num_splits: int = 0
prefill: FlashAttentionPrefillMetadata | None = None
decode: FlashAttentionDecodeMetadata | None = None
@@ -291,7 +277,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
else AttentionCGSupport.UNIFORM_BATCH
)
supports_update_block_table: bool = True
reorder_batch_threshold: ClassVar[int] = 1
@classmethod
@@ -316,6 +302,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.compilation_config = vllm_config.compilation_config
self.attention_config = vllm_config.attention_config
self.decode_use_graph = vllm_config.compilation_config.cudagraph_mode.decode_use_graph()
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config
)
@@ -325,7 +312,6 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.block_size = kv_cache_spec.block_size
self.max_num_splits = 0 # No upper bound on the number of splits.
# self.aot_schedule = get_flash_attn_version() == 3
self.aot_schedule = False
try:
@@ -346,6 +332,9 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
# Align decode/prefill split threshold with speculative decode query length
# when backend supports treating spec requests as decode.
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
if self.use_full_cuda_graph and self.aot_schedule:
# FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4
@@ -388,15 +377,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
max_query_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
key_start_loc = common_attn_metadata.key_start_loc
seq_lens = common_attn_metadata.seq_lens
seq_lens_np = common_attn_metadata.seq_lens_np
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
causal = common_attn_metadata.causal
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata)
)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
@@ -467,11 +458,11 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
dcp_context_kv_lens = None
cu_prefix_query_lens = None
cu_prefix_kv_lens = None
cu_suffix_kv_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
prefix_scheduler_metadata = None
cu_prefix_kv_lens = None
cu_suffix_kv_lens = None
if self.dcp_world_size > 1:
query_kv_lens = query_start_loc[1:] - query_start_loc[:-1]
@@ -507,11 +498,11 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
prefix_kv_lens = torch.tensor(
[common_prefix_len], dtype=torch.int32, device=self.device
)
# Use GPU tensor directly - no CPU sync needed
suffix_kv_lens = seq_lens[:num_reqs] - common_prefix_len
cu_prefix_kv_lens = torch.tensor([0, common_prefix_len],
dtype=torch.int32,
device=self.device)
# Use GPU tensor directly - no CPU sync needed
suffix_kv_lens = seq_lens[:num_reqs] - common_prefix_len
cu_suffix_kv_lens = torch.tensor([0,] + suffix_kv_lens.tolist(),
dtype=torch.int32,
@@ -542,7 +533,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
causal=causal,
)
# For FA3 + full cudagraph
max_num_splits = 0
max_num_splits = 0
if self.use_full_cuda_graph and scheduler_metadata is not None:
n = scheduler_metadata.shape[0]
self.scheduler_metadata[:n] = scheduler_metadata
@@ -552,50 +543,59 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
# output buffer.
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]
if num_actual_tokens <= self.max_cudagraph_size:
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
prefill_metadata = None
if num_prefills > 0:
reqs_start = num_decodes
prefill_query_start_loc = (
query_start_loc[reqs_start:] - query_start_loc[reqs_start]
)
prefill_key_start_loc = (
query_start_loc[reqs_start:] - query_start_loc[reqs_start]
)
reqs_start = num_decodes # prefill_start
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
prefill_key_start_loc = key_start_loc[
reqs_start:] - key_start_loc[reqs_start]
prefill_metadata = FlashAttentionPrefillMetadata(
block_table=block_table_tensor[reqs_start:, ...],
query_start_loc=prefill_query_start_loc,
key_start_loc=prefill_key_start_loc,
max_query_len=max_query_len,
)
block_table=block_table_tensor[reqs_start:, ...],
query_start_loc=prefill_query_start_loc,
key_start_loc=prefill_key_start_loc,
max_query_len=max_query_len,
)
decode_metadata = None
if num_decodes > 0:
reqs_start = num_decodes
reqs_start = num_decodes # prefill_start
decode_query_start_loc = query_start_loc[: reqs_start + 1]
decode_query_lens = (
decode_query_start_loc[1:] - decode_query_start_loc[:-1]
)
decode_metadata = FlashAttentionDecodeMetadata(
block_table=block_table_tensor[:reqs_start, ...],
query_start_loc=decode_query_start_loc,
seq_lens=seq_lens[:reqs_start],
max_decode_seq_len=torch.max(seq_lens[:reqs_start]).item(),
max_query_len=decode_query_lens.max().item(),
max_decode_seq_len=np.max(seq_lens_np[:reqs_start]).item(),
use_graph=num_prefills==0 and self.decode_use_graph
)
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
key_start_loc=key_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
max_dcp_context_kv_len=max_dcp_context_kv_len,
dcp_context_kv_lens=dcp_context_kv_lens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
max_dcp_context_kv_len=max_dcp_context_kv_len,
dcp_context_kv_lens=dcp_context_kv_lens,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
scheduler_metadata=scheduler_metadata,
@@ -607,8 +607,8 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
prefix_scheduler_metadata=prefix_scheduler_metadata,
max_num_splits=max_num_splits,
causal=causal,
prefill=prefill_metadata,
decode=decode_metadata,
prefill = prefill_metadata,
decode = decode_metadata,
)
return attn_metadata
@@ -621,6 +621,19 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
new_metadata = copy.copy(metadata)
new_metadata.block_table = blk_table
new_metadata.slot_mapping = slot_mapping
# Keep nested prefill/decode block tables in sync. Decode path consumes
# `attn_metadata.decode.block_table`, so updating only the top-level
# `block_table` is insufficient when metadata is reused across groups.
if metadata.decode is not None:
new_decode = copy.copy(metadata.decode)
reqs_start = metadata.num_decodes
new_decode.block_table = blk_table[:reqs_start, ...]
new_metadata.decode = new_decode
if metadata.prefill is not None:
new_prefill = copy.copy(metadata.prefill)
reqs_start = metadata.num_decodes
new_prefill.block_table = blk_table[reqs_start:, ...]
new_metadata.prefill = new_prefill
return new_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool:
@@ -667,7 +680,15 @@ class FlashAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()
self.vllm_flash_attn_version = get_flash_attn_version(
requires_alibi=alibi_slopes is not None,
head_size=head_size,
)
logger.info_once(
"Using FlashAttention version %s",
self.vllm_flash_attn_version,
scope="local",
)
# Cache the batch invariant result for use in forward passes
self.batch_invariant_enabled = vllm_is_batch_invariant()
@@ -677,6 +698,7 @@ class FlashAttentionImpl(AttentionImpl):
)
self.sinks = sinks
if self.sinks is not None:
assert flash_attn_supports_sinks(), (
"Sinks are only supported in FlashAttention 3"
@@ -687,6 +709,28 @@ class FlashAttentionImpl(AttentionImpl):
)
self.supports_quant_query_input = True
self.supports_per_head_quant_scales = (
self.vllm_flash_attn_version >= 3
if self.vllm_flash_attn_version is not None
else False
)
assert envs.VLLM_ATTN_OPT_LEVEL in [0, 1, 2], "VLLM_ATTN_OPT_LEVEL only support [0 for non-quant, 1 for I8Q_I8K_I8V, 2 for I8Q_I8K_F16V] now! but got {}".format(envs.VLLM_ATTN_OPT_LEVEL)
'''
quant_type = 0
attention:f16 qkv
cache:f16 kv cache
quant_type = 1
attention:int8q int8k int8v
cache:
int8 k cache && fp32 k cache scale
int8 v cache && fp32 v cache scale(load from file, dont update)
quant_type = 2
attention:int8q int8k fp16v
cache:
int8 k cache && fp32 k cache scale
fp16 v cache
'''
self.quant_type = int(envs.VLLM_ATTN_OPT_LEVEL)
def forward(
self,
@@ -698,7 +742,7 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: FlashAttentionMetadata,
output: torch.Tensor | None = None,
sqrt_alibi: bool = False,
kv_cache_scale: torch.Tensor | None = None,
kv_cache_scale: Union[torch.Tensor, List[torch.Tensor]] | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
@@ -711,6 +755,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
kv_cache_scale = [num_blocks, num_kv_heads, block_size] + [num_kv_heads, head_size]
Returns:
shape = [num_tokens, num_heads * head_size]
NOTE: FP8 quantization, flash-attn expect the size of
@@ -718,9 +763,9 @@ class FlashAttentionImpl(AttentionImpl):
We use torch's .expand() to avoid duplicating values
"""
assert output is not None, "Output tensor must be provided."
assert self.vllm_flash_attn_version is not None, (
"FlashAttention version not detected."
)
# assert self.vllm_flash_attn_version is not None, (
# "FlashAttention version not detected."
# )
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
@@ -729,13 +774,12 @@ class FlashAttentionImpl(AttentionImpl):
if attn_metadata is None:
# Profiling run.
# return output.fill_(0)
return output.fill_(0).view(-1, self.num_heads * self.head_size)
return output.view(-1, self.num_heads * self.head_size)
softmax_scale: float = self.scale
window_size = self.sliding_window
alibi_slopes: torch.Tensor | None = self.alibi_slopes
logits_soft_cap: float | None = self.logits_soft_cap
alibi_slopes: torch.Tensor = self.alibi_slopes
logits_soft_cap: float = self.logits_soft_cap
attn_type = self.attn_type
@@ -761,18 +805,140 @@ class FlashAttentionImpl(AttentionImpl):
output[:num_actual_tokens],
attn_metadata,
layer,
)
).view(-1, self.num_heads * self.head_size)
# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(0)
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
decode_only = has_decode and not has_prefill
num_decode_tokens = attn_metadata.num_decode_tokens
if self.quant_type == 0:
key_cache, value_cache = kv_cache.unbind(0)
elif self.quant_type == 1:
i8_key_cache, i8_value_cache = kv_cache.unbind(0)
num_blocks, num_kv_heads, block_size, head_size = i8_key_cache.shape
key_scale_cache, value_scale_cache = kv_cache_scale
assert key_scale_cache.shape == (num_blocks, num_kv_heads, block_size) and key_scale_cache.dtype == torch.float32, f"key_scale_cache.shape {key_scale_cache.shape} != (num_blocks, num_kv_heads, block_size) or key_scale_cache.dtype {key_scale_cache.dtype} != torch.float32"
assert value_scale_cache.shape == (num_kv_heads, head_size) and value_scale_cache.dtype == torch.float32, f"value_scale_cache.shape {value_scale_cache.shape} != (num_kv_heads, head_size) or value_scale_cache.dtype {value_scale_cache.dtype} != torch.float32"
value_cache_info = (i8_value_cache, value_scale_cache)
elif self.quant_type == 2:
# key_cache 是 f16value_cache 是 int8
i8_key_cache = kv_cache[0]
num_blocks, num_kv_heads, block_size, head_size = i8_key_cache.shape
value_cache = kv_cache[1:].view(query.dtype).reshape(num_blocks, num_kv_heads, block_size, head_size)
key_scale_cache = kv_cache_scale
value_cache_info = (value_cache, None)
decode_q = query[:num_decode_tokens]
prefill_q = query[num_decode_tokens:]
prefill_output = output[num_decode_tokens:]
decode_output = output[:num_decode_tokens]
if self.quant_type == 1:
if decode_only:
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
query, 2, transpose_scale=False
)
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
key, 2, transpose_scale=False
)
i8_value, _value_scale = ixf_ops.scaled_int8_quant_for_attn(
value, 0, transpose_scale=False, scale=value_cache_info[1]
)
else:
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
query, 2, transpose_scale=True
)
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
key, 2, transpose_scale=False
)
i8_value, _value_scale = ixf_ops.scaled_int8_quant_for_attn(
value, 0, transpose_scale=False, scale=value_cache_info[1]
)
elif self.quant_type == 2:
'''
origin key cache
num_blocks, num_kv_heads, block_size, head_size f16
reformat key cache
key_cache_i8 : num_blocks, num_kv_heads, block_size, head_size int8
key_scale_cache : num_blocks, num_kv_heads, block_size fp32
'''
if decode_only:
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
query, 2, transpose_scale=False
)
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
key, 2, transpose_scale=False
)
else:
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
query, 2, transpose_scale=True
)
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
key, 2, transpose_scale=False
)
else:
if layer.quant_manager is not None and layer.quant_manager.check_enable():
i8_value, value_scale = ixf_ops.scaled_int8_quant_for_attn(
value, 0, transpose_scale=False
)
layer.quant_manager.update_data(value_scale)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
if self.quant_type == 1:
if has_prefill:
ixf_ops.reshape_and_cache_flash_int8(
key=i8_key,
value=i8_value,
k_scale=key_scale,
key_cache=i8_key_cache,
value_cache=value_cache_info[0],
key_scale_cache=key_scale_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache_dtype="",
)
elif self.quant_type == 2:
if has_prefill:
ixf_ops.reshape_and_cache_flash_mix(
key=i8_key,
value=value,
k_scale=key_scale,
key_cache=i8_key_cache,
value_cache=value_cache_info[0],
key_scale_cache=key_scale_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache_dtype="",
)
else:
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
# queries are quantized in the attention layer
@@ -783,19 +949,6 @@ class FlashAttentionImpl(AttentionImpl):
value_cache = value_cache.view(dtype)
if not attn_metadata.use_cascade:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
q_descale = layer._q_scale.expand(descale_shape)
k_descale = layer._k_scale.expand(descale_shape)
v_descale = layer._v_scale.expand(descale_shape)
if self.dcp_world_size > 1:
self._forward_with_dcp(
query[:num_actual_tokens],
@@ -805,79 +958,140 @@ class FlashAttentionImpl(AttentionImpl):
value_cache,
output[:num_actual_tokens],
attn_metadata,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
return output.view(-1, self.num_heads * self.head_size)
else:
sliding_window_size = (
list(self.sliding_window)
if self.sliding_window is not None
else None
)
if has_prefill:
flash_attn_varlen_func(
q=prefill_q,
k=key_cache,
v=value_cache,
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
max_seqlen_q=attn_metadata.prefill.max_query_len,
max_seqlen_k=attn_metadata.max_query_len,
softmax_scale=softmax_scale,
causal=True,
window_size=sliding_window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
sqrt_alibi=sqrt_alibi,
sinks=self.sinks,
out=prefill_output,
block_table=attn_metadata.prefill.block_table,
)
# key = key[num_decode_tokens:]
# value = value[num_decode_tokens:]
# int8 attn
if self.quant_type > 0:
flash_attn_varlen_int8_func(
q=int8_query[num_decode_tokens:],
k=i8_key_cache,
v=value_cache_info[0],
q_scale=query_scale[:, num_decode_tokens:],
k_scale=key_scale_cache,
v_scale=value_cache_info[1],
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.prefill.key_start_loc,
max_seqlen_q=attn_metadata.prefill.max_query_len,
max_seqlen_k=attn_metadata.max_query_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
sqrt_alibi=sqrt_alibi,
out=prefill_output,
block_table=attn_metadata.prefill.block_table,
output_dtype=query.dtype
)
else:
flash_attn_varlen_func(
q=prefill_q,
k=key_cache,
v=value_cache,
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.prefill.key_start_loc,
max_seqlen_q=attn_metadata.prefill.max_query_len,
max_seqlen_k=attn_metadata.max_query_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
sqrt_alibi=sqrt_alibi,
sinks=self.sinks,
out=prefill_output,
block_table=attn_metadata.prefill.block_table,
)
if has_decode:
flash_attn_with_kvcache(
q=decode_q.unsqueeze(1),
k_cache=key_cache.contiguous(),
v_cache=value_cache.contiguous(),
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
softmax_scale=softmax_scale,
causal=True,
window_size=sliding_window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
use_sqrt_alibi=sqrt_alibi,
out=decode_output.unsqueeze(1),
max_context_len=attn_metadata.decode.max_decode_seq_len,
# sinks=self.sinks,
)
# for mtp + cuda graph
max_q_len = attn_metadata.decode.max_query_len if attn_metadata.decode is not None else attn_metadata.max_query_len
max_ct_len = attn_metadata.decode.max_decode_seq_len if attn_metadata.decode is not None else attn_metadata.max_seq_len
if self.quant_type in [1, 2]:
para_dict = dict(
output=decode_output,
query=int8_query[:num_decode_tokens],
key_cache=i8_key_cache,
query_scale=query_scale[:num_decode_tokens] if decode_only else query_scale[:, :num_decode_tokens].t().contiguous(),
key_scale_cache=key_scale_cache,
num_kv_heads=self.num_kv_heads,
scale=softmax_scale,
block_tables=attn_metadata.decode.block_table,
context_lens=attn_metadata.decode.seq_lens,
block_size=i8_key_cache.shape[-2],
softcap=logits_soft_cap,
alibi_slopes=alibi_slopes,
causal=True,
window_left=window_size[0],
window_right=window_size[1],
use_sqrt_alibi = sqrt_alibi,
use_cuda_graph=attn_metadata.decode.use_graph if decode_only else False,
max_context_len=max_ct_len,
# mtp
cu_query_lens=attn_metadata.decode.query_start_loc,
max_query_len=max_q_len,
)
if self.quant_type == 1:
para_dict.update(
dict(
value_cache=value_cache_info[0],
value_scale_cache=value_cache_info[1],
)
)
# for kv + k_scale write fusion
if decode_only:
para_dict.update(
dict(
save_key=i8_key[:num_decode_tokens],
save_value=i8_value[:num_decode_tokens],
save_key_scale=key_scale[:num_decode_tokens],
)
)
ixf_ops.vllm_paged_attention_int8(**para_dict)
elif self.quant_type == 2:
para_dict.update(
dict(
value_cache=value_cache,
)
)
if decode_only:
para_dict.update(
dict(
save_key=i8_key[:num_decode_tokens],
save_value=value[:num_decode_tokens].contiguous(),
save_key_scale=key_scale[:num_decode_tokens],
)
)
ixf_ops.vllm_paged_attention_mix(
**para_dict
)
else:
flash_attn_with_kvcache(
q=decode_q.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
max_query_len=max_q_len,
cu_query_lens=attn_metadata.decode.query_start_loc,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
use_sqrt_alibi=sqrt_alibi,
sinks=self.sinks,
out=decode_output.unsqueeze(1),
use_cuda_graph=attn_metadata.decode.use_graph,
max_context_len=max_ct_len
)
# Compute attention and update output up to `num_actual_tokens`.
return output.view(-1, self.num_heads * self.head_size)
# flash_attn_varlen_func(
# q=query[:num_actual_tokens],
# k=key_cache,
# v=value_cache,
# out=output[:num_actual_tokens],
# cu_seqlens_q=cu_seqlens_q,
# max_seqlen_q=max_seqlen_q,
# seqused_k=seqused_k,
# max_seqlen_k=max_seqlen_k,
# softmax_scale=self.scale,
# causal=attn_metadata.causal,
# alibi_slopes=self.alibi_slopes,
# window_size=sliding_window_size,
# block_table=block_table,
# softcap=self.logits_soft_cap,
# scheduler_metadata=scheduler_metadata,
# fa_version=self.vllm_flash_attn_version,
# q_descale=q_descale,
# k_descale=k_descale,
# v_descale=v_descale,
# num_splits=attn_metadata.max_num_splits,
# s_aux=self.sinks,
# )
# return output
# Cascade attention (rare case).
cascade_attention(
@@ -906,12 +1120,7 @@ class FlashAttentionImpl(AttentionImpl):
v_descale=layer._v_scale,
s_aux=self.sinks,
)
# return output
return (
output[:num_actual_tokens]
.contiguous()
.view(-1, self.num_heads * self.head_size)
)
return output.view(-1, self.num_heads * self.head_size)
def do_kv_cache_update(
self,
@@ -935,7 +1144,7 @@ class FlashAttentionImpl(AttentionImpl):
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
reshape_and_cache_flash(
ops.reshape_and_cache_flash(
key,
value,
key_cache,
@@ -959,9 +1168,9 @@ class FlashAttentionImpl(AttentionImpl):
k_descale: torch.Tensor | None = None,
v_descale: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.vllm_flash_attn_version is not None, (
"FlashAttention version not detected."
)
# assert self.vllm_flash_attn_version is not None, (
# "FlashAttention version not detected."
# )
cu_seqlens_q = attn_metadata.query_start_loc
max_seqlen_q = attn_metadata.max_query_len
@@ -969,27 +1178,22 @@ class FlashAttentionImpl(AttentionImpl):
query = query.contiguous()
query_across_dcp = get_dcp_group().all_gather(query, dim=1)
cu_dcp_kv_klens = attn_metadata.dcp_context_kv_lens.cumsum(dim=0, dtype=torch.int32)
new_tensor = torch.tensor([0],
device=attn_metadata.dcp_context_kv_lens.device,
dtype=attn_metadata.dcp_context_kv_lens.dtype)
cu_seqlens_k = torch.cat([new_tensor, cu_dcp_kv_klens])
sliding_window_size = (
list(self.sliding_window) if self.sliding_window is not None else None
)
cu_seqlens_k = torch.cat(
[
torch.zeros(1, device=cu_seqlens_q.device, dtype=cu_seqlens_q.dtype),
attn_metadata.dcp_context_kv_lens.cumsum(
dim=0, dtype=cu_seqlens_q.dtype
),
],
dim=0,
)
context_attn_out, context_lse = flash_attn_varlen_func(
q=query_across_dcp,
k=key_cache,
v=value_cache,
out=None,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
# seqused_k=attn_metadata.dcp_context_kv_lens,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=attn_metadata.max_dcp_context_kv_len,
softmax_scale=self.scale,
causal=False,
@@ -998,11 +1202,6 @@ class FlashAttentionImpl(AttentionImpl):
block_table=block_table,
softcap=self.logits_soft_cap,
return_softmax_lse=True,
# scheduler_metadata=attn_metadata.scheduler_metadata,
# fa_version=self.vllm_flash_attn_version,
# q_descale=q_descale,
# k_descale=k_descale,
# v_descale=v_descale,
)
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs(
@@ -1028,10 +1227,6 @@ class FlashAttentionImpl(AttentionImpl):
window_size=sliding_window_size,
softcap=self.logits_soft_cap,
return_softmax_lse=True,
# fa_version=self.vllm_flash_attn_version,
# q_descale=q_descale,
# k_descale=k_descale,
# v_descale=v_descale,
)
assert context_attn_out_cor.shape == query_attn_out.shape
assert context_lse_cor.shape == query_lse.shape
@@ -1040,7 +1235,7 @@ class FlashAttentionImpl(AttentionImpl):
context_lse_cor,
query_attn_out,
query_lse,
output,
output
)
def _forward_encoder_attention(
@@ -1062,9 +1257,9 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: Encoder attention metadata
layer: The attention layer
"""
assert self.vllm_flash_attn_version is not None, (
"FlashAttention version not detected."
)
# assert self.vllm_flash_attn_version is not None, (
# "FlashAttention version not detected."
# )
# For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"):
@@ -1101,18 +1296,9 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes,
window_size=sliding_window_size,
softcap=self.logits_soft_cap,
# fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
# num_splits=1 if self.batch_invariant_enabled else 0,
)
return (
output[: attn_metadata.num_actual_tokens]
.contiguous()
.view(-1, self.num_heads * self.head_size)
)
return output
def use_cascade_attention(
@@ -1203,8 +1389,6 @@ def cascade_attention(
cu_prefix_query_lens: torch.Tensor,
cu_prefix_kv_lens: torch.Tensor,
cu_suffix_kv_lens: torch.Tensor,
# prefix_kv_lens: torch.Tensor,
# suffix_kv_lens: torch.Tensor,
max_kv_len: int,
softmax_scale: float,
alibi_slopes: torch.Tensor | None,
@@ -1228,12 +1412,13 @@ def cascade_attention(
)
num_tokens = query.shape[0]
# block_size = key_cache.shape[-3]
block_size = key_cache.shape[-2]
assert common_prefix_len % block_size == 0
num_common_kv_blocks = common_prefix_len // block_size
assert num_common_kv_blocks > 0
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
assert q_descale is None or q_descale==1, f"q_descale is not None, q_descale: {q_descale}"
assert k_descale is None or k_descale==1, f"k_descale is not None, k_descale: {k_descale}"
assert v_descale is None or v_descale==1, f"v_descale is not None, v_descale: {v_descale}"
# Process shared prefix.
prefix_output, prefix_lse = flash_attn_varlen_func(
@@ -1241,7 +1426,6 @@ def cascade_attention(
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_prefix_query_lens,
# seqused_k=prefix_kv_lens,
cu_seqlens_k=cu_prefix_kv_lens,
max_seqlen_q=num_tokens,
max_seqlen_k=common_prefix_len,
@@ -1251,26 +1435,14 @@ def cascade_attention(
block_table=block_table[:1],
softcap=logits_soft_cap,
return_softmax_lse=True,
# scheduler_metadata=prefix_scheduler_metadata,
# fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge.
# s_aux=s_aux,
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
)
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
# Process suffix per query.
suffix_output, suffix_lse = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
# seqused_k=suffix_kv_lens,
cu_seqlens_k=cu_suffix_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len - common_prefix_len,
@@ -1280,14 +1452,6 @@ def cascade_attention(
block_table=block_table[:, num_common_kv_blocks:],
softcap=logits_soft_cap,
return_softmax_lse=True,
# scheduler_metadata=suffix_scheduler_metadata,
# fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
)
# Merge prefix and suffix outputs, and store the result in output.
# merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse)
merge_attn_states(prefix_output, prefix_lse, suffix_output, suffix_lse, output)