[FA3 Feature] Support multi modal Llama-3.2-11B-Vision-Instruct (#5103)
This commit is contained in:
@@ -42,6 +42,16 @@ class FlashAttentionMetadata:
|
||||
# Page table, the index of KV Cache Tables/Blocks
|
||||
page_table: torch.Tensor = None
|
||||
|
||||
# Encoder metadata
|
||||
# Cumulative sequence lengths for encoder key
|
||||
encoder_cu_seqlens_k: torch.Tensor = None
|
||||
# Maximum sequence length for encoder key
|
||||
encoder_max_seq_len_k: int = 0
|
||||
# Sequence lengths for the forward batch
|
||||
encoder_lens_int32: torch.Tensor = None
|
||||
# Page table for the encoder
|
||||
encoder_page_table: torch.Tensor = None
|
||||
|
||||
@dataclass
|
||||
class LocalAttentionMetadata:
|
||||
local_query_start_loc: torch.Tensor = None # cu_seqlens_q for local attention
|
||||
@@ -435,6 +445,30 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
)
|
||||
metadata.local_attn_metadata = local_metadata
|
||||
|
||||
# Encoder metadata for cross attention
|
||||
if forward_batch.encoder_lens is not None:
|
||||
assert (
|
||||
forward_batch.encoder_lens.numel() == 1
|
||||
), "Only encoder size 1 is supported for now"
|
||||
|
||||
metadata.encoder_lens_int32 = forward_batch.encoder_lens.to(torch.int32)
|
||||
metadata.encoder_cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
|
||||
(1, 0),
|
||||
)
|
||||
metadata.encoder_max_seq_len_k = metadata.encoder_lens_int32.max().item()
|
||||
metadata.encoder_page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k
|
||||
]
|
||||
|
||||
# Currently only support forward_batch.encoder_lens.numel() == 1
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices,
|
||||
metadata.encoder_max_seq_len_k : (
|
||||
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
|
||||
),
|
||||
]
|
||||
|
||||
# Convert the page table to a strided format which is needed by FA3 API
|
||||
if self.page_size > 1:
|
||||
self.strided_indices = torch.arange(
|
||||
@@ -486,6 +520,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
if layer.sliding_window_size is not None
|
||||
else (-1, -1)
|
||||
)
|
||||
causal = not layer.is_cross_attention
|
||||
|
||||
# Check if we should use local attention
|
||||
use_local_attn = (
|
||||
@@ -521,6 +556,12 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
value_cache = value_cache.view(
|
||||
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
||||
)
|
||||
if layer.is_cross_attention:
|
||||
page_table = metadata.encoder_page_table
|
||||
cache_seqlens = metadata.encoder_lens_int32
|
||||
cu_seqlens_k = metadata.encoder_cu_seqlens_k
|
||||
window_size = (-1, -1)
|
||||
|
||||
o = flash_attn_with_kvcache(
|
||||
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k_cache=key_cache,
|
||||
@@ -531,7 +572,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=layer.k_scale,
|
||||
@@ -614,6 +655,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
if layer.sliding_window_size is not None
|
||||
else (-1, -1)
|
||||
)
|
||||
causal = not layer.is_cross_attention
|
||||
|
||||
if not self.use_mla:
|
||||
# Do multi-head attention
|
||||
@@ -627,17 +669,27 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
if layer.is_cross_attention:
|
||||
page_table = metadata.encoder_page_table
|
||||
cache_seqlens = metadata.encoder_lens_int32
|
||||
cu_seqlens_k = metadata.encoder_cu_seqlens_k
|
||||
window_size = (-1, -1)
|
||||
else:
|
||||
page_table = metadata.page_table
|
||||
cache_seqlens = metadata.cache_seqlens_int32
|
||||
cu_seqlens_k = metadata.cu_seqlens_k
|
||||
|
||||
o = flash_attn_with_kvcache(
|
||||
q=q_reshaped,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
page_table=metadata.page_table,
|
||||
cache_seqlens=metadata.cache_seqlens_int32,
|
||||
page_table=page_table,
|
||||
cache_seqlens=cache_seqlens,
|
||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||
cu_seqlens_k_new=cu_seqlens_k,
|
||||
max_seqlen_q=1,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=layer.k_scale,
|
||||
@@ -733,6 +785,21 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
),
|
||||
}
|
||||
|
||||
self.encoder_metadata = {
|
||||
"encoder_page_table": torch.zeros(
|
||||
max_bs,
|
||||
self.max_context_len,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"encoder_lens_int32": torch.zeros(
|
||||
max_bs, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"encoder_cu_seqlens_k": torch.zeros(
|
||||
max_bs + 1, dtype=torch.int32, device=self.device
|
||||
),
|
||||
}
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
@@ -818,6 +885,19 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
self.target_verify_metadata[bs] = metadata
|
||||
|
||||
if encoder_lens is not None:
|
||||
encoder_bs = encoder_lens.numel()
|
||||
metadata.encoder_lens_int32 = self.encoder_metadata["encoder_lens_int32"][
|
||||
:encoder_bs
|
||||
]
|
||||
metadata.encoder_cu_seqlens_k = self.encoder_metadata[
|
||||
"encoder_cu_seqlens_k"
|
||||
][: (encoder_bs + 1)]
|
||||
|
||||
metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
|
||||
req_pool_indices, :
|
||||
]
|
||||
|
||||
self.forward_metadata = metadata
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
@@ -903,6 +983,30 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
|
||||
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
||||
|
||||
if encoder_lens is not None:
|
||||
# Only support encoder size 1 for now
|
||||
metadata.encoder_max_seq_len_k = encoder_lens[0]
|
||||
metadata.encoder_lens_int32.copy_(encoder_lens[:1])
|
||||
metadata.encoder_cu_seqlens_k.copy_(
|
||||
torch.nn.functional.pad(
|
||||
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
|
||||
(1, 0),
|
||||
)
|
||||
)
|
||||
|
||||
metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
|
||||
self.req_to_token[req_pool_indices, : metadata.encoder_max_seq_len_k]
|
||||
)
|
||||
|
||||
# Update the regular page table
|
||||
page_table = self.req_to_token[
|
||||
req_pool_indices,
|
||||
metadata.encoder_max_seq_len_k : (
|
||||
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
|
||||
),
|
||||
]
|
||||
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
||||
|
||||
self.forward_metadata = metadata
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
@@ -956,7 +1060,7 @@ class FlashAttentionMultiStepBackend:
|
||||
forward_batch.batch_size * self.topk,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
encoder_lens=None,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
@@ -973,7 +1077,7 @@ class FlashAttentionMultiStepBackend:
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
encoder_lens=None,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
spec_info=forward_batch.spec_info,
|
||||
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
||||
|
||||
@@ -886,7 +886,7 @@ class ModelRunner:
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
logger.warning(
|
||||
"FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported."
|
||||
"FlashAttention v3 Backend is in Beta. FP8 is not supported."
|
||||
)
|
||||
from sglang.srt.layers.attention.flashattention_backend import (
|
||||
FlashAttentionBackend,
|
||||
|
||||
Reference in New Issue
Block a user