From e62d60fe6d7ce656e9338284d08fc97917a1c70e Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sun, 30 Mar 2025 13:53:44 -0700 Subject: [PATCH] [Fix] avoid stream sync and torch compile in prefill for fa3 backend (#4932) --- .../attention/flashattention_backend.py | 2 +- .../attention/flashinfer_mla_backend.py | 2 +- .../srt/layers/attention/flashmla_backend.py | 2 +- python/sglang/srt/managers/schedule_batch.py | 25 +++++++++---------- .../srt/model_executor/cuda_graph_runner.py | 4 +-- .../srt/model_executor/forward_batch_info.py | 20 ++++++--------- .../eagle_draft_cuda_graph_runner.py | 10 ++++---- 7 files changed, 30 insertions(+), 35 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index ac549d5ea..93c263f74 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -79,7 +79,7 @@ class FlashAttentionBackend(AttentionBackend): torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) ) # Precompute maximum sequence length - metadata.max_seq_len_k = seqlens_in_batch.max().item() + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() # Precompute page table metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : metadata.max_seq_len_k diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 9af027bd1..65bcdf513 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -797,7 +797,7 @@ class FlashInferMLAMultiStepDraftBackend: encoder_lens=None, forward_mode=ForwardMode.DECODE, spec_info=forward_batch.spec_info, - seq_lens_cpu=forward_batch.decode_seq_lens_cpu, + seq_lens_cpu=forward_batch.seq_lens_cpu, ) self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) diff --git a/python/sglang/srt/layers/attention/flashmla_backend.py b/python/sglang/srt/layers/attention/flashmla_backend.py index 730c79495..85fe4a2fb 100644 --- a/python/sglang/srt/layers/attention/flashmla_backend.py +++ b/python/sglang/srt/layers/attention/flashmla_backend.py @@ -92,7 +92,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend): if forward_batch.forward_mode.is_decode_or_idle(): if spec_info is None: max_seqlen_pad = triton.cdiv( - forward_batch.decode_seq_lens_cpu.max().item(), PAGE_SIZE + forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE ) block_kv_indices = torch.full( (bs, max_seqlen_pad), diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 012788708..a327f60dc 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1398,21 +1398,22 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): def get_model_worker_batch(self) -> ModelWorkerBatch: if self.forward_mode.is_decode_or_idle(): - if ( - global_server_args_dict["enable_flashinfer_mla"] - or global_server_args_dict["enable_flashmla"] - or global_server_args_dict["attention_backend"] == "fa3" - ): - decode_seq_lens = self.seq_lens.cpu() - else: - decode_seq_lens = None extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None else: - decode_seq_lens = None extend_seq_lens = self.extend_lens extend_prefix_lens = self.prefix_lens extend_logprob_start_lens = self.extend_logprob_start_lens + # Create seq_lens_cpu when needed + if ( + global_server_args_dict["enable_flashinfer_mla"] + or global_server_args_dict["enable_flashmla"] + or global_server_args_dict["attention_backend"] == "fa3" + ): + seq_lens_cpu = self.seq_lens.cpu() + else: + seq_lens_cpu = None + if self.sampling_info: if self.has_grammar: self.sampling_info.grammars = [req.grammar for req in self.reqs] @@ -1435,7 +1436,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): global_num_tokens=self.global_num_tokens, global_num_tokens_for_logprob=self.global_num_tokens_for_logprob, can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, - decode_seq_lens=decode_seq_lens, + seq_lens_cpu=seq_lens_cpu, extend_num_tokens=self.extend_num_tokens, extend_seq_lens=extend_seq_lens, extend_prefix_lens=extend_prefix_lens, @@ -1496,6 +1497,7 @@ class ModelWorkerBatch: req_pool_indices: torch.Tensor # The sequence length seq_lens: torch.Tensor + seq_lens_cpu: Optional[torch.Tensor] # The indices of output tokens in the token_to_kv_pool_allocator out_cache_loc: torch.Tensor @@ -1512,9 +1514,6 @@ class ModelWorkerBatch: global_num_tokens_for_logprob: Optional[List[int]] can_run_dp_cuda_graph: bool - # For decode - decode_seq_lens: Optional[torch.Tensor] - # For extend extend_num_tokens: Optional[int] extend_seq_lens: Optional[List[int]] diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 1039ddaae..033111e45 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -491,10 +491,10 @@ class CudaGraphRunner: self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) self.positions[:raw_num_token].copy_(forward_batch.positions) - if forward_batch.decode_seq_lens_cpu is not None: + if forward_batch.seq_lens_cpu is not None: if bs != raw_bs: self.seq_lens_cpu.fill_(1) - self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu) + self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) if self.is_encoder_decoder: self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 8d2da2b25..d2991249d 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -39,7 +39,6 @@ import triton import triton.language as tl from sglang.srt.layers.rotary_embedding import MRotaryEmbedding -from sglang.srt.utils import get_compiler_backend if TYPE_CHECKING: from sglang.srt.layers.attention.base_attn_backend import AttentionBackend @@ -148,6 +147,9 @@ class ForwardBatch: # The sum of all sequence lengths seq_lens_sum: int + # Optional seq_lens on cpu + seq_lens_cpu: Optional[torch.Tensor] = None + # For logprob return_logprob: bool = False top_logprobs_nums: Optional[List[int]] = None @@ -162,9 +164,6 @@ class ForwardBatch: # Position information positions: torch.Tensor = None - # For decode - decode_seq_lens_cpu: Optional[torch.Tensor] = None - # For extend extend_num_tokens: Optional[int] = None extend_seq_lens: Optional[torch.Tensor] = None @@ -293,12 +292,14 @@ class ForwardBatch: ): ret.positions = ret.spec_info.positions + # Get seq_lens_cpu if needed + if ret.seq_lens_cpu is None: + ret.seq_lens_cpu = batch.seq_lens_cpu + # Init position information if ret.forward_mode.is_decode(): if ret.positions is None: - ret.positions = clamp_position(batch.seq_lens) - if ret.decode_seq_lens_cpu is None: - ret.decode_seq_lens_cpu = batch.decode_seq_lens + ret.positions = torch.clamp((batch.seq_lens - 1), min=0).to(torch.int64) else: ret.extend_seq_lens = torch.tensor( batch.extend_seq_lens, dtype=torch.int32 @@ -518,8 +519,3 @@ def compute_position_torch( extend_start_loc = torch.zeros_like(extend_seq_lens) extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0) return positions.to(torch.int64), extend_start_loc - - -@torch.compile(dynamic=True, backend=get_compiler_backend()) -def clamp_position(seq_lens): - return torch.clamp((seq_lens - 1), min=0).to(torch.int64) diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 323f47de9..d3ad50060 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -214,10 +214,10 @@ class EAGLEDraftCudaGraphRunner: forward_batch.positions = self.positions[:num_tokens] # Special handle for seq_len_cpu used when flashinfer mla is used - if (forward_batch.decode_seq_lens_cpu is not None) and (bs != raw_bs): + if (forward_batch.seq_lens_cpu is not None) and (bs != raw_bs): self.seq_lens_cpu.fill_(1) - self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu) - forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:bs] + self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) + forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs] self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph( forward_batch, bs @@ -233,7 +233,7 @@ class EAGLEDraftCudaGraphRunner: forward_batch.positions = self.positions[:raw_num_token] forward_batch.seq_lens = self.seq_lens[:raw_bs] forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs] - if forward_batch.decode_seq_lens_cpu is not None: - forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:raw_bs] + if forward_batch.seq_lens_cpu is not None: + forward_batch.seq_lens_cpu = self.seq_lens_cpu[:raw_bs] return out