1030 lines
47 KiB
Python
1030 lines
47 KiB
Python
from __future__ import annotations
|
||
|
||
from typing import TYPE_CHECKING, Optional
|
||
|
||
import torch
|
||
|
||
from sglang.srt.configs.model_config import AttentionArch
|
||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||
from sglang.srt.layers.attention.flashattention_backend import (
|
||
FlashAttentionMetadata,
|
||
make_local_attention_virtual_batches,
|
||
merge_state_v2_wrapper,
|
||
prepare_swa_spec_page_table_triton,
|
||
)
|
||
from sglang.srt.managers.schedule_batch import get_global_server_args
|
||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||
|
||
if TYPE_CHECKING:
|
||
from sglang.srt.layers.radix_attention import RadixAttention
|
||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||
|
||
from sgl_kernel import merge_state_v2
|
||
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||
from sglang.srt.layers.attention.flashattention_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
||
|
||
|
||
class XPUAttentionBackend(AttentionBackend):
|
||
"""XPU FlashAttention backend, currently based on FlashAttentionBackend, will be refactored later.
|
||
|
||
TODO:
|
||
- Prefill and Decode disaggregation, currently only chunked prefill is supported
|
||
- Speculative Decoding support
|
||
- XPU Graph support, see https://github.com/pytorch/pytorch/issues/162143
|
||
- MLA support
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
model_runner: ModelRunner,
|
||
skip_prefill: bool = False,
|
||
speculative_step_id=0,
|
||
topk=0,
|
||
speculative_num_steps=0,
|
||
):
|
||
super().__init__()
|
||
|
||
assert not (
|
||
model_runner.sliding_window_size is not None
|
||
and model_runner.model_config.is_encoder_decoder
|
||
), "Sliding window and cross attention are not supported together"
|
||
|
||
self.forward_metadata: FlashAttentionMetadata = None
|
||
# extra metadata for handling speculative decoding topk > 1, extended draft decode and verify
|
||
self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None
|
||
self.max_context_len = model_runner.model_config.context_len
|
||
self.device = model_runner.device
|
||
self.decode_cuda_graph_metadata = {}
|
||
self.target_verify_metadata = {}
|
||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
||
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
|
||
self.page_size = model_runner.page_size
|
||
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
|
||
assert (
|
||
self.use_mla is False
|
||
), "XPUAttentionBackend doesn't support MLA yet, please use --attention-backend triton instead."
|
||
self.skip_prefill = skip_prefill
|
||
self.is_hybrid = model_runner.is_hybrid
|
||
if self.is_hybrid:
|
||
self.full_to_swa_index_mapping = (
|
||
model_runner.token_to_kv_pool.full_to_swa_index_mapping
|
||
)
|
||
self.topk = model_runner.server_args.speculative_eagle_topk or 0
|
||
self.speculative_num_steps = speculative_num_steps
|
||
self.speculative_num_draft_tokens = (
|
||
model_runner.server_args.speculative_num_draft_tokens
|
||
)
|
||
self.speculative_step_id = speculative_step_id
|
||
|
||
# Local attention settings
|
||
self.attention_chunk_size = (
|
||
model_runner.attention_chunk_size
|
||
if hasattr(model_runner, "attention_chunk_size")
|
||
else None
|
||
)
|
||
|
||
# For each layer, the sliding_window_size can be different. This is only used for preparing SWA metadata.
|
||
# We use `layer.sliding_window_size` to decide whether to use SWA for each layer.
|
||
self.sliding_window_size = model_runner.sliding_window_size
|
||
self.has_swa = (
|
||
self.sliding_window_size is not None and self.sliding_window_size > -1
|
||
)
|
||
|
||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
|
||
metadata = FlashAttentionMetadata()
|
||
seqlens_in_batch = forward_batch.seq_lens
|
||
batch_size = forward_batch.batch_size
|
||
device = seqlens_in_batch.device
|
||
|
||
if forward_batch.forward_mode.is_decode_or_idle():
|
||
# Draft Decode
|
||
if forward_batch.spec_info is not None:
|
||
assert (
|
||
False
|
||
), "XPUAttentionBackend doesn't support speculative decoding yet, please use --attention-backend triton instead."
|
||
if self.topk <= 1:
|
||
metadata.cache_seqlens_int32 = (
|
||
seqlens_in_batch + (self.speculative_step_id + 1)
|
||
).to(torch.int32)
|
||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
|
||
self.speculative_step_id + 1
|
||
)
|
||
metadata.cu_seqlens_q = torch.arange(
|
||
0, batch_size + 1, dtype=torch.int32, device=device
|
||
)
|
||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||
torch.cumsum(
|
||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||
),
|
||
(1, 0),
|
||
)
|
||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||
]
|
||
else:
|
||
metadata.cache_seqlens_int32 = (seqlens_in_batch).to(torch.int32)
|
||
metadata.max_seq_len_q = self.topk
|
||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||
metadata.cu_seqlens_q = torch.arange(
|
||
0,
|
||
batch_size * self.topk + 1,
|
||
step=self.topk,
|
||
dtype=torch.int32,
|
||
device=device,
|
||
)
|
||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||
torch.cumsum(
|
||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||
),
|
||
(1, 0),
|
||
)
|
||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||
]
|
||
|
||
metadata_expand = FlashAttentionMetadata()
|
||
decode_length = self.speculative_step_id + 1
|
||
metadata_expand.cache_seqlens_int32 = torch.full(
|
||
(seqlens_in_batch.numel() * self.topk,),
|
||
decode_length,
|
||
device=device,
|
||
dtype=torch.int32,
|
||
)
|
||
metadata_expand.max_seq_len_q = 1
|
||
metadata_expand.cu_seqlens_q = torch.arange(
|
||
0,
|
||
metadata_expand.cache_seqlens_int32.numel() + 1,
|
||
dtype=torch.int32,
|
||
device=device,
|
||
)
|
||
metadata_expand.cu_seqlens_k = torch.arange(
|
||
0,
|
||
metadata_expand.cache_seqlens_int32.numel() * decode_length + 1,
|
||
step=decode_length,
|
||
dtype=torch.int32,
|
||
device=device,
|
||
)
|
||
# shape: [bs, num_steps, topk] -> [bs x topk, num_steps]
|
||
cache_loc = forward_batch.out_cache_loc.view(
|
||
-1, self.speculative_num_steps
|
||
)
|
||
metadata_expand.page_table = (
|
||
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
||
)
|
||
self.forward_metadata_spec_decode_expand = metadata_expand
|
||
else:
|
||
# Normal Decode
|
||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||
metadata.cu_seqlens_q = torch.arange(
|
||
0, batch_size + 1, dtype=torch.int32, device=device
|
||
)
|
||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||
)
|
||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||
]
|
||
# TODO: we need to test this part for llama 4 eagle case
|
||
self._init_local_attn_metadata(forward_batch, metadata, device)
|
||
elif forward_batch.forward_mode.is_target_verify():
|
||
if self.topk <= 1:
|
||
metadata.cache_seqlens_int32 = (
|
||
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
||
).to(torch.int32)
|
||
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
||
metadata.max_seq_len_k = (
|
||
forward_batch.seq_lens_cpu.max().item()
|
||
+ self.speculative_num_draft_tokens
|
||
)
|
||
metadata.cu_seqlens_q = torch.arange(
|
||
0,
|
||
batch_size * self.speculative_num_draft_tokens + 1,
|
||
self.speculative_num_draft_tokens,
|
||
dtype=torch.int32,
|
||
device=device,
|
||
)
|
||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||
torch.cumsum(
|
||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||
),
|
||
(1, 0),
|
||
)
|
||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||
]
|
||
|
||
self._init_local_attn_metadata(forward_batch, metadata, device)
|
||
else:
|
||
metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
|
||
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||
metadata.cu_seqlens_q = torch.arange(
|
||
0,
|
||
batch_size * self.speculative_num_draft_tokens + 1,
|
||
step=self.speculative_num_draft_tokens,
|
||
dtype=torch.int32,
|
||
device=device,
|
||
)
|
||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||
torch.cumsum(
|
||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||
),
|
||
(1, 0),
|
||
)
|
||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||
]
|
||
|
||
metadata_expand = FlashAttentionMetadata()
|
||
|
||
metadata_expand.max_seq_len_q = 1
|
||
metadata_expand.cu_seqlens_q = torch.arange(
|
||
0,
|
||
forward_batch.seq_lens.numel() * self.speculative_num_draft_tokens
|
||
+ 1,
|
||
dtype=torch.int32,
|
||
device=device,
|
||
)
|
||
|
||
# create expand page table
|
||
offsets = torch.arange(
|
||
self.speculative_num_draft_tokens, device=device
|
||
).unsqueeze(
|
||
0
|
||
) # shape: (1, self.speculative_num_draft_tokens)
|
||
cols = offsets.expand(
|
||
forward_batch.seq_lens.numel(), -1
|
||
) + forward_batch.seq_lens.unsqueeze(1)
|
||
cum_len = torch.nn.functional.pad(
|
||
torch.cumsum(
|
||
(
|
||
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
||
).repeat_interleave(self.speculative_num_draft_tokens),
|
||
dim=0,
|
||
),
|
||
(1, 0),
|
||
)[:-1]
|
||
mask_extraction_indices = (
|
||
cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
||
+ cum_len[:, None]
|
||
).view(1, -1)
|
||
mask = forward_batch.spec_info.custom_mask[
|
||
mask_extraction_indices
|
||
].view(
|
||
-1, self.speculative_num_draft_tokens
|
||
) # (bsz * draft_num, draft_num)
|
||
|
||
# shift table indices to avoid padding
|
||
# non_masked_page_table [[8, 9, 10], mask (display with int format) [[1, 0, 0],
|
||
# [8, 9, 10], [1, 1, 0],
|
||
# [8, 9, 10]] [1, 0, 1]]
|
||
# if masked with padding [[8, 0, 0], our mask without padding [[8, 9, 10],
|
||
# [8, 9, 0], [8, 9, 10],
|
||
# [8, 0, 10]] [8, 10, 9]]
|
||
# note here cache_seqlens_int32 is [1, 2, 2] so extra page indices will be ignored in each row
|
||
col_indices = offsets.expand(
|
||
mask.shape[0], self.speculative_num_draft_tokens
|
||
)
|
||
# Build keys: if an entry is valid (mask==True), keep its original index;
|
||
# if not, add self.speculative_num_draft_tokens so that it sorts after all valid entries.
|
||
keys = torch.where(
|
||
mask, col_indices, col_indices + self.speculative_num_draft_tokens
|
||
)
|
||
_, sort_order = torch.sort(keys, dim=1)
|
||
non_masked_page_table = (
|
||
forward_batch.req_to_token_pool.req_to_token[
|
||
forward_batch.req_pool_indices, :
|
||
]
|
||
.gather(1, cols)
|
||
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
||
) # (bsz, draft_num)
|
||
metadata_expand.page_table = non_masked_page_table.gather(1, sort_order)
|
||
metadata_expand.cache_seqlens_int32 = mask.sum(dim=1).to(torch.int32)
|
||
metadata_expand.cu_seqlens_k = torch.nn.functional.pad(
|
||
torch.cumsum(
|
||
metadata_expand.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||
),
|
||
(1, 0),
|
||
)
|
||
self.forward_metadata_spec_decode_expand = metadata_expand
|
||
|
||
if self.has_swa:
|
||
self._init_sliding_window_attn_spec_metadata(
|
||
metadata, metadata_expand
|
||
)
|
||
|
||
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
|
||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||
)
|
||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||
]
|
||
|
||
if (
|
||
any(forward_batch.extend_prefix_lens_cpu)
|
||
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
|
||
):
|
||
extend_seq_lens = forward_batch.extend_seq_lens
|
||
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
||
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
||
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
||
)
|
||
else:
|
||
metadata.max_seq_len_q = metadata.max_seq_len_k
|
||
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
||
|
||
# Setup local attention if enabled
|
||
if forward_batch.forward_mode == ForwardMode.EXTEND:
|
||
self._init_local_attn_metadata(forward_batch, metadata, device)
|
||
|
||
# Encoder metadata for cross attention
|
||
if forward_batch.encoder_lens is not None:
|
||
assert (
|
||
forward_batch.encoder_lens.numel() == 1
|
||
), "Only encoder size 1 is supported for now"
|
||
|
||
metadata.encoder_lens_int32 = forward_batch.encoder_lens.to(torch.int32)
|
||
metadata.encoder_cu_seqlens_k = torch.nn.functional.pad(
|
||
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
|
||
(1, 0),
|
||
)
|
||
metadata.encoder_max_seq_len_k = metadata.encoder_lens_int32.max().item()
|
||
metadata.encoder_page_table = forward_batch.req_to_token_pool.req_to_token[
|
||
forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k
|
||
]
|
||
|
||
# Currently only support forward_batch.encoder_lens.numel() == 1
|
||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||
forward_batch.req_pool_indices,
|
||
metadata.encoder_max_seq_len_k : (
|
||
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
|
||
),
|
||
]
|
||
|
||
# Convert the page table to a strided format which is needed by FA3 API
|
||
if self.page_size > 1:
|
||
self.strided_indices = torch.arange(
|
||
0, metadata.page_table.shape[1], self.page_size, device=self.device
|
||
)
|
||
metadata.page_table = (
|
||
metadata.page_table[:, self.strided_indices] // self.page_size
|
||
)
|
||
|
||
self.forward_metadata = metadata
|
||
|
||
def forward_extend(
|
||
self,
|
||
q: torch.Tensor,
|
||
k: torch.Tensor,
|
||
v: torch.Tensor,
|
||
layer: RadixAttention,
|
||
forward_batch: ForwardBatch,
|
||
save_kv_cache=True,
|
||
# For multi-head latent attention
|
||
q_rope: Optional[torch.Tensor] = None,
|
||
k_rope: Optional[torch.Tensor] = None,
|
||
sinks: Optional[torch.Tensor] = None,
|
||
):
|
||
if k is not None:
|
||
assert v is not None
|
||
if save_kv_cache:
|
||
cache_loc = (
|
||
forward_batch.out_cache_loc
|
||
if not layer.is_cross_attention
|
||
else forward_batch.encoder_out_cache_loc
|
||
)
|
||
if not self.use_mla:
|
||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||
)
|
||
else:
|
||
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
||
layer,
|
||
cache_loc,
|
||
k,
|
||
k_rope,
|
||
)
|
||
|
||
# Use precomputed metadata across all layers
|
||
metadata = self.forward_metadata
|
||
|
||
# Calculate window size (can be moved to metadata if layer properties don't change)
|
||
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
||
# here is two side inclusive
|
||
is_swa = (
|
||
layer.sliding_window_size is not None and layer.sliding_window_size > -1
|
||
)
|
||
window_size = (layer.sliding_window_size, 0) if is_swa else (-1, -1)
|
||
|
||
# currently no FP8 KV cache supported
|
||
k_descale, v_descale = None, None
|
||
# # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
||
# # has corresponding quantization method so that layer.k_scale is not None,
|
||
# # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
|
||
# if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
|
||
# if layer.k_scale is not None:
|
||
# descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
||
# k_descale = layer.k_scale.expand(descale_shape)
|
||
# v_descale = layer.v_scale.expand(descale_shape)
|
||
# q = q.to(self.kv_cache_dtype)
|
||
# q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
|
||
# k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
|
||
causal = not layer.is_cross_attention
|
||
|
||
# Check if we should use local attention
|
||
use_local_attn = (
|
||
self.attention_chunk_size is not None
|
||
and metadata.local_attn_metadata is not None
|
||
and (hasattr(layer, "use_irope") and layer.use_irope)
|
||
)
|
||
|
||
# We do cascade attention for Target Verify with topk > 1
|
||
# We don't use cascade attention for Sliding Window Attention:
|
||
# - Different window sizes should be passed in for each q in the first stage of cascade attention, but FA3 interface doesn't support pass in a list of window sizes.
|
||
# - The overhead of duplicated computation of the common prefix part is small for sliding window layers (seq_len <= window_size), so we can just expand it.
|
||
use_cascade_attn = (
|
||
forward_batch.forward_mode.is_target_verify()
|
||
and self.topk > 1
|
||
and not is_swa
|
||
)
|
||
|
||
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
||
kwargs = {}
|
||
if sinks is not None:
|
||
kwargs["sinks"] = sinks
|
||
|
||
# Get the appropriate page table based on whether we're using local attention
|
||
if use_local_attn:
|
||
local_metadata = metadata.local_attn_metadata
|
||
page_table = local_metadata.local_block_table
|
||
cu_seqlens_q = local_metadata.local_query_start_loc
|
||
cache_seqlens = local_metadata.local_seqused_k
|
||
max_seqlen_q = local_metadata.local_max_query_len
|
||
elif is_swa and metadata.swa_spec_metadata is not None:
|
||
swa_spec_metadata = metadata.swa_spec_metadata
|
||
page_table = swa_spec_metadata.page_table
|
||
cu_seqlens_q = swa_spec_metadata.cu_seqlens_q
|
||
cache_seqlens = swa_spec_metadata.cache_seqlens_int32
|
||
max_seqlen_q = swa_spec_metadata.max_seq_len_q
|
||
cu_seqlens_k = swa_spec_metadata.cu_seqlens_k
|
||
else:
|
||
page_table = metadata.page_table
|
||
cu_seqlens_q = metadata.cu_seqlens_q
|
||
cache_seqlens = metadata.cache_seqlens_int32
|
||
max_seqlen_q = metadata.max_seq_len_q
|
||
cu_seqlens_k = metadata.cu_seqlens_k
|
||
|
||
# Use Flash Attention for prefill
|
||
if not self.use_mla:
|
||
# Do multi-head attention
|
||
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
||
layer.layer_id
|
||
)
|
||
key_cache = key_cache.view(
|
||
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
||
)
|
||
value_cache = value_cache.view(
|
||
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
||
)
|
||
if layer.is_cross_attention:
|
||
page_table = metadata.encoder_page_table
|
||
cache_seqlens = metadata.encoder_lens_int32
|
||
cu_seqlens_k = metadata.encoder_cu_seqlens_k
|
||
window_size = (-1, -1)
|
||
|
||
result = flash_attn_with_kvcache(
|
||
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||
k_cache=key_cache,
|
||
v_cache=value_cache,
|
||
page_table=page_table,
|
||
cache_seqlens=cache_seqlens,
|
||
cu_seqlens_q=cu_seqlens_q,
|
||
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
||
max_seqlen_q=max_seqlen_q,
|
||
softmax_scale=layer.scaling,
|
||
causal=False if use_cascade_attn else causal,
|
||
window_size=window_size,
|
||
softcap=layer.logit_cap,
|
||
k_descale=k_descale,
|
||
v_descale=v_descale,
|
||
return_softmax_lse=use_cascade_attn,
|
||
**kwargs,
|
||
)
|
||
|
||
if use_cascade_attn:
|
||
o, softmax_lse, *rest = result
|
||
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
|
||
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||
k_cache=key_cache,
|
||
v_cache=value_cache,
|
||
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
||
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
||
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
||
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
||
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
||
softmax_scale=layer.scaling,
|
||
causal=False,
|
||
window_size=window_size,
|
||
softcap=layer.logit_cap,
|
||
k_descale=k_descale,
|
||
v_descale=v_descale,
|
||
return_softmax_lse=True,
|
||
**kwargs,
|
||
)
|
||
o, _ = merge_state_v2_wrapper(
|
||
o,
|
||
softmax_lse.T.contiguous(),
|
||
o_expand,
|
||
softmax_lse_expand.T.contiguous(),
|
||
)
|
||
else:
|
||
o = result
|
||
else:
|
||
if (
|
||
forward_batch.attn_attend_prefix_cache is not None
|
||
and not forward_batch.forward_mode.is_target_verify()
|
||
and not forward_batch.forward_mode.is_draft_extend()
|
||
):
|
||
# Do multi-head attention with chunked prefix cache
|
||
if forward_batch.attn_attend_prefix_cache:
|
||
assert not get_global_server_args().disable_chunked_prefix_cache
|
||
# MHA for chunked prefix kv cache when running model with MLA
|
||
assert forward_batch.prefix_chunk_idx is not None
|
||
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
||
assert forward_batch.prefix_chunk_max_seq_lens is not None
|
||
|
||
chunk_idx = forward_batch.prefix_chunk_idx
|
||
assert chunk_idx >= 0
|
||
|
||
assert forward_batch.mha_return_lse
|
||
output = flash_attn_varlen_func(
|
||
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
||
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
|
||
max_seqlen_q=metadata.max_seq_len_q,
|
||
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
|
||
softmax_scale=layer.scaling,
|
||
causal=False,
|
||
return_softmax_lse=True,
|
||
)
|
||
else:
|
||
# MHA for extend part of sequence without attending prefix kv cache
|
||
output = flash_attn_varlen_func(
|
||
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
||
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||
cu_seqlens_k=metadata.cu_seqlens_q,
|
||
max_seqlen_q=metadata.max_seq_len_q,
|
||
max_seqlen_k=metadata.max_seq_len_q,
|
||
softmax_scale=layer.scaling,
|
||
causal=True,
|
||
return_softmax_lse=forward_batch.mha_return_lse,
|
||
)
|
||
if forward_batch.mha_return_lse:
|
||
output, lse, *rest = output
|
||
lse = torch.transpose(lse, 0, 1).contiguous()
|
||
return output, lse
|
||
return output
|
||
else:
|
||
# Do absorbed multi-latent attention
|
||
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
||
layer.layer_id
|
||
).to(q.dtype)
|
||
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
||
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
||
k_rope_cache = k_rope.view(
|
||
-1,
|
||
self.page_size,
|
||
layer.tp_k_head_num,
|
||
layer.head_dim - layer.v_head_dim,
|
||
)
|
||
c_kv_cache = c_kv.view(
|
||
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
||
)
|
||
if q_rope is not None:
|
||
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||
q_rope = q_rope.view(
|
||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||
)
|
||
else:
|
||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||
q_nope = q_all[:, :, : layer.v_head_dim]
|
||
q_rope = q_all[:, :, layer.v_head_dim :]
|
||
|
||
result = flash_attn_with_kvcache(
|
||
q=q_rope,
|
||
k_cache=k_rope_cache,
|
||
v_cache=c_kv_cache,
|
||
qv=q_nope,
|
||
page_table=page_table,
|
||
cache_seqlens=cache_seqlens,
|
||
cu_seqlens_q=cu_seqlens_q,
|
||
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
||
max_seqlen_q=max_seqlen_q,
|
||
softmax_scale=layer.scaling,
|
||
causal=False if use_cascade_attn else causal,
|
||
softcap=layer.logit_cap,
|
||
k_descale=k_descale,
|
||
v_descale=v_descale,
|
||
return_softmax_lse=use_cascade_attn,
|
||
)
|
||
if use_cascade_attn:
|
||
o, softmax_lse, *rest = result
|
||
o_expand, softmax_lse_expand, *rest_expand = (
|
||
flash_attn_with_kvcache(
|
||
q=q_rope,
|
||
k_cache=k_rope_cache,
|
||
v_cache=c_kv_cache,
|
||
qv=q_nope,
|
||
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
||
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
||
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
||
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
||
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
||
softmax_scale=layer.scaling,
|
||
causal=False,
|
||
window_size=window_size,
|
||
softcap=layer.logit_cap,
|
||
k_descale=k_descale,
|
||
v_descale=v_descale,
|
||
return_softmax_lse=True,
|
||
)
|
||
)
|
||
o, _ = merge_state_v2_wrapper(
|
||
o,
|
||
softmax_lse.T.contiguous(),
|
||
o_expand,
|
||
softmax_lse_expand.T.contiguous(),
|
||
)
|
||
else:
|
||
o = result
|
||
|
||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||
|
||
def forward_decode(
|
||
self,
|
||
q: torch.Tensor,
|
||
k: torch.Tensor,
|
||
v: torch.Tensor,
|
||
layer: RadixAttention,
|
||
forward_batch: ForwardBatch,
|
||
save_kv_cache=True,
|
||
# For multi-head latent attention
|
||
q_rope: Optional[torch.Tensor] = None,
|
||
k_rope: Optional[torch.Tensor] = None,
|
||
sinks: Optional[torch.Tensor] = None,
|
||
) -> torch.Tensor:
|
||
if k is not None:
|
||
assert v is not None
|
||
if save_kv_cache:
|
||
cache_loc = (
|
||
forward_batch.out_cache_loc
|
||
if not layer.is_cross_attention
|
||
else forward_batch.encoder_out_cache_loc
|
||
)
|
||
if not self.use_mla:
|
||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||
)
|
||
else:
|
||
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
||
layer,
|
||
cache_loc,
|
||
k,
|
||
k_rope,
|
||
)
|
||
|
||
# Use precomputed metadata across all layers
|
||
metadata = self.forward_metadata
|
||
local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
|
||
use_local_attn = (
|
||
self.attention_chunk_size is not None
|
||
and local_attn_metadata is not None
|
||
and (hasattr(layer, "use_irope") and layer.use_irope)
|
||
)
|
||
|
||
# When Spec Decode enabled, forward_decode would be called with two mode:
|
||
# 1. DRAFT_DECODE: we enable cascade attention when top_k > 1
|
||
# 2. IDLE: we don’t need cascade attention, spec_info will be none in this case
|
||
use_cascade_attn = forward_batch.spec_info is not None and self.topk > 1
|
||
|
||
# Calculate window size (can be moved to metadata if layer properties don't change)
|
||
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
||
# here is two side inclusive
|
||
window_size = (
|
||
(layer.sliding_window_size, 0)
|
||
if layer.sliding_window_size is not None and layer.sliding_window_size > -1
|
||
else (-1, -1)
|
||
)
|
||
causal = not layer.is_cross_attention
|
||
|
||
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
||
kwargs = {}
|
||
if sinks is not None:
|
||
kwargs["sinks"] = sinks
|
||
|
||
k_descale, v_descale = None, None
|
||
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
||
# has corresponding quantization method so that layer.k_scale is not None,
|
||
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
|
||
if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
|
||
if layer.k_scale is not None:
|
||
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
||
k_descale = layer.k_scale.expand(descale_shape)
|
||
v_descale = layer.v_scale.expand(descale_shape)
|
||
q = q.to(self.kv_cache_dtype)
|
||
q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
|
||
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
|
||
if not self.use_mla:
|
||
# Do multi-head attention
|
||
|
||
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
||
layer.layer_id
|
||
)
|
||
key_cache = key_cache.view(
|
||
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
||
)
|
||
value_cache = value_cache.view(
|
||
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
||
)
|
||
|
||
if layer.is_cross_attention:
|
||
# Always use non-chunked logic for cross-attention
|
||
o = flash_attn_with_kvcache(
|
||
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||
k_cache=key_cache,
|
||
v_cache=value_cache,
|
||
page_table=metadata.encoder_page_table,
|
||
cache_seqlens=metadata.encoder_lens_int32,
|
||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||
cu_seqlens_k_new=metadata.encoder_cu_seqlens_k,
|
||
max_seqlen_q=1,
|
||
softmax_scale=layer.scaling,
|
||
causal=False,
|
||
window_size=(-1, -1),
|
||
softcap=layer.logit_cap,
|
||
k_descale=k_descale,
|
||
v_descale=v_descale,
|
||
**kwargs,
|
||
)
|
||
elif use_local_attn:
|
||
# Use chunked (local) attention batching for self-attention
|
||
o = flash_attn_with_kvcache(
|
||
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||
k_cache=key_cache,
|
||
v_cache=value_cache,
|
||
page_table=local_attn_metadata.local_block_table,
|
||
cache_seqlens=local_attn_metadata.local_seqused_k,
|
||
cu_seqlens_q=local_attn_metadata.local_query_start_loc,
|
||
cu_seqlens_k_new=None,
|
||
max_seqlen_q=local_attn_metadata.local_max_query_len,
|
||
softmax_scale=layer.scaling,
|
||
causal=True,
|
||
window_size=(-1, -1),
|
||
softcap=layer.logit_cap,
|
||
k_descale=k_descale,
|
||
v_descale=v_descale,
|
||
**kwargs,
|
||
)
|
||
else:
|
||
page_table = metadata.page_table
|
||
cache_seqlens = metadata.cache_seqlens_int32
|
||
cu_seqlens_k = metadata.cu_seqlens_k
|
||
max_seqlen_q = metadata.max_seq_len_q
|
||
q_reshaped = q.contiguous().view(
|
||
-1, layer.tp_q_head_num, layer.head_dim
|
||
)
|
||
|
||
# Default: single-token self-attention
|
||
result = flash_attn_with_kvcache(
|
||
q=q_reshaped,
|
||
k_cache=key_cache,
|
||
v_cache=value_cache,
|
||
page_table=page_table,
|
||
cache_seqlens=cache_seqlens,
|
||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||
cu_seqlens_k_new=cu_seqlens_k,
|
||
max_seqlen_q=max_seqlen_q,
|
||
softmax_scale=layer.scaling,
|
||
causal=False if use_cascade_attn else causal,
|
||
window_size=window_size,
|
||
softcap=layer.logit_cap,
|
||
k_descale=k_descale,
|
||
v_descale=v_descale,
|
||
return_softmax_lse=use_cascade_attn,
|
||
**kwargs,
|
||
)
|
||
if use_cascade_attn:
|
||
o, softmax_lse, *rest = result
|
||
o_expand, softmax_lse_expand, *rest_expand = (
|
||
flash_attn_with_kvcache(
|
||
q=q_reshaped,
|
||
k_cache=key_cache,
|
||
v_cache=value_cache,
|
||
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
||
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
||
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
||
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
||
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
||
softmax_scale=layer.scaling,
|
||
causal=False,
|
||
window_size=window_size,
|
||
softcap=layer.logit_cap,
|
||
k_descale=k_descale,
|
||
v_descale=v_descale,
|
||
return_softmax_lse=True,
|
||
**kwargs,
|
||
)
|
||
)
|
||
o, _ = merge_state_v2(
|
||
o,
|
||
softmax_lse.T.contiguous(),
|
||
o_expand,
|
||
softmax_lse_expand.T.contiguous(),
|
||
)
|
||
else:
|
||
o = result
|
||
else:
|
||
# Do absorbed multi-latent attention
|
||
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
|
||
q.dtype
|
||
)
|
||
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
||
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
||
k_rope_cache = k_rope.view(
|
||
-1,
|
||
self.page_size,
|
||
layer.tp_k_head_num,
|
||
layer.head_dim - layer.v_head_dim,
|
||
)
|
||
c_kv_cache = c_kv.view(
|
||
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
||
)
|
||
|
||
if q_rope is not None:
|
||
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||
q_rope = q_rope.view(
|
||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||
)
|
||
else:
|
||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||
q_nope = q_all[:, :, : layer.v_head_dim]
|
||
q_rope = q_all[:, :, layer.v_head_dim :]
|
||
max_seqlen_q = metadata.max_seq_len_q
|
||
|
||
result = flash_attn_with_kvcache(
|
||
q=q_rope,
|
||
k_cache=k_rope_cache,
|
||
v_cache=c_kv_cache,
|
||
qv=q_nope,
|
||
page_table=metadata.page_table,
|
||
cache_seqlens=metadata.cache_seqlens_int32,
|
||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||
max_seqlen_q=max_seqlen_q,
|
||
softmax_scale=layer.scaling,
|
||
causal=False if use_cascade_attn else causal,
|
||
softcap=layer.logit_cap,
|
||
k_descale=k_descale,
|
||
v_descale=v_descale,
|
||
return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
|
||
)
|
||
if use_cascade_attn:
|
||
o, softmax_lse, *rest = result
|
||
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
|
||
q=q_rope,
|
||
k_cache=k_rope_cache,
|
||
v_cache=c_kv_cache,
|
||
qv=q_nope,
|
||
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
||
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
||
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
||
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
||
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
||
softmax_scale=layer.scaling,
|
||
causal=False,
|
||
window_size=window_size,
|
||
softcap=layer.logit_cap,
|
||
k_descale=k_descale,
|
||
v_descale=v_descale,
|
||
return_softmax_lse=True,
|
||
)
|
||
o, _ = merge_state_v2(
|
||
o,
|
||
softmax_lse.T.contiguous(),
|
||
o_expand,
|
||
softmax_lse_expand.T.contiguous(),
|
||
)
|
||
else:
|
||
o = result
|
||
|
||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||
|
||
def get_cuda_graph_seq_len_fill_value(self):
|
||
"""Get the fill value for sequence length in CUDA graph."""
|
||
return 1
|
||
|
||
def _init_local_attn_metadata(
|
||
self, forwardbatch: ForwardBatch, metadata: FlashAttentionMetadata, device
|
||
):
|
||
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
|
||
if self.attention_chunk_size is None:
|
||
metadata.local_attn_metadata = None
|
||
return
|
||
|
||
cu_seqlens_q = metadata.cu_seqlens_q
|
||
cache_seqlens_int32 = metadata.cache_seqlens_int32
|
||
if self.is_hybrid:
|
||
page_table = self.full_to_swa_index_mapping[metadata.page_table].to(
|
||
torch.int32
|
||
)
|
||
else:
|
||
page_table = metadata.page_table
|
||
if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
|
||
metadata.local_attn_metadata = None
|
||
return
|
||
|
||
cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
|
||
seq_lens_np = cache_seqlens_int32.cpu().numpy()
|
||
(
|
||
seqlens_q_local_np,
|
||
cu_seqlens_q_local_np,
|
||
seqlens_k_local_np,
|
||
block_table_local,
|
||
) = make_local_attention_virtual_batches(
|
||
self.attention_chunk_size,
|
||
cu_seqlens_q_np,
|
||
seq_lens_np,
|
||
page_table,
|
||
self.page_size,
|
||
)
|
||
|
||
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
||
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
|
||
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
|
||
local_block_table=block_table_local.to(device),
|
||
local_max_query_len=int(seqlens_q_local_np.max()),
|
||
local_max_seq_len=int(seqlens_k_local_np.max()),
|
||
)
|
||
metadata.local_attn_metadata = local_metadata
|
||
|
||
def _init_sliding_window_attn_spec_metadata(
|
||
self,
|
||
metadata: FlashAttentionMetadata,
|
||
metadata_expand: FlashAttentionMetadata,
|
||
metadata_swa: Optional[FlashAttentionMetadata] = None,
|
||
):
|
||
# TODO: support page_size > 1 for swa spec
|
||
assert (
|
||
self.page_size == 1
|
||
), "FlashAttention backend doesn't support topk > 1 speculative decoding with page size > 1 sliding window attention"
|
||
|
||
cache_seqlens_int32 = (
|
||
metadata.cache_seqlens_int32.repeat_interleave(
|
||
self.speculative_num_draft_tokens
|
||
)
|
||
+ metadata_expand.cache_seqlens_int32
|
||
)
|
||
cu_seqlens_k = torch.nn.functional.pad(
|
||
torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32), (1, 0)
|
||
)
|
||
bs = cache_seqlens_int32.shape[0]
|
||
page_table = (
|
||
metadata.page_table.new_zeros(
|
||
(bs, metadata.max_seq_len_k + metadata_expand.page_table.shape[1])
|
||
)
|
||
if metadata_swa is None
|
||
else metadata_swa.page_table
|
||
)
|
||
|
||
prepare_swa_spec_page_table_triton(
|
||
page_table,
|
||
metadata.page_table,
|
||
metadata_expand.page_table,
|
||
metadata.cache_seqlens_int32,
|
||
metadata_expand.cache_seqlens_int32,
|
||
self.speculative_num_draft_tokens,
|
||
)
|
||
|
||
if metadata_swa is None:
|
||
metadata_swa = FlashAttentionMetadata()
|
||
metadata_swa.max_seq_len_q = 1
|
||
metadata_swa.cu_seqlens_q = metadata_expand.cu_seqlens_q
|
||
metadata_swa.cache_seqlens_int32 = cache_seqlens_int32
|
||
metadata_swa.cu_seqlens_k = cu_seqlens_k
|
||
metadata_swa.page_table = page_table
|
||
else:
|
||
metadata_swa.cache_seqlens_int32.copy_(cache_seqlens_int32)
|
||
metadata_swa.cu_seqlens_k.copy_(cu_seqlens_k)
|
||
|
||
metadata.swa_spec_metadata = metadata_swa
|