diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index c7da38ac5..73cf574dd 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -66,6 +66,10 @@ class PrefillMetadata: # Reuse this workspace buffer across all flashinfer wrappers global_workspace_buffer = None +# Use as a fast path to override the indptr in flashinfer's plan function +# This is used to remove some host-to-device copy overhead. +global_override_indptr_cpu = None + class FlashInferAttnBackend(AttentionBackend): """Flashinfer attention kernels.""" @@ -205,6 +209,7 @@ class FlashInferAttnBackend(AttentionBackend): self.indices_updater_decode.update( forward_batch.req_pool_indices, forward_batch.seq_lens, + forward_batch.seq_lens_cpu, forward_batch.seq_lens_sum, decode_wrappers=self.decode_wrappers, encoder_lens=forward_batch.encoder_lens, @@ -215,6 +220,7 @@ class FlashInferAttnBackend(AttentionBackend): self.indices_updater_prefill.update( forward_batch.req_pool_indices, forward_batch.seq_lens, + forward_batch.seq_lens_cpu, forward_batch.seq_lens_sum, prefix_lens=None, prefill_wrappers=self.prefill_wrappers_paged, @@ -229,6 +235,7 @@ class FlashInferAttnBackend(AttentionBackend): self.indices_updater_prefill.update( forward_batch.req_pool_indices, forward_batch.seq_lens, + forward_batch.seq_lens_cpu, forward_batch.seq_lens_sum, prefix_lens=None, prefill_wrappers=self.prefill_wrappers_verify, @@ -252,6 +259,7 @@ class FlashInferAttnBackend(AttentionBackend): self.indices_updater_prefill.update( forward_batch.req_pool_indices, forward_batch.seq_lens, + forward_batch.seq_lens_cpu, forward_batch.seq_lens_sum, prefix_lens, prefill_wrappers=self.prefill_wrappers_paged, @@ -327,6 +335,7 @@ class FlashInferAttnBackend(AttentionBackend): self.indices_updater_decode.update( req_pool_indices, seq_lens, + seq_lens.cpu(), # may add a little overhead in capture stage seq_lens_sum, decode_wrappers=decode_wrappers, encoder_lens=encoder_lens, @@ -358,6 +367,7 @@ class FlashInferAttnBackend(AttentionBackend): self.indices_updater_prefill.update( req_pool_indices, seq_lens, + seq_lens.cpu(), # may add a little overhead in capture stage seq_lens_sum, prefix_lens=None, prefill_wrappers=prefill_wrappers, @@ -387,6 +397,7 @@ class FlashInferAttnBackend(AttentionBackend): self.indices_updater_prefill.update( req_pool_indices, seq_lens, + seq_lens.cpu(), # may add a little overhead in capture stage seq_lens_sum, prefix_lens=None, prefill_wrappers=prefill_wrappers, @@ -414,6 +425,7 @@ class FlashInferAttnBackend(AttentionBackend): self.indices_updater_decode.update( req_pool_indices[:bs], seq_lens[:bs], + seq_lens_cpu[:bs] if seq_lens_cpu is not None else None, seq_lens_sum, decode_wrappers=self.decode_cuda_graph_metadata[bs], encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, @@ -423,6 +435,7 @@ class FlashInferAttnBackend(AttentionBackend): self.indices_updater_prefill.update( req_pool_indices[:bs], seq_lens[:bs], + seq_lens_cpu[:bs] if seq_lens_cpu is not None else None, seq_lens_sum, prefix_lens=None, prefill_wrappers=self.prefill_cuda_graph_metadata[bs], @@ -434,6 +447,7 @@ class FlashInferAttnBackend(AttentionBackend): self.indices_updater_prefill.update( req_pool_indices[:bs], seq_lens[:bs], + seq_lens_cpu[:bs] if seq_lens_cpu is not None else None, seq_lens_sum, prefix_lens=None, prefill_wrappers=self.prefill_cuda_graph_metadata[bs], @@ -581,7 +595,7 @@ class FlashInferAttnBackend(AttentionBackend): class FlashInferIndicesUpdaterDecode: - def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): + def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBackend): # Parse Constants self.num_qo_heads = ( model_runner.model_config.num_attention_heads // get_attention_tp_size() @@ -614,6 +628,7 @@ class FlashInferIndicesUpdaterDecode: self, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, + seq_lens_cpu: Optional[torch.Tensor], seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: Optional[torch.Tensor], @@ -626,6 +641,7 @@ class FlashInferIndicesUpdaterDecode: self, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, + seq_lens_cpu: Optional[torch.Tensor], seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: Optional[torch.Tensor], @@ -640,30 +656,39 @@ class FlashInferIndicesUpdaterDecode: self.kv_indptr[0], None, spec_info, + seq_lens_cpu, ) def update_sliding_window( self, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, + seq_lens_cpu: Optional[torch.Tensor], seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: Optional[torch.Tensor], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): + assert self.sliding_window_size is not None for wrapper_id in range(2): if wrapper_id == 0: # Sliding window attention - paged_kernel_lens_tmp = torch.minimum( # TODO: replace this with clamp - seq_lens, - torch.tensor(self.sliding_window_size + 1), + paged_kernel_lens_tmp = torch.clamp( + seq_lens, max=self.sliding_window_size + 1 ) - paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item() + if seq_lens_cpu is not None: + seq_lens_cpu_tmp = torch.clamp( + seq_lens_cpu, max=self.sliding_window_size + 1 + ) + paged_kernel_lens_sum_tmp = seq_lens_cpu_tmp.sum().item() + else: + paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item() kv_start_idx_tmp = seq_lens - paged_kernel_lens_tmp else: # Full attention paged_kernel_lens_tmp = seq_lens paged_kernel_lens_sum_tmp = seq_lens_sum + seq_lens_cpu_tmp = seq_lens_cpu kv_start_idx_tmp = None use_sliding_window_kv_pool = wrapper_id == 0 and isinstance( @@ -678,6 +703,7 @@ class FlashInferIndicesUpdaterDecode: self.kv_indptr[wrapper_id], kv_start_idx_tmp, spec_info, + seq_lens_cpu=seq_lens_cpu_tmp, use_sliding_window_kv_pool=use_sliding_window_kv_pool, ) @@ -685,6 +711,7 @@ class FlashInferIndicesUpdaterDecode: self, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, + seq_lens_cpu: Optional[torch.Tensor], seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: Optional[torch.Tensor], @@ -709,6 +736,7 @@ class FlashInferIndicesUpdaterDecode: self.kv_indptr[wrapper_id], kv_start_idx, spec_info, + seq_lens_cpu=seq_lens_cpu, ) def call_begin_forward( @@ -720,6 +748,7 @@ class FlashInferIndicesUpdaterDecode: kv_indptr: torch.Tensor, kv_start_idx: torch.Tensor, spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + seq_lens_cpu: Optional[torch.Tensor], use_sliding_window_kv_pool: bool = False, ): if spec_info is None: @@ -756,6 +785,14 @@ class FlashInferIndicesUpdaterDecode: ) ) + global global_override_indptr_cpu + locally_override = False + if seq_lens_cpu is not None and global_override_indptr_cpu is None: + locally_override = True + global_override_indptr_cpu = torch.empty_like(kv_indptr, device="cpu") + global_override_indptr_cpu[0] = 0 + global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0) + wrapper.begin_forward( kv_indptr, kv_indices, @@ -769,9 +806,12 @@ class FlashInferIndicesUpdaterDecode: non_blocking=True, ) + if locally_override: + global_override_indptr_cpu = None + class FlashInferIndicesUpdaterPrefill: - def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): + def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBackend): # Parse Constants self.num_qo_heads = ( model_runner.model_config.num_attention_heads // get_attention_tp_size() @@ -806,6 +846,7 @@ class FlashInferIndicesUpdaterPrefill: self, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, + seq_lens_cpu: Optional[torch.Tensor], seq_lens_sum: int, prefix_lens: torch.Tensor, prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], @@ -820,6 +861,7 @@ class FlashInferIndicesUpdaterPrefill: self, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, + seq_lens_cpu: Optional[torch.Tensor], seq_lens_sum: int, prefix_lens: torch.Tensor, prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], @@ -853,6 +895,7 @@ class FlashInferIndicesUpdaterPrefill: self, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, + seq_lens_cpu: Optional[torch.Tensor], seq_lens_sum: int, prefix_lens: torch.Tensor, prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], @@ -898,6 +941,7 @@ class FlashInferIndicesUpdaterPrefill: self, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, + seq_lens_cpu: Optional[torch.Tensor], seq_lens_sum: int, prefix_lens: torch.Tensor, prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], @@ -1020,11 +1064,6 @@ class FlashInferIndicesUpdaterPrefill: ) -# Use as a fast path to override the indptr in flashinfer's plan function -# This is used to remove some host-to-device copy overhead. -global global_override_indptr_cpu - - class FlashInferMultiStepDraftBackend: """ Wrap multiple flashinfer attention backends as one for multiple consecutive @@ -1056,7 +1095,7 @@ class FlashInferMultiStepDraftBackend: self.kv_last_page_len = torch.ones( (max_bs,), dtype=torch.int32, device=model_runner.device ) - self.attn_backends = [] + self.attn_backends: List[FlashInferAttnBackend] = [] for i in range(self.speculative_num_steps): self.attn_backends.append( FlashInferAttnBackend( @@ -1176,7 +1215,7 @@ class FlashInferMultiStepDraftBackend: encoder_lens=None, forward_mode=ForwardMode.DECODE, spec_info=forward_batch.spec_info, - seq_lens_cpu=None, + seq_lens_cpu=forward_batch.seq_lens_cpu, ) self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index fd7630b3e..e6b8d42ba 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1714,16 +1714,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): attention_backend_str = global_server_args_dict["prefill_attention_backend"] # Create seq_lens_cpu when needed if ( - attention_backend_str == "fa3" - or ( - global_server_args_dict["use_mla_backend"] - and attention_backend_str == "flashinfer" - ) - or attention_backend_str == "flashmla" - or attention_backend_str == "cutlass_mla" - or attention_backend_str == "ascend" - or attention_backend_str == "trtllm_mha" - or attention_backend_str == "aiter" + attention_backend_str + in [ + "fa3", + "flashinfer", + "flashmla", + "cutlass_mla", + "ascend", + "trtllm_mha", + "aiter", + ] or global_server_args_dict["enable_two_batch_overlap"] ): seq_lens_cpu = ( diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 05599c697..303919505 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -729,10 +729,12 @@ class CudaGraphRunner: self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) self.positions[:raw_num_token].copy_(forward_batch.positions) + seq_lens_cpu = None if forward_batch.seq_lens_cpu is not None: if bs != raw_bs: self.seq_lens_cpu.fill_(self.seq_len_fill_value) self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) + seq_lens_cpu = self.seq_lens_cpu[:bs] if pp_proxy_tensors: for key in self.pp_proxy_tensors.keys(): @@ -766,7 +768,7 @@ class CudaGraphRunner: self.encoder_lens[:bs] if self.is_encoder_decoder else None, self.capture_forward_mode, forward_batch.spec_info, - seq_lens_cpu=self.seq_lens_cpu[:bs], + seq_lens_cpu=seq_lens_cpu, ) # Store fields