[FA3 Feature] Support multi modal Llama-3.2-11B-Vision-Instruct (#5103)
This commit is contained in:
@@ -86,8 +86,8 @@ def eval_mmmu(args):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
args = add_common_sglang_args_and_parse(parser)
|
|
||||||
EvalArgs.add_cli_args(parser)
|
EvalArgs.add_cli_args(parser)
|
||||||
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
eval_mmmu(args)
|
eval_mmmu(args)
|
||||||
|
|||||||
@@ -42,6 +42,16 @@ class FlashAttentionMetadata:
|
|||||||
# Page table, the index of KV Cache Tables/Blocks
|
# Page table, the index of KV Cache Tables/Blocks
|
||||||
page_table: torch.Tensor = None
|
page_table: torch.Tensor = None
|
||||||
|
|
||||||
|
# Encoder metadata
|
||||||
|
# Cumulative sequence lengths for encoder key
|
||||||
|
encoder_cu_seqlens_k: torch.Tensor = None
|
||||||
|
# Maximum sequence length for encoder key
|
||||||
|
encoder_max_seq_len_k: int = 0
|
||||||
|
# Sequence lengths for the forward batch
|
||||||
|
encoder_lens_int32: torch.Tensor = None
|
||||||
|
# Page table for the encoder
|
||||||
|
encoder_page_table: torch.Tensor = None
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LocalAttentionMetadata:
|
class LocalAttentionMetadata:
|
||||||
local_query_start_loc: torch.Tensor = None # cu_seqlens_q for local attention
|
local_query_start_loc: torch.Tensor = None # cu_seqlens_q for local attention
|
||||||
@@ -435,6 +445,30 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
metadata.local_attn_metadata = local_metadata
|
metadata.local_attn_metadata = local_metadata
|
||||||
|
|
||||||
|
# Encoder metadata for cross attention
|
||||||
|
if forward_batch.encoder_lens is not None:
|
||||||
|
assert (
|
||||||
|
forward_batch.encoder_lens.numel() == 1
|
||||||
|
), "Only encoder size 1 is supported for now"
|
||||||
|
|
||||||
|
metadata.encoder_lens_int32 = forward_batch.encoder_lens.to(torch.int32)
|
||||||
|
metadata.encoder_cu_seqlens_k = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
|
||||||
|
(1, 0),
|
||||||
|
)
|
||||||
|
metadata.encoder_max_seq_len_k = metadata.encoder_lens_int32.max().item()
|
||||||
|
metadata.encoder_page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||||
|
forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k
|
||||||
|
]
|
||||||
|
|
||||||
|
# Currently only support forward_batch.encoder_lens.numel() == 1
|
||||||
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
metadata.encoder_max_seq_len_k : (
|
||||||
|
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
# Convert the page table to a strided format which is needed by FA3 API
|
# Convert the page table to a strided format which is needed by FA3 API
|
||||||
if self.page_size > 1:
|
if self.page_size > 1:
|
||||||
self.strided_indices = torch.arange(
|
self.strided_indices = torch.arange(
|
||||||
@@ -486,6 +520,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
if layer.sliding_window_size is not None
|
if layer.sliding_window_size is not None
|
||||||
else (-1, -1)
|
else (-1, -1)
|
||||||
)
|
)
|
||||||
|
causal = not layer.is_cross_attention
|
||||||
|
|
||||||
# Check if we should use local attention
|
# Check if we should use local attention
|
||||||
use_local_attn = (
|
use_local_attn = (
|
||||||
@@ -521,6 +556,12 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
value_cache = value_cache.view(
|
value_cache = value_cache.view(
|
||||||
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
||||||
)
|
)
|
||||||
|
if layer.is_cross_attention:
|
||||||
|
page_table = metadata.encoder_page_table
|
||||||
|
cache_seqlens = metadata.encoder_lens_int32
|
||||||
|
cu_seqlens_k = metadata.encoder_cu_seqlens_k
|
||||||
|
window_size = (-1, -1)
|
||||||
|
|
||||||
o = flash_attn_with_kvcache(
|
o = flash_attn_with_kvcache(
|
||||||
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
k_cache=key_cache,
|
k_cache=key_cache,
|
||||||
@@ -531,7 +572,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
||||||
max_seqlen_q=max_seqlen_q,
|
max_seqlen_q=max_seqlen_q,
|
||||||
softmax_scale=layer.scaling,
|
softmax_scale=layer.scaling,
|
||||||
causal=True,
|
causal=causal,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
softcap=layer.logit_cap,
|
softcap=layer.logit_cap,
|
||||||
k_descale=layer.k_scale,
|
k_descale=layer.k_scale,
|
||||||
@@ -614,6 +655,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
if layer.sliding_window_size is not None
|
if layer.sliding_window_size is not None
|
||||||
else (-1, -1)
|
else (-1, -1)
|
||||||
)
|
)
|
||||||
|
causal = not layer.is_cross_attention
|
||||||
|
|
||||||
if not self.use_mla:
|
if not self.use_mla:
|
||||||
# Do multi-head attention
|
# Do multi-head attention
|
||||||
@@ -627,17 +669,27 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
if layer.is_cross_attention:
|
||||||
|
page_table = metadata.encoder_page_table
|
||||||
|
cache_seqlens = metadata.encoder_lens_int32
|
||||||
|
cu_seqlens_k = metadata.encoder_cu_seqlens_k
|
||||||
|
window_size = (-1, -1)
|
||||||
|
else:
|
||||||
|
page_table = metadata.page_table
|
||||||
|
cache_seqlens = metadata.cache_seqlens_int32
|
||||||
|
cu_seqlens_k = metadata.cu_seqlens_k
|
||||||
|
|
||||||
o = flash_attn_with_kvcache(
|
o = flash_attn_with_kvcache(
|
||||||
q=q_reshaped,
|
q=q_reshaped,
|
||||||
k_cache=key_cache,
|
k_cache=key_cache,
|
||||||
v_cache=value_cache,
|
v_cache=value_cache,
|
||||||
page_table=metadata.page_table,
|
page_table=page_table,
|
||||||
cache_seqlens=metadata.cache_seqlens_int32,
|
cache_seqlens=cache_seqlens,
|
||||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
cu_seqlens_k_new=cu_seqlens_k,
|
||||||
max_seqlen_q=1,
|
max_seqlen_q=1,
|
||||||
softmax_scale=layer.scaling,
|
softmax_scale=layer.scaling,
|
||||||
causal=True,
|
causal=causal,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
softcap=layer.logit_cap,
|
softcap=layer.logit_cap,
|
||||||
k_descale=layer.k_scale,
|
k_descale=layer.k_scale,
|
||||||
@@ -733,6 +785,21 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.encoder_metadata = {
|
||||||
|
"encoder_page_table": torch.zeros(
|
||||||
|
max_bs,
|
||||||
|
self.max_context_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
"encoder_lens_int32": torch.zeros(
|
||||||
|
max_bs, dtype=torch.int32, device=self.device
|
||||||
|
),
|
||||||
|
"encoder_cu_seqlens_k": torch.zeros(
|
||||||
|
max_bs + 1, dtype=torch.int32, device=self.device
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self,
|
self,
|
||||||
bs: int,
|
bs: int,
|
||||||
@@ -818,6 +885,19 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
self.target_verify_metadata[bs] = metadata
|
self.target_verify_metadata[bs] = metadata
|
||||||
|
|
||||||
|
if encoder_lens is not None:
|
||||||
|
encoder_bs = encoder_lens.numel()
|
||||||
|
metadata.encoder_lens_int32 = self.encoder_metadata["encoder_lens_int32"][
|
||||||
|
:encoder_bs
|
||||||
|
]
|
||||||
|
metadata.encoder_cu_seqlens_k = self.encoder_metadata[
|
||||||
|
"encoder_cu_seqlens_k"
|
||||||
|
][: (encoder_bs + 1)]
|
||||||
|
|
||||||
|
metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
|
||||||
|
req_pool_indices, :
|
||||||
|
]
|
||||||
|
|
||||||
self.forward_metadata = metadata
|
self.forward_metadata = metadata
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
@@ -903,6 +983,30 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
|
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
|
||||||
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
||||||
|
|
||||||
|
if encoder_lens is not None:
|
||||||
|
# Only support encoder size 1 for now
|
||||||
|
metadata.encoder_max_seq_len_k = encoder_lens[0]
|
||||||
|
metadata.encoder_lens_int32.copy_(encoder_lens[:1])
|
||||||
|
metadata.encoder_cu_seqlens_k.copy_(
|
||||||
|
torch.nn.functional.pad(
|
||||||
|
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
|
||||||
|
(1, 0),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
|
||||||
|
self.req_to_token[req_pool_indices, : metadata.encoder_max_seq_len_k]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the regular page table
|
||||||
|
page_table = self.req_to_token[
|
||||||
|
req_pool_indices,
|
||||||
|
metadata.encoder_max_seq_len_k : (
|
||||||
|
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
|
||||||
|
),
|
||||||
|
]
|
||||||
|
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
||||||
|
|
||||||
self.forward_metadata = metadata
|
self.forward_metadata = metadata
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
@@ -956,7 +1060,7 @@ class FlashAttentionMultiStepBackend:
|
|||||||
forward_batch.batch_size * self.topk,
|
forward_batch.batch_size * self.topk,
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
forward_batch.seq_lens,
|
forward_batch.seq_lens,
|
||||||
encoder_lens=None,
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
forward_mode=ForwardMode.DECODE,
|
forward_mode=ForwardMode.DECODE,
|
||||||
spec_info=forward_batch.spec_info,
|
spec_info=forward_batch.spec_info,
|
||||||
)
|
)
|
||||||
@@ -973,7 +1077,7 @@ class FlashAttentionMultiStepBackend:
|
|||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
forward_batch.seq_lens,
|
forward_batch.seq_lens,
|
||||||
forward_batch.seq_lens_sum,
|
forward_batch.seq_lens_sum,
|
||||||
encoder_lens=None,
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
forward_mode=ForwardMode.DECODE,
|
forward_mode=ForwardMode.DECODE,
|
||||||
spec_info=forward_batch.spec_info,
|
spec_info=forward_batch.spec_info,
|
||||||
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
||||||
|
|||||||
@@ -886,7 +886,7 @@ class ModelRunner:
|
|||||||
"Please use `--attention-backend flashinfer`."
|
"Please use `--attention-backend flashinfer`."
|
||||||
)
|
)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported."
|
"FlashAttention v3 Backend is in Beta. FP8 is not supported."
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.attention.flashattention_backend import (
|
from sglang.srt.layers.attention.flashattention_backend import (
|
||||||
FlashAttentionBackend,
|
FlashAttentionBackend,
|
||||||
|
|||||||
Reference in New Issue
Block a user