Support FA3 as Attention backend by using --attention-backend fa3 (#4680)
Co-authored-by: qsong <qsong@linkedin.com> Co-authored-by: qingquansong <ustcsqq@gmail.com>
This commit is contained in:
295
python/sglang/srt/layers/attention/flashattention_backend.py
Normal file
295
python/sglang/srt/layers/attention/flashattention_backend.py
Normal file
@@ -0,0 +1,295 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
|
||||
"""
|
||||
Support different attention backends.
|
||||
Now there are three backends: FlashInfer, Triton and FlashAttention.
|
||||
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
from flash_attn_interface import flash_attn_with_kvcache
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata:
|
||||
"""Metadata for decode operations to avoid redundant computations."""
|
||||
|
||||
cu_seqlens_q: torch.Tensor = None
|
||||
cu_seqlens_k: torch.Tensor = None
|
||||
max_seq_len_k: int = 0
|
||||
window_size: tuple = (-1, -1)
|
||||
page_table: torch.Tensor = None
|
||||
cache_seqlens_int32: torch.Tensor = None
|
||||
max_seq_len_q: int = 0
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
"""FlashAttention backend implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_runner: ModelRunner,
|
||||
skip_prefill: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert not (
|
||||
model_runner.sliding_window_size is not None
|
||||
and model_runner.model_config.is_encoder_decoder
|
||||
), "Sliding window and cross attention are not supported together"
|
||||
|
||||
# Initialize metadata
|
||||
self.forward_metadata: FlashAttentionMetadata = None
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.device = model_runner.device
|
||||
self.decode_cuda_graph_metadata = {}
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Initialize forward metadata to cache repetitive calculations."""
|
||||
# Create metadata based on forward mode
|
||||
metadata = FlashAttentionMetadata()
|
||||
|
||||
extend_seq_lens = forward_batch.extend_seq_lens
|
||||
# Get sequence information
|
||||
seqlens_in_batch = forward_batch.seq_lens
|
||||
# Precompute int32 version of sequence lengths
|
||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||
batch_size = len(seqlens_in_batch)
|
||||
device = seqlens_in_batch.device
|
||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
# Precompute maximum sequence length
|
||||
metadata.max_seq_len_k = seqlens_in_batch.max().item()
|
||||
# Precompute page table
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
if forward_batch.forward_mode == ForwardMode.DECODE:
|
||||
# Precompute cumulative sequence lengths
|
||||
metadata.cu_seqlens_q = torch.arange(
|
||||
0, batch_size + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
else:
|
||||
extend_no_prefix = not any(forward_batch.extend_prefix_lens)
|
||||
# Precompute cumulative sequence lengths
|
||||
if not extend_no_prefix:
|
||||
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
||||
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
else:
|
||||
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
||||
metadata.max_seq_len_q = seqlens_in_batch.max().item()
|
||||
self.forward_metadata = metadata
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
cache_loc = (
|
||||
forward_batch.out_cache_loc
|
||||
if not layer.is_cross_attention
|
||||
else forward_batch.encoder_out_cache_loc
|
||||
)
|
||||
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
|
||||
# Use precomputed metadata
|
||||
metadata = self.forward_metadata
|
||||
|
||||
# # Use Flash Attention for prefill
|
||||
# Calculate window size (can be moved to metadata if layer properties don't change)
|
||||
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
||||
# here is two side inclusive
|
||||
window_size = (
|
||||
(layer.sliding_window_size, 0)
|
||||
if layer.sliding_window_size is not None
|
||||
else (-1, -1)
|
||||
)
|
||||
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
o = flash_attn_with_kvcache(
|
||||
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k_cache=key_cache.unsqueeze(1),
|
||||
v_cache=value_cache.unsqueeze(1),
|
||||
page_table=metadata.page_table,
|
||||
cache_seqlens=metadata.cache_seqlens_int32,
|
||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||
max_seqlen_q=metadata.max_seq_len_q,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=layer.k_scale,
|
||||
v_descale=layer.v_scale,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention using precomputed metadata."""
|
||||
# Save KV cache if needed
|
||||
if k is not None and v is not None and save_kv_cache:
|
||||
cache_loc = (
|
||||
forward_batch.out_cache_loc
|
||||
if not layer.is_cross_attention
|
||||
else forward_batch.encoder_out_cache_loc
|
||||
)
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
|
||||
# Get KV cache
|
||||
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
|
||||
# Use precomputed metadata
|
||||
metadata = self.forward_metadata
|
||||
|
||||
# Pre-reshape query tensor
|
||||
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
|
||||
# Calculate window size (can be moved to metadata if layer properties don't change)
|
||||
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
||||
# here is two side inclusive
|
||||
window_size = (
|
||||
(layer.sliding_window_size, 0)
|
||||
if layer.sliding_window_size is not None
|
||||
else (-1, -1)
|
||||
)
|
||||
# Run attention with precomputed values
|
||||
o = flash_attn_with_kvcache(
|
||||
q=q_reshaped,
|
||||
k_cache=key_cache.unsqueeze(1),
|
||||
v_cache=value_cache.unsqueeze(1),
|
||||
page_table=metadata.page_table,
|
||||
cache_seqlens=metadata.cache_seqlens_int32,
|
||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||
max_seqlen_q=1,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=layer.k_scale,
|
||||
v_descale=layer.v_scale,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
"""Initialize CUDA graph state for the attention backend.
|
||||
|
||||
Args:
|
||||
max_bs (int): Maximum batch size to support in CUDA graphs
|
||||
|
||||
This creates fixed-size tensors that will be reused during CUDA graph replay
|
||||
to avoid memory allocations.
|
||||
"""
|
||||
# Initialize fixed size tensors for decode operations
|
||||
self.decode_cuda_graph_metadata = {
|
||||
# Page table for token mapping (batch_size, max_context_len)
|
||||
"page_table": torch.zeros(
|
||||
max_bs, self.max_context_len, dtype=torch.int32, device=self.device
|
||||
),
|
||||
}
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_tokens: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
"""Initialize forward metadata for capturing CUDA graph."""
|
||||
metadata = FlashAttentionMetadata()
|
||||
# Get sequence information
|
||||
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
||||
batch_size = len(seq_lens)
|
||||
device = seq_lens.device
|
||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
# Precompute maximum sequence length
|
||||
metadata.max_seq_len_k = seq_lens.max().item()
|
||||
# Precompute page table
|
||||
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
||||
req_pool_indices, :
|
||||
]
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
# Precompute cumulative sequence lengths
|
||||
metadata.cu_seqlens_q = torch.arange(
|
||||
0, batch_size + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
else:
|
||||
raise ValueError("Do not support Prefill Mode cuda graph")
|
||||
self.decode_cuda_graph_metadata[bs] = metadata
|
||||
self.forward_metadata = metadata
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
# """Initialize forward metadata for replaying CUDA graph."""
|
||||
seqlens_in_batch = seq_lens[:bs]
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
# Precompute maximum sequence length
|
||||
metadata.max_seq_len_k = seqlens_in_batch.max().item()
|
||||
# Only zero out the part out of max_len_k
|
||||
metadata.page_table[:, metadata.max_seq_len_k :].fill_(0)
|
||||
# Then do the copy
|
||||
metadata.page_table[:, : metadata.max_seq_len_k].copy_(
|
||||
self.req_to_token[req_pool_indices[:bs], : metadata.max_seq_len_k]
|
||||
)
|
||||
self.forward_decode_metadata = metadata
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
"""Get the fill value for sequence length in CUDA graph."""
|
||||
return 0
|
||||
Reference in New Issue
Block a user