Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
from typing import ClassVar, Optional, Union, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -23,15 +23,15 @@ from vllm.v1.attention.backends.fa_utils import (
|
||||
is_flash_attn_varlen_func_available,
|
||||
)
|
||||
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
|
||||
from ixformer.contrib.vllm_flash_attn import merge_attn_states
|
||||
|
||||
if is_flash_attn_varlen_func_available():
|
||||
from vllm.v1.attention.backends.fa_utils import (
|
||||
flash_attn_supports_sinks,
|
||||
flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache,
|
||||
# get_scheduler_metadata,
|
||||
reshape_and_cache_flash,
|
||||
flash_attn_varlen_int8_func
|
||||
)
|
||||
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
|
||||
from vllm.config.cache import CacheDType
|
||||
@@ -50,9 +50,12 @@ from vllm.v1.attention.backend import (
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
get_dcp_local_seq_lens,
|
||||
get_kv_cache_layout,
|
||||
split_decodes_and_prefills,
|
||||
split_decodes_and_prefills
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm import _custom_ops as ops
|
||||
import vllm.envs as envs
|
||||
import ixformer.inference.functions as ixf_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -63,23 +66,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
if (
|
||||
model_config
|
||||
and model_config.is_hybrid
|
||||
and (
|
||||
cache_config.mamba_ssm_cache_dtype == "float32"
|
||||
or cache_config.mamba_cache_dtype == "float32"
|
||||
)
|
||||
):
|
||||
# NOTE(tdoublep): while in principle, FA supports
|
||||
# MultipleOf(16), these are the block sizes that do not
|
||||
# suffer from the NaN propagation problem described here:
|
||||
# https://github.com/Dao-AILab/flash-attention/issues/1974
|
||||
return [16, 32, 64]
|
||||
return [MultipleOf(16)]
|
||||
return [16, 32, 64]
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
@@ -120,7 +107,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
# return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
if envs.VLLM_ATTN_OPT_LEVEL == 2:
|
||||
return (3, num_blocks, num_kv_heads, block_size, head_size)
|
||||
return (2, num_blocks, num_kv_heads, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
@@ -139,7 +127,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
# (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
|
||||
return (2, 4, 0, 1, 3, 5)
|
||||
elif cache_layout == "HND":
|
||||
stride_order = (0, 1, 3, 2, 4)
|
||||
stride_order = (0, 1, 2, 3, 4)
|
||||
else:
|
||||
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
||||
return stride_order
|
||||
@@ -188,24 +176,22 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
if has_sink and device_capability < DeviceCapability(9, 0):
|
||||
return "sink not supported on compute capability < 9.0"
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionPrefillMetadata:
|
||||
"""Prefill Specific Metadata"""
|
||||
|
||||
""" Prefill Specific Metadata """
|
||||
block_table: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
key_start_loc: torch.Tensor
|
||||
max_query_len: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionDecodeMetadata:
|
||||
block_table: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
max_query_len: int
|
||||
max_decode_seq_len: int
|
||||
|
||||
use_graph: bool
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata:
|
||||
@@ -220,11 +206,12 @@ class FlashAttentionMetadata:
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
key_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
@@ -235,7 +222,6 @@ class FlashAttentionMetadata:
|
||||
cu_prefix_query_lens: torch.Tensor | None
|
||||
prefix_kv_lens: torch.Tensor | None
|
||||
suffix_kv_lens: torch.Tensor | None
|
||||
|
||||
cu_prefix_kv_lens: torch.Tensor | None
|
||||
cu_suffix_kv_lens: torch.Tensor | None
|
||||
|
||||
@@ -247,7 +233,7 @@ class FlashAttentionMetadata:
|
||||
scheduler_metadata: torch.Tensor | None = None
|
||||
prefix_scheduler_metadata: torch.Tensor | None = None
|
||||
max_num_splits: int = 0
|
||||
|
||||
|
||||
prefill: FlashAttentionPrefillMetadata | None = None
|
||||
decode: FlashAttentionDecodeMetadata | None = None
|
||||
|
||||
@@ -291,7 +277,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
else AttentionCGSupport.UNIFORM_BATCH
|
||||
)
|
||||
supports_update_block_table: bool = True
|
||||
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
@classmethod
|
||||
@@ -316,6 +302,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.attention_config = vllm_config.attention_config
|
||||
|
||||
self.decode_use_graph = vllm_config.compilation_config.cudagraph_mode.decode_use_graph()
|
||||
self.num_heads_q = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config
|
||||
)
|
||||
@@ -325,7 +312,6 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
|
||||
self.max_num_splits = 0 # No upper bound on the number of splits.
|
||||
# self.aot_schedule = get_flash_attn_version() == 3
|
||||
self.aot_schedule = False
|
||||
|
||||
try:
|
||||
@@ -346,6 +332,9 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
)
|
||||
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
|
||||
# Align decode/prefill split threshold with speculative decode query length
|
||||
# when backend supports treating spec requests as decode.
|
||||
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
|
||||
|
||||
if self.use_full_cuda_graph and self.aot_schedule:
|
||||
# FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4
|
||||
@@ -388,15 +377,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
key_start_loc = common_attn_metadata.key_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_np = common_attn_metadata.seq_lens_np
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
causal = common_attn_metadata.causal
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(common_attn_metadata)
|
||||
)
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold,
|
||||
)
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
|
||||
|
||||
@@ -467,11 +458,11 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
dcp_context_kv_lens = None
|
||||
|
||||
cu_prefix_query_lens = None
|
||||
cu_prefix_kv_lens = None
|
||||
cu_suffix_kv_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
prefix_scheduler_metadata = None
|
||||
cu_prefix_kv_lens = None
|
||||
cu_suffix_kv_lens = None
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
query_kv_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
@@ -507,11 +498,11 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
prefix_kv_lens = torch.tensor(
|
||||
[common_prefix_len], dtype=torch.int32, device=self.device
|
||||
)
|
||||
# Use GPU tensor directly - no CPU sync needed
|
||||
suffix_kv_lens = seq_lens[:num_reqs] - common_prefix_len
|
||||
cu_prefix_kv_lens = torch.tensor([0, common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
# Use GPU tensor directly - no CPU sync needed
|
||||
suffix_kv_lens = seq_lens[:num_reqs] - common_prefix_len
|
||||
|
||||
cu_suffix_kv_lens = torch.tensor([0,] + suffix_kv_lens.tolist(),
|
||||
dtype=torch.int32,
|
||||
@@ -542,7 +533,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
causal=causal,
|
||||
)
|
||||
# For FA3 + full cudagraph
|
||||
max_num_splits = 0
|
||||
max_num_splits = 0
|
||||
if self.use_full_cuda_graph and scheduler_metadata is not None:
|
||||
n = scheduler_metadata.shape[0]
|
||||
self.scheduler_metadata[:n] = scheduler_metadata
|
||||
@@ -552,50 +543,59 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
# output buffer.
|
||||
self.scheduler_metadata[n:] = 0
|
||||
scheduler_metadata = self.scheduler_metadata[:n]
|
||||
|
||||
|
||||
if num_actual_tokens <= self.max_cudagraph_size:
|
||||
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
||||
# usage, because the intermediate buffers of size [num_splits,
|
||||
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
reqs_start = num_decodes
|
||||
prefill_query_start_loc = (
|
||||
query_start_loc[reqs_start:] - query_start_loc[reqs_start]
|
||||
)
|
||||
prefill_key_start_loc = (
|
||||
query_start_loc[reqs_start:] - query_start_loc[reqs_start]
|
||||
)
|
||||
reqs_start = num_decodes # prefill_start
|
||||
|
||||
prefill_query_start_loc = query_start_loc[
|
||||
reqs_start:] - query_start_loc[reqs_start]
|
||||
prefill_key_start_loc = key_start_loc[
|
||||
reqs_start:] - key_start_loc[reqs_start]
|
||||
prefill_metadata = FlashAttentionPrefillMetadata(
|
||||
block_table=block_table_tensor[reqs_start:, ...],
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
key_start_loc=prefill_key_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
)
|
||||
block_table=block_table_tensor[reqs_start:, ...],
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
key_start_loc=prefill_key_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
)
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
reqs_start = num_decodes
|
||||
reqs_start = num_decodes # prefill_start
|
||||
decode_query_start_loc = query_start_loc[: reqs_start + 1]
|
||||
decode_query_lens = (
|
||||
decode_query_start_loc[1:] - decode_query_start_loc[:-1]
|
||||
)
|
||||
decode_metadata = FlashAttentionDecodeMetadata(
|
||||
block_table=block_table_tensor[:reqs_start, ...],
|
||||
query_start_loc=decode_query_start_loc,
|
||||
seq_lens=seq_lens[:reqs_start],
|
||||
max_decode_seq_len=torch.max(seq_lens[:reqs_start]).item(),
|
||||
max_query_len=decode_query_lens.max().item(),
|
||||
max_decode_seq_len=np.max(seq_lens_np[:reqs_start]).item(),
|
||||
use_graph=num_prefills==0 and self.decode_use_graph
|
||||
)
|
||||
|
||||
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
key_start_loc=key_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
max_dcp_context_kv_len=max_dcp_context_kv_len,
|
||||
dcp_context_kv_lens=dcp_context_kv_lens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
max_dcp_context_kv_len=max_dcp_context_kv_len,
|
||||
dcp_context_kv_lens=dcp_context_kv_lens,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
@@ -607,8 +607,8 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
max_num_splits=max_num_splits,
|
||||
causal=causal,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
prefill = prefill_metadata,
|
||||
decode = decode_metadata,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
@@ -621,6 +621,19 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
new_metadata = copy.copy(metadata)
|
||||
new_metadata.block_table = blk_table
|
||||
new_metadata.slot_mapping = slot_mapping
|
||||
# Keep nested prefill/decode block tables in sync. Decode path consumes
|
||||
# `attn_metadata.decode.block_table`, so updating only the top-level
|
||||
# `block_table` is insufficient when metadata is reused across groups.
|
||||
if metadata.decode is not None:
|
||||
new_decode = copy.copy(metadata.decode)
|
||||
reqs_start = metadata.num_decodes
|
||||
new_decode.block_table = blk_table[:reqs_start, ...]
|
||||
new_metadata.decode = new_decode
|
||||
if metadata.prefill is not None:
|
||||
new_prefill = copy.copy(metadata.prefill)
|
||||
reqs_start = metadata.num_decodes
|
||||
new_prefill.block_table = blk_table[reqs_start:, ...]
|
||||
new_metadata.prefill = new_prefill
|
||||
return new_metadata
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
@@ -667,7 +680,15 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
self.attn_type = attn_type
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
self.vllm_flash_attn_version = get_flash_attn_version(
|
||||
requires_alibi=alibi_slopes is not None,
|
||||
head_size=head_size,
|
||||
)
|
||||
logger.info_once(
|
||||
"Using FlashAttention version %s",
|
||||
self.vllm_flash_attn_version,
|
||||
scope="local",
|
||||
)
|
||||
# Cache the batch invariant result for use in forward passes
|
||||
self.batch_invariant_enabled = vllm_is_batch_invariant()
|
||||
|
||||
@@ -677,6 +698,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
self.sinks = sinks
|
||||
|
||||
if self.sinks is not None:
|
||||
assert flash_attn_supports_sinks(), (
|
||||
"Sinks are only supported in FlashAttention 3"
|
||||
@@ -687,6 +709,28 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
self.supports_quant_query_input = True
|
||||
self.supports_per_head_quant_scales = (
|
||||
self.vllm_flash_attn_version >= 3
|
||||
if self.vllm_flash_attn_version is not None
|
||||
else False
|
||||
)
|
||||
assert envs.VLLM_ATTN_OPT_LEVEL in [0, 1, 2], "VLLM_ATTN_OPT_LEVEL only support [0 for non-quant, 1 for I8Q_I8K_I8V, 2 for I8Q_I8K_F16V] now! but got {}".format(envs.VLLM_ATTN_OPT_LEVEL)
|
||||
'''
|
||||
quant_type = 0
|
||||
attention:f16 qkv
|
||||
cache:f16 kv cache
|
||||
quant_type = 1
|
||||
attention:int8q int8k int8v
|
||||
cache:
|
||||
int8 k cache && fp32 k cache scale
|
||||
int8 v cache && fp32 v cache scale(load from file, dont update)
|
||||
quant_type = 2
|
||||
attention:int8q int8k fp16v
|
||||
cache:
|
||||
int8 k cache && fp32 k cache scale
|
||||
fp16 v cache
|
||||
'''
|
||||
self.quant_type = int(envs.VLLM_ATTN_OPT_LEVEL)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -698,7 +742,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
sqrt_alibi: bool = False,
|
||||
kv_cache_scale: torch.Tensor | None = None,
|
||||
kv_cache_scale: Union[torch.Tensor, List[torch.Tensor]] | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -711,6 +755,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
kv_cache_scale = [num_blocks, num_kv_heads, block_size] + [num_kv_heads, head_size]
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
NOTE: FP8 quantization, flash-attn expect the size of
|
||||
@@ -718,9 +763,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
assert self.vllm_flash_attn_version is not None, (
|
||||
"FlashAttention version not detected."
|
||||
)
|
||||
# assert self.vllm_flash_attn_version is not None, (
|
||||
# "FlashAttention version not detected."
|
||||
# )
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
@@ -729,13 +774,12 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
# return output.fill_(0)
|
||||
return output.fill_(0).view(-1, self.num_heads * self.head_size)
|
||||
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
softmax_scale: float = self.scale
|
||||
window_size = self.sliding_window
|
||||
alibi_slopes: torch.Tensor | None = self.alibi_slopes
|
||||
logits_soft_cap: float | None = self.logits_soft_cap
|
||||
alibi_slopes: torch.Tensor = self.alibi_slopes
|
||||
logits_soft_cap: float = self.logits_soft_cap
|
||||
|
||||
attn_type = self.attn_type
|
||||
|
||||
@@ -761,18 +805,140 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
layer,
|
||||
)
|
||||
).view(-1, self.num_heads * self.head_size)
|
||||
|
||||
# For decoder and cross-attention, use KV cache as before
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
decode_only = has_decode and not has_prefill
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
if self.quant_type == 0:
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
elif self.quant_type == 1:
|
||||
i8_key_cache, i8_value_cache = kv_cache.unbind(0)
|
||||
num_blocks, num_kv_heads, block_size, head_size = i8_key_cache.shape
|
||||
|
||||
key_scale_cache, value_scale_cache = kv_cache_scale
|
||||
assert key_scale_cache.shape == (num_blocks, num_kv_heads, block_size) and key_scale_cache.dtype == torch.float32, f"key_scale_cache.shape {key_scale_cache.shape} != (num_blocks, num_kv_heads, block_size) or key_scale_cache.dtype {key_scale_cache.dtype} != torch.float32"
|
||||
assert value_scale_cache.shape == (num_kv_heads, head_size) and value_scale_cache.dtype == torch.float32, f"value_scale_cache.shape {value_scale_cache.shape} != (num_kv_heads, head_size) or value_scale_cache.dtype {value_scale_cache.dtype} != torch.float32"
|
||||
value_cache_info = (i8_value_cache, value_scale_cache)
|
||||
|
||||
elif self.quant_type == 2:
|
||||
# key_cache 是 f16,value_cache 是 int8
|
||||
i8_key_cache = kv_cache[0]
|
||||
num_blocks, num_kv_heads, block_size, head_size = i8_key_cache.shape
|
||||
value_cache = kv_cache[1:].view(query.dtype).reshape(num_blocks, num_kv_heads, block_size, head_size)
|
||||
key_scale_cache = kv_cache_scale
|
||||
value_cache_info = (value_cache, None)
|
||||
|
||||
decode_q = query[:num_decode_tokens]
|
||||
prefill_q = query[num_decode_tokens:]
|
||||
prefill_output = output[num_decode_tokens:]
|
||||
decode_output = output[:num_decode_tokens]
|
||||
|
||||
if self.quant_type == 1:
|
||||
if decode_only:
|
||||
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
query, 2, transpose_scale=False
|
||||
)
|
||||
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
key, 2, transpose_scale=False
|
||||
)
|
||||
i8_value, _value_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
value, 0, transpose_scale=False, scale=value_cache_info[1]
|
||||
)
|
||||
else:
|
||||
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
query, 2, transpose_scale=True
|
||||
)
|
||||
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
key, 2, transpose_scale=False
|
||||
)
|
||||
i8_value, _value_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
value, 0, transpose_scale=False, scale=value_cache_info[1]
|
||||
)
|
||||
elif self.quant_type == 2:
|
||||
'''
|
||||
origin key cache
|
||||
num_blocks, num_kv_heads, block_size, head_size f16
|
||||
reformat key cache
|
||||
key_cache_i8 : num_blocks, num_kv_heads, block_size, head_size int8
|
||||
key_scale_cache : num_blocks, num_kv_heads, block_size fp32
|
||||
'''
|
||||
|
||||
if decode_only:
|
||||
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
query, 2, transpose_scale=False
|
||||
)
|
||||
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
key, 2, transpose_scale=False
|
||||
)
|
||||
else:
|
||||
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
query, 2, transpose_scale=True
|
||||
)
|
||||
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
key, 2, transpose_scale=False
|
||||
)
|
||||
else:
|
||||
if layer.quant_manager is not None and layer.quant_manager.check_enable():
|
||||
i8_value, value_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
value, 0, transpose_scale=False
|
||||
)
|
||||
layer.quant_manager.update_data(value_scale)
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
if self.quant_type == 1:
|
||||
if has_prefill:
|
||||
ixf_ops.reshape_and_cache_flash_int8(
|
||||
key=i8_key,
|
||||
value=i8_value,
|
||||
k_scale=key_scale,
|
||||
key_cache=i8_key_cache,
|
||||
value_cache=value_cache_info[0],
|
||||
key_scale_cache=key_scale_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping,
|
||||
kv_cache_dtype="",
|
||||
)
|
||||
elif self.quant_type == 2:
|
||||
if has_prefill:
|
||||
ixf_ops.reshape_and_cache_flash_mix(
|
||||
key=i8_key,
|
||||
value=value,
|
||||
k_scale=key_scale,
|
||||
key_cache=i8_key_cache,
|
||||
value_cache=value_cache_info[0],
|
||||
key_scale_cache=key_scale_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping,
|
||||
kv_cache_dtype="",
|
||||
)
|
||||
|
||||
else:
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
# queries are quantized in the attention layer
|
||||
@@ -783,19 +949,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
value_cache = value_cache.view(dtype)
|
||||
|
||||
if not attn_metadata.use_cascade:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
scheduler_metadata = attn_metadata.scheduler_metadata
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
|
||||
|
||||
q_descale = layer._q_scale.expand(descale_shape)
|
||||
k_descale = layer._k_scale.expand(descale_shape)
|
||||
v_descale = layer._v_scale.expand(descale_shape)
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
self._forward_with_dcp(
|
||||
query[:num_actual_tokens],
|
||||
@@ -805,79 +958,140 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
value_cache,
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
else:
|
||||
sliding_window_size = (
|
||||
list(self.sliding_window)
|
||||
if self.sliding_window is not None
|
||||
else None
|
||||
)
|
||||
if has_prefill:
|
||||
flash_attn_varlen_func(
|
||||
q=prefill_q,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
|
||||
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
|
||||
max_seqlen_q=attn_metadata.prefill.max_query_len,
|
||||
max_seqlen_k=attn_metadata.max_query_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=sliding_window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
sqrt_alibi=sqrt_alibi,
|
||||
sinks=self.sinks,
|
||||
out=prefill_output,
|
||||
block_table=attn_metadata.prefill.block_table,
|
||||
)
|
||||
# key = key[num_decode_tokens:]
|
||||
# value = value[num_decode_tokens:]
|
||||
|
||||
# int8 attn
|
||||
if self.quant_type > 0:
|
||||
flash_attn_varlen_int8_func(
|
||||
q=int8_query[num_decode_tokens:],
|
||||
k=i8_key_cache,
|
||||
v=value_cache_info[0],
|
||||
q_scale=query_scale[:, num_decode_tokens:],
|
||||
k_scale=key_scale_cache,
|
||||
v_scale=value_cache_info[1],
|
||||
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
|
||||
cu_seqlens_k=attn_metadata.prefill.key_start_loc,
|
||||
max_seqlen_q=attn_metadata.prefill.max_query_len,
|
||||
max_seqlen_k=attn_metadata.max_query_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
sqrt_alibi=sqrt_alibi,
|
||||
out=prefill_output,
|
||||
block_table=attn_metadata.prefill.block_table,
|
||||
output_dtype=query.dtype
|
||||
)
|
||||
else:
|
||||
flash_attn_varlen_func(
|
||||
q=prefill_q,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
|
||||
cu_seqlens_k=attn_metadata.prefill.key_start_loc,
|
||||
max_seqlen_q=attn_metadata.prefill.max_query_len,
|
||||
max_seqlen_k=attn_metadata.max_query_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
sqrt_alibi=sqrt_alibi,
|
||||
sinks=self.sinks,
|
||||
out=prefill_output,
|
||||
block_table=attn_metadata.prefill.block_table,
|
||||
)
|
||||
if has_decode:
|
||||
flash_attn_with_kvcache(
|
||||
q=decode_q.unsqueeze(1),
|
||||
k_cache=key_cache.contiguous(),
|
||||
v_cache=value_cache.contiguous(),
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=sliding_window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
use_sqrt_alibi=sqrt_alibi,
|
||||
out=decode_output.unsqueeze(1),
|
||||
max_context_len=attn_metadata.decode.max_decode_seq_len,
|
||||
# sinks=self.sinks,
|
||||
)
|
||||
# for mtp + cuda graph
|
||||
max_q_len = attn_metadata.decode.max_query_len if attn_metadata.decode is not None else attn_metadata.max_query_len
|
||||
max_ct_len = attn_metadata.decode.max_decode_seq_len if attn_metadata.decode is not None else attn_metadata.max_seq_len
|
||||
if self.quant_type in [1, 2]:
|
||||
para_dict = dict(
|
||||
output=decode_output,
|
||||
query=int8_query[:num_decode_tokens],
|
||||
key_cache=i8_key_cache,
|
||||
query_scale=query_scale[:num_decode_tokens] if decode_only else query_scale[:, :num_decode_tokens].t().contiguous(),
|
||||
key_scale_cache=key_scale_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
scale=softmax_scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
block_size=i8_key_cache.shape[-2],
|
||||
softcap=logits_soft_cap,
|
||||
alibi_slopes=alibi_slopes,
|
||||
causal=True,
|
||||
window_left=window_size[0],
|
||||
window_right=window_size[1],
|
||||
use_sqrt_alibi = sqrt_alibi,
|
||||
use_cuda_graph=attn_metadata.decode.use_graph if decode_only else False,
|
||||
max_context_len=max_ct_len,
|
||||
# mtp
|
||||
cu_query_lens=attn_metadata.decode.query_start_loc,
|
||||
max_query_len=max_q_len,
|
||||
)
|
||||
|
||||
if self.quant_type == 1:
|
||||
para_dict.update(
|
||||
dict(
|
||||
value_cache=value_cache_info[0],
|
||||
value_scale_cache=value_cache_info[1],
|
||||
)
|
||||
)
|
||||
# for kv + k_scale write fusion
|
||||
if decode_only:
|
||||
para_dict.update(
|
||||
dict(
|
||||
save_key=i8_key[:num_decode_tokens],
|
||||
save_value=i8_value[:num_decode_tokens],
|
||||
save_key_scale=key_scale[:num_decode_tokens],
|
||||
)
|
||||
)
|
||||
ixf_ops.vllm_paged_attention_int8(**para_dict)
|
||||
elif self.quant_type == 2:
|
||||
para_dict.update(
|
||||
dict(
|
||||
value_cache=value_cache,
|
||||
)
|
||||
)
|
||||
if decode_only:
|
||||
para_dict.update(
|
||||
dict(
|
||||
save_key=i8_key[:num_decode_tokens],
|
||||
save_value=value[:num_decode_tokens].contiguous(),
|
||||
save_key_scale=key_scale[:num_decode_tokens],
|
||||
)
|
||||
)
|
||||
ixf_ops.vllm_paged_attention_mix(
|
||||
**para_dict
|
||||
)
|
||||
else:
|
||||
flash_attn_with_kvcache(
|
||||
q=decode_q.unsqueeze(1),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
max_query_len=max_q_len,
|
||||
cu_query_lens=attn_metadata.decode.query_start_loc,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
use_sqrt_alibi=sqrt_alibi,
|
||||
sinks=self.sinks,
|
||||
out=decode_output.unsqueeze(1),
|
||||
use_cuda_graph=attn_metadata.decode.use_graph,
|
||||
max_context_len=max_ct_len
|
||||
)
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
# flash_attn_varlen_func(
|
||||
# q=query[:num_actual_tokens],
|
||||
# k=key_cache,
|
||||
# v=value_cache,
|
||||
# out=output[:num_actual_tokens],
|
||||
# cu_seqlens_q=cu_seqlens_q,
|
||||
# max_seqlen_q=max_seqlen_q,
|
||||
# seqused_k=seqused_k,
|
||||
# max_seqlen_k=max_seqlen_k,
|
||||
# softmax_scale=self.scale,
|
||||
# causal=attn_metadata.causal,
|
||||
# alibi_slopes=self.alibi_slopes,
|
||||
# window_size=sliding_window_size,
|
||||
# block_table=block_table,
|
||||
# softcap=self.logits_soft_cap,
|
||||
# scheduler_metadata=scheduler_metadata,
|
||||
# fa_version=self.vllm_flash_attn_version,
|
||||
# q_descale=q_descale,
|
||||
# k_descale=k_descale,
|
||||
# v_descale=v_descale,
|
||||
# num_splits=attn_metadata.max_num_splits,
|
||||
# s_aux=self.sinks,
|
||||
# )
|
||||
# return output
|
||||
|
||||
# Cascade attention (rare case).
|
||||
cascade_attention(
|
||||
@@ -906,12 +1120,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
v_descale=layer._v_scale,
|
||||
s_aux=self.sinks,
|
||||
)
|
||||
# return output
|
||||
return (
|
||||
output[:num_actual_tokens]
|
||||
.contiguous()
|
||||
.view(-1, self.num_heads * self.head_size)
|
||||
)
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
@@ -935,7 +1144,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
reshape_and_cache_flash(
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
@@ -959,9 +1168,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
k_descale: torch.Tensor | None = None,
|
||||
v_descale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.vllm_flash_attn_version is not None, (
|
||||
"FlashAttention version not detected."
|
||||
)
|
||||
# assert self.vllm_flash_attn_version is not None, (
|
||||
# "FlashAttention version not detected."
|
||||
# )
|
||||
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
@@ -969,27 +1178,22 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
query = query.contiguous()
|
||||
query_across_dcp = get_dcp_group().all_gather(query, dim=1)
|
||||
cu_dcp_kv_klens = attn_metadata.dcp_context_kv_lens.cumsum(dim=0, dtype=torch.int32)
|
||||
new_tensor = torch.tensor([0],
|
||||
device=attn_metadata.dcp_context_kv_lens.device,
|
||||
dtype=attn_metadata.dcp_context_kv_lens.dtype)
|
||||
cu_seqlens_k = torch.cat([new_tensor, cu_dcp_kv_klens])
|
||||
sliding_window_size = (
|
||||
list(self.sliding_window) if self.sliding_window is not None else None
|
||||
)
|
||||
cu_seqlens_k = torch.cat(
|
||||
[
|
||||
torch.zeros(1, device=cu_seqlens_q.device, dtype=cu_seqlens_q.dtype),
|
||||
attn_metadata.dcp_context_kv_lens.cumsum(
|
||||
dim=0, dtype=cu_seqlens_q.dtype
|
||||
),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
context_attn_out, context_lse = flash_attn_varlen_func(
|
||||
q=query_across_dcp,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=None,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
# seqused_k=attn_metadata.dcp_context_kv_lens,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=attn_metadata.max_dcp_context_kv_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
@@ -998,11 +1202,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
# scheduler_metadata=attn_metadata.scheduler_metadata,
|
||||
# fa_version=self.vllm_flash_attn_version,
|
||||
# q_descale=q_descale,
|
||||
# k_descale=k_descale,
|
||||
# v_descale=v_descale,
|
||||
)
|
||||
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
|
||||
context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs(
|
||||
@@ -1028,10 +1227,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
window_size=sliding_window_size,
|
||||
softcap=self.logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
# fa_version=self.vllm_flash_attn_version,
|
||||
# q_descale=q_descale,
|
||||
# k_descale=k_descale,
|
||||
# v_descale=v_descale,
|
||||
)
|
||||
assert context_attn_out_cor.shape == query_attn_out.shape
|
||||
assert context_lse_cor.shape == query_lse.shape
|
||||
@@ -1040,7 +1235,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
context_lse_cor,
|
||||
query_attn_out,
|
||||
query_lse,
|
||||
output,
|
||||
output
|
||||
)
|
||||
|
||||
def _forward_encoder_attention(
|
||||
@@ -1062,9 +1257,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
attn_metadata: Encoder attention metadata
|
||||
layer: The attention layer
|
||||
"""
|
||||
assert self.vllm_flash_attn_version is not None, (
|
||||
"FlashAttention version not detected."
|
||||
)
|
||||
# assert self.vllm_flash_attn_version is not None, (
|
||||
# "FlashAttention version not detected."
|
||||
# )
|
||||
|
||||
# For encoder attention, process FP8 quantization if needed
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
@@ -1101,18 +1296,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=sliding_window_size,
|
||||
softcap=self.logits_soft_cap,
|
||||
# fa_version=self.vllm_flash_attn_version,
|
||||
# q_descale=layer._q_scale.expand(descale_shape),
|
||||
# k_descale=layer._k_scale.expand(descale_shape),
|
||||
# v_descale=layer._v_scale.expand(descale_shape),
|
||||
# num_splits=1 if self.batch_invariant_enabled else 0,
|
||||
)
|
||||
|
||||
return (
|
||||
output[: attn_metadata.num_actual_tokens]
|
||||
.contiguous()
|
||||
.view(-1, self.num_heads * self.head_size)
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def use_cascade_attention(
|
||||
@@ -1203,8 +1389,6 @@ def cascade_attention(
|
||||
cu_prefix_query_lens: torch.Tensor,
|
||||
cu_prefix_kv_lens: torch.Tensor,
|
||||
cu_suffix_kv_lens: torch.Tensor,
|
||||
# prefix_kv_lens: torch.Tensor,
|
||||
# suffix_kv_lens: torch.Tensor,
|
||||
max_kv_len: int,
|
||||
softmax_scale: float,
|
||||
alibi_slopes: torch.Tensor | None,
|
||||
@@ -1228,12 +1412,13 @@ def cascade_attention(
|
||||
)
|
||||
|
||||
num_tokens = query.shape[0]
|
||||
# block_size = key_cache.shape[-3]
|
||||
block_size = key_cache.shape[-2]
|
||||
assert common_prefix_len % block_size == 0
|
||||
num_common_kv_blocks = common_prefix_len // block_size
|
||||
assert num_common_kv_blocks > 0
|
||||
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||
assert q_descale is None or q_descale==1, f"q_descale is not None, q_descale: {q_descale}"
|
||||
assert k_descale is None or k_descale==1, f"k_descale is not None, k_descale: {k_descale}"
|
||||
assert v_descale is None or v_descale==1, f"v_descale is not None, v_descale: {v_descale}"
|
||||
|
||||
# Process shared prefix.
|
||||
prefix_output, prefix_lse = flash_attn_varlen_func(
|
||||
@@ -1241,7 +1426,6 @@ def cascade_attention(
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_prefix_query_lens,
|
||||
# seqused_k=prefix_kv_lens,
|
||||
cu_seqlens_k=cu_prefix_kv_lens,
|
||||
max_seqlen_q=num_tokens,
|
||||
max_seqlen_k=common_prefix_len,
|
||||
@@ -1251,26 +1435,14 @@ def cascade_attention(
|
||||
block_table=block_table[:1],
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
# scheduler_metadata=prefix_scheduler_metadata,
|
||||
# fa_version=fa_version,
|
||||
# q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
|
||||
# k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
|
||||
# v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
|
||||
# s_aux is incorporated into prefix_lse inside the GPU kernel,
|
||||
# enabling its effect during the final attention merge.
|
||||
# s_aux=s_aux,
|
||||
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
|
||||
)
|
||||
|
||||
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||
|
||||
# Process suffix per query.
|
||||
suffix_output, suffix_lse = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
# seqused_k=suffix_kv_lens,
|
||||
cu_seqlens_k=cu_suffix_kv_lens,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len - common_prefix_len,
|
||||
@@ -1280,14 +1452,6 @@ def cascade_attention(
|
||||
block_table=block_table[:, num_common_kv_blocks:],
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
# scheduler_metadata=suffix_scheduler_metadata,
|
||||
# fa_version=fa_version,
|
||||
# q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
|
||||
# k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
|
||||
# v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
|
||||
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
|
||||
)
|
||||
|
||||
# Merge prefix and suffix outputs, and store the result in output.
|
||||
# merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse)
|
||||
merge_attn_states(prefix_output, prefix_lse, suffix_output, suffix_lse, output)
|
||||
|
||||
Reference in New Issue
Block a user