diff --git a/python/sglang/srt/layers/attention/tbo_backend.py b/python/sglang/srt/layers/attention/tbo_backend.py index f99387e59..afded3c33 100644 --- a/python/sglang/srt/layers/attention/tbo_backend.py +++ b/python/sglang/srt/layers/attention/tbo_backend.py @@ -119,24 +119,15 @@ class TboAttnBackend(AttentionBackend): replay_seq_lens_sum: int = None, replay_seq_lens_cpu: Optional[torch.Tensor] = None, ): - from sglang.srt.model_executor.forward_batch_info import ForwardMode - if fn_name == "init_forward_metadata_capture_cuda_graph": assert capture_num_tokens == bs, "Only support num_tokens==bs currently" num_tokens = bs - forward_mode_for_tbo_split = ( - forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE - ) - tbo_split_seq_index = two_batch_overlap.compute_split_seq_index( - forward_mode=forward_mode_for_tbo_split, - num_tokens=num_tokens, - extend_lens=None, - ) - tbo_split_token_index = two_batch_overlap.compute_split_token_index( - split_seq_index=tbo_split_seq_index, - forward_mode=forward_mode_for_tbo_split, - extend_seq_lens=None, + tbo_split_seq_index, tbo_split_token_index = ( + two_batch_overlap.compute_split_indices_for_cuda_graph_replay( + forward_mode=forward_mode, + cuda_graph_num_tokens=num_tokens, + ) ) num_tokens_child_left = tbo_split_token_index diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 74f45fb09..341cce09a 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -40,7 +40,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ) from sglang.srt.patch_torch import monkey_patch_torch_compile from sglang.srt.two_batch_overlap import ( - TboCudaGraphRunnerUtils, + TboCudaGraphRunnerPlugin, TboForwardBatchPreparer, ) from sglang.srt.utils import ( @@ -256,6 +256,7 @@ class CudaGraphRunner: self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32) + self.tbo_plugin = TboCudaGraphRunnerPlugin() # pipeline parallelism if self.pp_size > 1: @@ -481,12 +482,9 @@ class CudaGraphRunner: capture_hidden_mode=self.capture_hidden_mode, lora_paths=lora_paths, num_token_non_padded=self.num_token_non_padded, - tbo_split_seq_index=TboCudaGraphRunnerUtils.compute_tbo_split_seq_index( - self, num_tokens - ), global_forward_mode=self.capture_forward_mode, ) - TboForwardBatchPreparer.prepare(forward_batch) + self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens) if lora_paths is not None: self.model_runner.lora_manager.prepare_lora_batch(forward_batch) @@ -581,7 +579,13 @@ class CudaGraphRunner: self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) self.positions[:raw_num_token].copy_(forward_batch.positions) - self.num_token_non_padded[...] = len(forward_batch.input_ids) + num_token_non_padded = len(forward_batch.input_ids) + self.num_token_non_padded[...] = num_token_non_padded + self.tbo_plugin.replay_prepare( + forward_mode=forward_batch.forward_mode, + bs=bs, + num_token_non_padded=num_token_non_padded, + ) if forward_batch.seq_lens_cpu is not None: if bs != raw_bs: self.seq_lens_cpu.fill_(1) diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index 78bc6b431..6b0241f40 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -85,25 +85,54 @@ def compute_split_token_index( raise NotImplementedError +def compute_split_indices_for_cuda_graph_replay( + forward_mode: ForwardMode, + cuda_graph_num_tokens: int, +): + forward_mode_for_tbo_split = ( + forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE + ) + tbo_split_seq_index = compute_split_seq_index( + forward_mode=forward_mode_for_tbo_split, + num_tokens=cuda_graph_num_tokens, + extend_lens=None, + ) + tbo_split_token_index = compute_split_token_index( + split_seq_index=tbo_split_seq_index, + forward_mode=forward_mode_for_tbo_split, + extend_seq_lens=None, + ) + return tbo_split_seq_index, tbo_split_token_index + + # -------------------------------- Preparation --------------------------------------- -class TboCudaGraphRunnerUtils: - @staticmethod - def compute_tbo_split_seq_index(that: "CudaGraphRunner", num_tokens: int): - if that.model_runner.server_args.enable_two_batch_overlap: - tbo_split_seq_index = compute_split_seq_index( - forward_mode=that.capture_forward_mode, - num_tokens=num_tokens, - extend_lens=None, - ) - # For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true - assert ( - tbo_split_seq_index is not None - ), f"{that.capture_forward_mode=} {num_tokens=}" - else: - tbo_split_seq_index = None - return tbo_split_seq_index +class TboCudaGraphRunnerPlugin: + def __init__(self): + pass # TODO add logic here + + def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int): + if not global_server_args_dict["enable_two_batch_overlap"]: + return + + batch.tbo_split_seq_index = compute_split_seq_index( + forward_mode=batch.forward_mode, + num_tokens=num_tokens, + extend_lens=None, + ) + # For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true + assert batch.tbo_split_seq_index is not None, f"{num_tokens=}" + + TboForwardBatchPreparer.prepare(batch) + + def replay_prepare( + self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int + ): + if not global_server_args_dict["enable_two_batch_overlap"]: + return + + pass # TODO add logic here class TboDPAttentionPreparer: