diff --git a/benchmark/mmmu/bench_sglang.py b/benchmark/mmmu/bench_sglang.py index 19f60d661..3786d2b09 100644 --- a/benchmark/mmmu/bench_sglang.py +++ b/benchmark/mmmu/bench_sglang.py @@ -86,8 +86,8 @@ def eval_mmmu(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - args = add_common_sglang_args_and_parse(parser) EvalArgs.add_cli_args(parser) + args = add_common_sglang_args_and_parse(parser) args = parser.parse_args() eval_mmmu(args) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 45e64c45e..dca38d9bb 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -42,6 +42,16 @@ class FlashAttentionMetadata: # Page table, the index of KV Cache Tables/Blocks page_table: torch.Tensor = None + # Encoder metadata + # Cumulative sequence lengths for encoder key + encoder_cu_seqlens_k: torch.Tensor = None + # Maximum sequence length for encoder key + encoder_max_seq_len_k: int = 0 + # Sequence lengths for the forward batch + encoder_lens_int32: torch.Tensor = None + # Page table for the encoder + encoder_page_table: torch.Tensor = None + @dataclass class LocalAttentionMetadata: local_query_start_loc: torch.Tensor = None # cu_seqlens_q for local attention @@ -435,6 +445,30 @@ class FlashAttentionBackend(AttentionBackend): ) metadata.local_attn_metadata = local_metadata + # Encoder metadata for cross attention + if forward_batch.encoder_lens is not None: + assert ( + forward_batch.encoder_lens.numel() == 1 + ), "Only encoder size 1 is supported for now" + + metadata.encoder_lens_int32 = forward_batch.encoder_lens.to(torch.int32) + metadata.encoder_cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32), + (1, 0), + ) + metadata.encoder_max_seq_len_k = metadata.encoder_lens_int32.max().item() + metadata.encoder_page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k + ] + + # Currently only support forward_batch.encoder_lens.numel() == 1 + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, + metadata.encoder_max_seq_len_k : ( + metadata.encoder_max_seq_len_k + metadata.max_seq_len_k + ), + ] + # Convert the page table to a strided format which is needed by FA3 API if self.page_size > 1: self.strided_indices = torch.arange( @@ -486,6 +520,7 @@ class FlashAttentionBackend(AttentionBackend): if layer.sliding_window_size is not None else (-1, -1) ) + causal = not layer.is_cross_attention # Check if we should use local attention use_local_attn = ( @@ -521,6 +556,12 @@ class FlashAttentionBackend(AttentionBackend): value_cache = value_cache.view( -1, self.page_size, layer.tp_v_head_num, layer.head_dim ) + if layer.is_cross_attention: + page_table = metadata.encoder_page_table + cache_seqlens = metadata.encoder_lens_int32 + cu_seqlens_k = metadata.encoder_cu_seqlens_k + window_size = (-1, -1) + o = flash_attn_with_kvcache( q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), k_cache=key_cache, @@ -531,7 +572,7 @@ class FlashAttentionBackend(AttentionBackend): cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None, max_seqlen_q=max_seqlen_q, softmax_scale=layer.scaling, - causal=True, + causal=causal, window_size=window_size, softcap=layer.logit_cap, k_descale=layer.k_scale, @@ -614,6 +655,7 @@ class FlashAttentionBackend(AttentionBackend): if layer.sliding_window_size is not None else (-1, -1) ) + causal = not layer.is_cross_attention if not self.use_mla: # Do multi-head attention @@ -627,17 +669,27 @@ class FlashAttentionBackend(AttentionBackend): ) q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + if layer.is_cross_attention: + page_table = metadata.encoder_page_table + cache_seqlens = metadata.encoder_lens_int32 + cu_seqlens_k = metadata.encoder_cu_seqlens_k + window_size = (-1, -1) + else: + page_table = metadata.page_table + cache_seqlens = metadata.cache_seqlens_int32 + cu_seqlens_k = metadata.cu_seqlens_k + o = flash_attn_with_kvcache( q=q_reshaped, k_cache=key_cache, v_cache=value_cache, - page_table=metadata.page_table, - cache_seqlens=metadata.cache_seqlens_int32, + page_table=page_table, + cache_seqlens=cache_seqlens, cu_seqlens_q=metadata.cu_seqlens_q, - cu_seqlens_k_new=metadata.cu_seqlens_k, + cu_seqlens_k_new=cu_seqlens_k, max_seqlen_q=1, softmax_scale=layer.scaling, - causal=True, + causal=causal, window_size=window_size, softcap=layer.logit_cap, k_descale=layer.k_scale, @@ -733,6 +785,21 @@ class FlashAttentionBackend(AttentionBackend): ), } + self.encoder_metadata = { + "encoder_page_table": torch.zeros( + max_bs, + self.max_context_len, + dtype=torch.int32, + device=self.device, + ), + "encoder_lens_int32": torch.zeros( + max_bs, dtype=torch.int32, device=self.device + ), + "encoder_cu_seqlens_k": torch.zeros( + max_bs + 1, dtype=torch.int32, device=self.device + ), + } + def init_forward_metadata_capture_cuda_graph( self, bs: int, @@ -818,6 +885,19 @@ class FlashAttentionBackend(AttentionBackend): self.target_verify_metadata[bs] = metadata + if encoder_lens is not None: + encoder_bs = encoder_lens.numel() + metadata.encoder_lens_int32 = self.encoder_metadata["encoder_lens_int32"][ + :encoder_bs + ] + metadata.encoder_cu_seqlens_k = self.encoder_metadata[ + "encoder_cu_seqlens_k" + ][: (encoder_bs + 1)] + + metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][ + req_pool_indices, : + ] + self.forward_metadata = metadata def init_forward_metadata_replay_cuda_graph( @@ -903,6 +983,30 @@ class FlashAttentionBackend(AttentionBackend): page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k] metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table) + if encoder_lens is not None: + # Only support encoder size 1 for now + metadata.encoder_max_seq_len_k = encoder_lens[0] + metadata.encoder_lens_int32.copy_(encoder_lens[:1]) + metadata.encoder_cu_seqlens_k.copy_( + torch.nn.functional.pad( + torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32), + (1, 0), + ) + ) + + metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_( + self.req_to_token[req_pool_indices, : metadata.encoder_max_seq_len_k] + ) + + # Update the regular page table + page_table = self.req_to_token[ + req_pool_indices, + metadata.encoder_max_seq_len_k : ( + metadata.encoder_max_seq_len_k + metadata.max_seq_len_k + ), + ] + metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table) + self.forward_metadata = metadata def get_cuda_graph_seq_len_fill_value(self): @@ -956,7 +1060,7 @@ class FlashAttentionMultiStepBackend: forward_batch.batch_size * self.topk, forward_batch.req_pool_indices, forward_batch.seq_lens, - encoder_lens=None, + encoder_lens=forward_batch.encoder_lens, forward_mode=ForwardMode.DECODE, spec_info=forward_batch.spec_info, ) @@ -973,7 +1077,7 @@ class FlashAttentionMultiStepBackend: forward_batch.req_pool_indices, forward_batch.seq_lens, forward_batch.seq_lens_sum, - encoder_lens=None, + encoder_lens=forward_batch.encoder_lens, forward_mode=ForwardMode.DECODE, spec_info=forward_batch.spec_info, seq_lens_cpu=forward_batch.seq_lens_cpu, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 8506485fe..e7fae3a68 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -886,7 +886,7 @@ class ModelRunner: "Please use `--attention-backend flashinfer`." ) logger.warning( - "FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported." + "FlashAttention v3 Backend is in Beta. FP8 is not supported." ) from sglang.srt.layers.attention.flashattention_backend import ( FlashAttentionBackend,