Add Llama4 support (#5092)
Co-authored-by: Cheng Wan <cwan39@gatech.edu> Co-authored-by: fzyzcjy <ch271828n@outlook.com> Co-authored-by: ispobock <ispobaoke@163.com>
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
|
||||
"""
|
||||
@@ -45,6 +47,206 @@ class FlashAttentionMetadata:
|
||||
# Sequence lengths for the forward batch
|
||||
cache_seqlens_int32: torch.Tensor = None
|
||||
|
||||
@dataclass
|
||||
class LocalAttentionMetadata:
|
||||
local_query_start_loc: torch.Tensor = None # cu_seqlens_q for local attention
|
||||
local_seqused_k: torch.Tensor = None # sequence lengths for local attention
|
||||
local_block_table: torch.Tensor = None # block table for local attention
|
||||
local_max_query_len: int = 0 # max query length for local attention
|
||||
local_max_seq_len: int = 0 # max sequence length for local attention
|
||||
|
||||
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||
|
||||
|
||||
# Copied from:
|
||||
# https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py
|
||||
#
|
||||
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
||||
# local attention blocks, where each block is passed to the attention kernel
|
||||
# as an independent local ("virtual") batch item.
|
||||
#
|
||||
# For example, if are performing a chunked prefill a batch of 3 sequences:
|
||||
# q_seqlens = [4, 10, 5]
|
||||
# kv_seqlens = [6, 17, 9]
|
||||
# Then normally for regular attention we would compute with an attention mask
|
||||
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
|
||||
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
|
||||
# k_toks > 0 1 2 3 4 5
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# 2 | 1 1 1 1 1
|
||||
# 3 | 1 1 1 1 1 1
|
||||
#
|
||||
# for local attention (with attn_chunk_size = 4) we would compute with an
|
||||
# attention mask like:
|
||||
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
|
||||
# k_toks > 0 1 2 3 4 5
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# 2 | 1
|
||||
# 3 | 1 1
|
||||
#
|
||||
# We can simulate this mask using standard flash-attention by breaking the
|
||||
# sequences into local ("virtual") batches, where each local batch item is a
|
||||
# local attention block, so in this case batch idx 0 would be broken up into:
|
||||
#
|
||||
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
|
||||
# k_toks > 0 1 2 3
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
|
||||
# k_toks > 4 5
|
||||
# q_toks v _____________
|
||||
# 2 | 1
|
||||
# 3 | 1 1
|
||||
#
|
||||
# e.g. if we have:
|
||||
# attn_chunk_size = 4
|
||||
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
|
||||
# Then this function would return:
|
||||
# __b0__ ______b1______ __b2__ < orig batch indices
|
||||
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
|
||||
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
|
||||
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
|
||||
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
|
||||
def make_local_attention_virtual_batches(
|
||||
attn_chunk_size: int,
|
||||
query_start_loc_np: np.ndarray,
|
||||
seq_lens_np: np.ndarray,
|
||||
block_table: torch.Tensor,
|
||||
page_size: int = 0,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
||||
local attention blocks, where each block is passed to the attention kernel
|
||||
as an independent local ("virtual") batch item.
|
||||
|
||||
Args:
|
||||
attn_chunk_size: Size of local attention chunks
|
||||
query_start_loc_np: Cumulative sum of query lengths (numpy array)
|
||||
seq_lens_np: Sequence lengths (numpy array)
|
||||
block_table: Block table for KV cache
|
||||
page_size: Size of each page in the KV cache
|
||||
|
||||
Returns:
|
||||
seqlens_q_local: Query sequence lengths for local attention
|
||||
cu_seqlens_q_local: Cumulative sum of query sequence lengths for local attention
|
||||
seqlens_k_local: Key sequence lengths for local attention
|
||||
block_table_local: Block table for local attention
|
||||
"""
|
||||
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
||||
actual_batch_size = seq_lens_np.shape[0]
|
||||
|
||||
# Handle if we are starting in the middle of a local attention block,
|
||||
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
|
||||
# the number of tokens that are not in the first local attention block and
|
||||
# then we can simply use a cdiv for the rest.
|
||||
# For example if we have:
|
||||
# attn_chunk_size = 4
|
||||
# q_seqlens = [4, 10, 5]
|
||||
# k_seqlens = [6, 17, 9]
|
||||
# Then we would get:
|
||||
# new_tokens_in_first_block = [2, 1, 4]
|
||||
# local_blocks = [2, 4, 2]
|
||||
q_tokens_in_first_block = np.minimum(
|
||||
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
|
||||
).astype(np.int32)
|
||||
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
|
||||
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
|
||||
|
||||
# Once we know the number of local blocks we can compute the request spans
|
||||
# for each batch idx, we can figure out the number of "virtual" requests we
|
||||
# have to make,
|
||||
# For the above example we would get:
|
||||
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
|
||||
#
|
||||
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
|
||||
# (TODO: max a utility to share this code with _prepare_inputs)
|
||||
# arange step 1. [2, 4, 2] -> [2, 6, 8]
|
||||
cu_num_blocks = np.cumsum(local_blocks)
|
||||
virtual_batches = cu_num_blocks[-1]
|
||||
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
|
||||
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
|
||||
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
|
||||
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
|
||||
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
|
||||
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
|
||||
# Then we can compute the seqlens_q_local, handling the fact that the
|
||||
# first and last blocks could be partial
|
||||
seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
|
||||
# set the first block since this may be a partial block
|
||||
seqlens_q_local[arange == 0] = q_tokens_in_first_block
|
||||
# set the remaining blocks
|
||||
seqlens_q_local[arange > 0] = np.minimum(
|
||||
seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
|
||||
)[arange > 0]
|
||||
|
||||
# convert from q_seqlens to cu_seqlens_q
|
||||
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0)).astype(np.int32)
|
||||
|
||||
# compute the seqlens_k_local,
|
||||
# basically a full local attention block for all but the last block in each
|
||||
# batch
|
||||
# For our example this will be:
|
||||
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
|
||||
seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
|
||||
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
||||
|
||||
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
|
||||
rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
|
||||
)
|
||||
# For the example the local attention blocks start at:
|
||||
# _b0_ _____b1_____ _b2_
|
||||
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
|
||||
block_starts = k_seqstarts_absolute // page_size
|
||||
|
||||
assert attn_chunk_size % page_size == 0, (
|
||||
f"attn_chunk_size {attn_chunk_size} is not "
|
||||
f"divisible by page_size {page_size}"
|
||||
)
|
||||
pages_per_local_batch = attn_chunk_size // page_size
|
||||
|
||||
# Create a block_table for the local attention blocks
|
||||
# For out example if we have a block-table like (assuming page_size=2):
|
||||
# block_table = [
|
||||
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
|
||||
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
|
||||
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
|
||||
# ]
|
||||
# Then for the local batches we would want a block-table like
|
||||
# block_table_local = [
|
||||
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
|
||||
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
|
||||
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
|
||||
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
|
||||
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
|
||||
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
|
||||
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
|
||||
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
|
||||
# ]
|
||||
block_indices = np.broadcast_to(
|
||||
np.arange(pages_per_local_batch, dtype=np.int32),
|
||||
(virtual_batches, pages_per_local_batch),
|
||||
) + np.expand_dims(block_starts, axis=1)
|
||||
block_indices = block_indices.flatten()
|
||||
batch_indices = np.repeat(
|
||||
np.arange(actual_batch_size, dtype=np.int32),
|
||||
local_blocks * pages_per_local_batch,
|
||||
)
|
||||
block_table_local = block_table[batch_indices, block_indices].view(
|
||||
virtual_batches, -1
|
||||
)
|
||||
|
||||
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, block_table_local
|
||||
|
||||
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
"""Ceiling division."""
|
||||
return -(a // -b)
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
"""FlashAttention backend implementation.
|
||||
@@ -100,6 +302,13 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.step_id = step_id
|
||||
self.speculative_num_steps = speculative_num_steps
|
||||
|
||||
# Local attention settings
|
||||
self.attention_chunk_size = (
|
||||
model_runner.attention_chunk_size
|
||||
if hasattr(model_runner, "attention_chunk_size")
|
||||
else None
|
||||
)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Initialize forward metadata to cache repetitive calculations."""
|
||||
metadata = FlashAttentionMetadata()
|
||||
@@ -189,6 +398,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
|
||||
# Precompute cumulative sequence lengths
|
||||
if (
|
||||
any(forward_batch.extend_prefix_lens_cpu)
|
||||
@@ -203,6 +413,51 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
||||
metadata.max_seq_len_q = metadata.max_seq_len_k
|
||||
|
||||
# Setup local attention if enabled
|
||||
if (
|
||||
self.attention_chunk_size is not None
|
||||
and forward_batch.forward_mode == ForwardMode.EXTEND
|
||||
):
|
||||
# Convert tensors to numpy for local attention processing
|
||||
cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
|
||||
seq_lens_np = metadata.cache_seqlens_int32.cpu().numpy()
|
||||
|
||||
# Adjust attention_chunk_size based on the actual sequence length
|
||||
# to avoid index out of bounds errors
|
||||
max_seq_len = seq_lens_np.max()
|
||||
effective_chunk_size = min(self.attention_chunk_size, max_seq_len)
|
||||
# Make sure effective_chunk_size is divisible by page_size
|
||||
effective_chunk_size = (
|
||||
effective_chunk_size // self.page_size
|
||||
) * self.page_size
|
||||
if effective_chunk_size < self.page_size:
|
||||
effective_chunk_size = self.page_size
|
||||
|
||||
# Create local attention metadata
|
||||
(
|
||||
seqlens_q_local_np,
|
||||
cu_seqlens_q_local_np,
|
||||
seqlens_k_local_np,
|
||||
block_table_local,
|
||||
) = make_local_attention_virtual_batches(
|
||||
effective_chunk_size,
|
||||
cu_seqlens_q_np,
|
||||
seq_lens_np,
|
||||
metadata.page_table,
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
||||
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(
|
||||
device
|
||||
),
|
||||
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
|
||||
local_block_table=block_table_local,
|
||||
local_max_query_len=seqlens_q_local_np.max(),
|
||||
local_max_seq_len=seqlens_k_local_np.max(),
|
||||
)
|
||||
metadata.local_attn_metadata = local_metadata
|
||||
|
||||
# Precompute strided indices
|
||||
if self.page_size > 1:
|
||||
self.strided_indices = torch.arange(
|
||||
@@ -211,6 +466,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.page_table = (
|
||||
metadata.page_table[:, self.strided_indices] // self.page_size
|
||||
)
|
||||
|
||||
self.forward_metadata = metadata
|
||||
|
||||
def forward_extend(
|
||||
@@ -254,7 +510,28 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
else (-1, -1)
|
||||
)
|
||||
|
||||
page_table = metadata.page_table
|
||||
# Check if we should use local attention
|
||||
use_local_attn = (
|
||||
self.attention_chunk_size is not None
|
||||
and metadata.local_attn_metadata is not None
|
||||
and (hasattr(layer, "use_irope") and layer.use_irope)
|
||||
)
|
||||
|
||||
# Get the appropriate page table based on whether we're using local attention
|
||||
if use_local_attn:
|
||||
local_metadata = metadata.local_attn_metadata
|
||||
page_table = local_metadata.local_block_table
|
||||
cu_seqlens_q = local_metadata.local_query_start_loc
|
||||
cache_seqlens = local_metadata.local_seqused_k
|
||||
max_seqlen_q = local_metadata.local_max_query_len
|
||||
max_seqlen_k = local_metadata.local_max_seq_len
|
||||
else:
|
||||
page_table = metadata.page_table
|
||||
cu_seqlens_q = metadata.cu_seqlens_q
|
||||
cache_seqlens = metadata.cache_seqlens_int32
|
||||
max_seqlen_q = metadata.max_seq_len_q
|
||||
max_seqlen_k = metadata.max_seq_len_k
|
||||
cu_seqlens_k = metadata.cu_seqlens_k
|
||||
|
||||
# Use Flash Attention for prefill
|
||||
if not self.use_mla:
|
||||
@@ -272,10 +549,10 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
page_table=page_table,
|
||||
cache_seqlens=metadata.cache_seqlens_int32,
|
||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||
max_seqlen_q=metadata.max_seq_len_q,
|
||||
cache_seqlens=cache_seqlens,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
@@ -307,10 +584,10 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
v_cache=c_kv_cache,
|
||||
qv=q_nope,
|
||||
page_table=page_table,
|
||||
cache_seqlens=metadata.cache_seqlens_int32,
|
||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||
max_seqlen_q=metadata.max_seq_len_q,
|
||||
cache_seqlens=cache_seqlens,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
softcap=layer.logit_cap,
|
||||
|
||||
Reference in New Issue
Block a user