diff --git a/python/sglang/srt/layers/attention/tbo_backend.py b/python/sglang/srt/layers/attention/tbo_backend.py index 4ad8c5b87..06cfbd4ef 100644 --- a/python/sglang/srt/layers/attention/tbo_backend.py +++ b/python/sglang/srt/layers/attention/tbo_backend.py @@ -119,21 +119,27 @@ class TboAttnBackend(AttentionBackend): replay_seq_lens_sum: int = None, replay_seq_lens_cpu: Optional[torch.Tensor] = None, ): + token_num_per_seq = two_batch_overlap.get_token_num_per_seq( + forward_mode=forward_mode, spec_info=spec_info + ) if fn_name == "init_forward_metadata_capture_cuda_graph": - assert capture_num_tokens == bs, "Only support num_tokens==bs currently" - num_tokens = bs + assert ( + capture_num_tokens == bs * token_num_per_seq + ), "For target-verify or decode mode, num_tokens should be equal to token_num_per_seq * bs" + num_tokens = bs * token_num_per_seq tbo_split_seq_index, tbo_split_token_index = ( two_batch_overlap.compute_split_indices_for_cuda_graph_replay( forward_mode=forward_mode, cuda_graph_num_tokens=num_tokens, + spec_info=spec_info, ) ) num_tokens_child_left = tbo_split_token_index num_tokens_child_right = num_tokens - tbo_split_token_index - bs_child_left = num_tokens_child_left - bs_child_right = num_tokens_child_right + bs_child_left = tbo_split_seq_index + bs_child_right = bs - bs_child_left assert ( num_tokens_child_left > 0 and num_tokens_child_right > 0 @@ -190,16 +196,36 @@ def _init_forward_metadata_cuda_graph_split( seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: "ForwardMode", - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[EagleVerifyInput], # capture args capture_num_tokens: int = None, # replay args replay_seq_lens_sum: int = None, replay_seq_lens_cpu: Optional[torch.Tensor] = None, ): + token_num_per_seq = two_batch_overlap.get_token_num_per_seq( + forward_mode=forward_mode, spec_info=spec_info + ) assert encoder_lens is None, "encoder_lens is not supported yet" - assert spec_info is None, "spec_info is not supported yet" + if spec_info is not None: + output_spec_info = two_batch_overlap.split_spec_info( + spec_info=spec_info, + start_seq_index=seq_slice.start if seq_slice.start is not None else 0, + end_seq_index=seq_slice.stop if seq_slice.stop is not None else bs, + start_token_index=( + seq_slice.start * token_num_per_seq + if seq_slice.start is not None + else 0 + ), + end_token_index=( + seq_slice.stop * token_num_per_seq + if seq_slice.stop is not None + else bs * token_num_per_seq + ), + ) + else: + output_spec_info = None ans = dict( bs=output_bs, req_pool_indices=req_pool_indices[seq_slice], @@ -208,14 +234,16 @@ def _init_forward_metadata_cuda_graph_split( forward_mode=forward_mode, # ignore encoder_lens=None, - spec_info=None, + spec_info=output_spec_info, ) if fn_name == "init_forward_metadata_capture_cuda_graph": - assert capture_num_tokens == bs, "Only support num_tokens==bs currently" + assert ( + capture_num_tokens == bs * token_num_per_seq + ), "Only support num_tokens==bs * token_num_per_seq for target-verify or decode mode" ans.update( dict( - num_tokens=output_bs, + num_tokens=output_bs * token_num_per_seq, ) ) elif fn_name == "init_forward_metadata_replay_cuda_graph": diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 48f62d28b..1f654ca7e 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -679,6 +679,7 @@ class CudaGraphRunner: forward_mode=self.capture_forward_mode, bs=bs, num_token_non_padded=len(forward_batch.input_ids), + spec_info=forward_batch.spec_info, ) if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None: forward_batch.spec_info.custom_mask = self.custom_mask diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 28d8c2bfa..cc01e963e 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -352,7 +352,9 @@ class ForwardBatch: if ret.forward_mode.is_idle(): ret.positions = torch.empty((0,), device=device) - TboForwardBatchPreparer.prepare(ret) + TboForwardBatchPreparer.prepare( + ret, is_draft_worker=model_runner.is_draft_worker + ) return ret # Override the positions with spec_info @@ -397,7 +399,9 @@ class ForwardBatch: if model_runner.server_args.lora_paths is not None: model_runner.lora_manager.prepare_lora_batch(ret) - TboForwardBatchPreparer.prepare(ret) + TboForwardBatchPreparer.prepare( + ret, is_draft_worker=model_runner.is_draft_worker + ) return ret diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 21f4b968d..0c9185bb9 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1039,7 +1039,7 @@ class ModelRunner: def init_attention_backend(self): """Init attention kernel backend.""" - if self.server_args.enable_two_batch_overlap: + if self.server_args.enable_two_batch_overlap and not self.is_draft_worker: self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend) else: self.attn_backend = self._get_attention_backend() diff --git a/python/sglang/srt/operations_strategy.py b/python/sglang/srt/operations_strategy.py index 8821a05eb..6000b5e8f 100644 --- a/python/sglang/srt/operations_strategy.py +++ b/python/sglang/srt/operations_strategy.py @@ -71,7 +71,9 @@ def _compute_moe_deepseek_layer_operations_strategy_tbo( assert layer.is_layer_sparse, "dense layer TBO not yet implemented" if forward_mode == ForwardMode.EXTEND: return _compute_moe_deepseek_blog_prefill(layer) - elif forward_mode == ForwardMode.DECODE: + elif ( + forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY + ): return _compute_moe_deepseek_blog_decode(layer) else: raise NotImplementedError(f"Unsupported {forward_mode=}") @@ -146,7 +148,9 @@ def _compute_moe_qwen3_layer_operations_strategy_tbo( assert layer.is_layer_sparse, "qwen3 moe only support sparse layers" if forward_mode == ForwardMode.EXTEND: return _compute_moe_qwen3_prefill(layer) - elif forward_mode == ForwardMode.DECODE: + elif ( + forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY + ): return _compute_moe_qwen3_decode(layer) else: raise NotImplementedError(f"Unsupported {forward_mode=}") diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index a2f3936a4..0879bcf2b 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -1,6 +1,7 @@ import dataclasses import logging -from typing import Dict, List, Optional, Sequence +from dataclasses import replace +from typing import Dict, List, Optional, Sequence, Union import torch @@ -16,6 +17,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.operations import execute_operations, execute_overlapped_operations from sglang.srt.operations_strategy import OperationsStrategy +from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var _tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG") @@ -26,17 +28,34 @@ logger = logging.getLogger(__name__) # -------------------------------- Compute Basic Info --------------------------------------- +def get_token_num_per_seq( + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None, +): + if forward_mode.is_target_verify(): + return spec_info.draft_token_num + elif forward_mode.is_decode(): + return 1 + elif forward_mode.is_idle(): + return 0 + else: + # For extend, we should not use `token_num_per_seq`. + return None + + # TODO: may smartly disable TBO when batch size is too small b/c it will slow down def compute_split_seq_index( forward_mode: "ForwardMode", num_tokens: int, extend_lens: Optional[Sequence[int]], + token_num_per_seq: Optional[int], ) -> Optional[int]: - if forward_mode.is_extend(): + if forward_mode == ForwardMode.EXTEND: assert extend_lens is not None return _split_array_by_half_sum(extend_lens) - elif forward_mode.is_decode(): - return num_tokens // 2 + elif forward_mode.is_target_verify() or forward_mode.is_decode(): + assert token_num_per_seq is not None + return (num_tokens // token_num_per_seq) // 2 elif forward_mode.is_idle(): assert num_tokens == 0 return 0 @@ -63,16 +82,103 @@ def _split_array_by_half_sum(arr: Sequence[int]) -> int: return best_index +def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int: + if seq_index == 0: + return 0 + + offset = 0 + max_seq_len = min(seq_index, spec_info.seq_lens_cpu.shape[0]) + for i in range(max_seq_len): + offset += ( + spec_info.seq_lens_cpu[i] + spec_info.draft_token_num + ) * spec_info.draft_token_num + return offset + + +def split_spec_info( + spec_info: Optional[EagleVerifyInput], + start_seq_index: int, + end_seq_index: int, + start_token_index: int, + end_token_index: int, +): + if spec_info is None: + return None + if spec_info.draft_token is not None: + draft_token = spec_info.draft_token[start_token_index:end_token_index] + else: + draft_token = None + if spec_info.custom_mask is not None and spec_info.draft_token is not None: + custom_mask_start = _compute_mask_offset(start_seq_index, spec_info) + if end_seq_index == spec_info.seq_lens_cpu.shape[0]: + custom_mask_end = spec_info.custom_mask.shape[0] + else: + custom_mask_end = _compute_mask_offset(end_seq_index, spec_info) + + if custom_mask_end > custom_mask_start: + custom_mask = spec_info.custom_mask[custom_mask_start:custom_mask_end] + else: + custom_mask = spec_info.custom_mask + else: + custom_mask = spec_info.custom_mask + if spec_info.positions is not None: + positions = spec_info.positions[start_token_index:end_token_index] + else: + positions = None + if spec_info.retrive_index is not None: + retrive_index = spec_info.retrive_index[start_seq_index:end_seq_index] + else: + retrive_index = None + if spec_info.retrive_next_token is not None: + retrive_next_token = spec_info.retrive_next_token[start_seq_index:end_seq_index] + else: + retrive_next_token = None + if spec_info.retrive_next_sibling is not None: + retrive_next_sibling = spec_info.retrive_next_sibling[ + start_seq_index:end_seq_index + ] + else: + retrive_next_sibling = None + if spec_info.retrive_cum_len is not None: + retrive_cum_len = spec_info.retrive_cum_len[start_seq_index:end_seq_index] + else: + retrive_cum_len = None + + if spec_info.seq_lens_cpu is not None: + seq_lens_cpu = spec_info.seq_lens_cpu[start_seq_index:end_seq_index] + else: + seq_lens_cpu = None + if seq_lens_cpu is not None: + seq_lens_sum = seq_lens_cpu.sum() + else: + seq_lens_sum = None + output_spec_info = replace( + spec_info, + custom_mask=custom_mask, + draft_token=draft_token, + positions=positions, + retrive_index=retrive_index, + retrive_next_token=retrive_next_token, + retrive_next_sibling=retrive_next_sibling, + retrive_cum_len=retrive_cum_len, + seq_lens_cpu=seq_lens_cpu, + seq_lens_sum=seq_lens_sum, + ) + return output_spec_info + + def compute_split_token_index( split_seq_index: int, forward_mode: "ForwardMode", extend_seq_lens: Optional[Sequence[int]], + token_num_per_seq: Optional[int], ) -> int: - if forward_mode.is_extend(): + if forward_mode == ForwardMode.EXTEND: assert extend_seq_lens is not None return sum(extend_seq_lens[:split_seq_index]) - elif forward_mode.is_decode(): - return split_seq_index + elif forward_mode.is_target_verify() or forward_mode.is_decode(): + assert token_num_per_seq is not None + return split_seq_index * token_num_per_seq elif forward_mode.is_idle(): assert split_seq_index == 0 return 0 @@ -83,19 +189,25 @@ def compute_split_token_index( def compute_split_indices_for_cuda_graph_replay( forward_mode: ForwardMode, cuda_graph_num_tokens: int, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): forward_mode_for_tbo_split = ( forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE ) + token_num_per_seq = get_token_num_per_seq( + forward_mode=forward_mode, spec_info=spec_info + ) tbo_split_seq_index = compute_split_seq_index( forward_mode=forward_mode_for_tbo_split, num_tokens=cuda_graph_num_tokens, extend_lens=None, + token_num_per_seq=token_num_per_seq, ) tbo_split_token_index = compute_split_token_index( split_seq_index=tbo_split_seq_index, forward_mode=forward_mode_for_tbo_split, extend_seq_lens=None, + token_num_per_seq=token_num_per_seq, ) return tbo_split_seq_index, tbo_split_token_index @@ -110,11 +222,15 @@ class TboCudaGraphRunnerPlugin: def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int): if not global_server_args_dict["enable_two_batch_overlap"]: return + token_num_per_seq = get_token_num_per_seq( + forward_mode=batch.forward_mode, spec_info=batch.spec_info + ) batch.tbo_split_seq_index = compute_split_seq_index( forward_mode=batch.forward_mode, num_tokens=num_tokens, extend_lens=None, + token_num_per_seq=token_num_per_seq, ) # For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true assert batch.tbo_split_seq_index is not None, f"{num_tokens=}" @@ -129,13 +245,20 @@ class TboCudaGraphRunnerPlugin: ) def replay_prepare( - self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int + self, + forward_mode: ForwardMode, + bs: int, + num_token_non_padded: int, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): + token_num_per_seq = get_token_num_per_seq( + forward_mode=forward_mode, spec_info=spec_info + ) tbo_split_seq_index, tbo_split_token_index = ( compute_split_indices_for_cuda_graph_replay( forward_mode=forward_mode, - # TODO support bs!=num_tokens - cuda_graph_num_tokens=bs, + cuda_graph_num_tokens=bs * token_num_per_seq, + spec_info=spec_info, ) ) @@ -154,14 +277,29 @@ class TboDPAttentionPreparer: self.enable_two_batch_overlap = enable_two_batch_overlap if local_batch is not None: + token_num_per_seq = get_token_num_per_seq( + forward_mode=local_batch.forward_mode, spec_info=local_batch.spec_info + ) + + if ( + local_batch.forward_mode.is_target_verify() + or local_batch.forward_mode.is_decode() + ): + num_tokens = local_batch.batch_size() * token_num_per_seq + else: + num_tokens = local_batch.extend_num_tokens self.local_tbo_split_seq_index = compute_split_seq_index( forward_mode=local_batch.forward_mode, - num_tokens=local_batch.input_ids.shape[0], + num_tokens=num_tokens, extend_lens=local_batch.extend_lens, + token_num_per_seq=token_num_per_seq, ) resolved_deepep_mode = deepep_mode.resolve(local_batch.forward_mode) local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not ( - local_batch.forward_mode.is_extend() + ( + local_batch.forward_mode.is_extend() + and not local_batch.forward_mode.is_target_verify() + ) and enable_deepep_moe and (resolved_deepep_mode == DeepEPMode.low_latency) ) @@ -218,8 +356,8 @@ class TboDPAttentionPreparer: class TboForwardBatchPreparer: @classmethod - def prepare(cls, batch: ForwardBatch): - if batch.tbo_split_seq_index is None: + def prepare(cls, batch: ForwardBatch, is_draft_worker: bool = False): + if batch.tbo_split_seq_index is None or is_draft_worker: return tbo_children_num_token_non_padded = ( @@ -242,7 +380,9 @@ class TboForwardBatchPreparer: f"TboForwardBatchPreparer.prepare " f"tbo_split_seq_index={batch.tbo_split_seq_index} " f"tbo_split_token_index={tbo_split_token_index} " - f"extend_seq_lens={batch.extend_seq_lens_cpu}" + f"extend_seq_lens={batch.extend_seq_lens_cpu} " + f"bs={batch.batch_size} " + f"forward_mode={batch.forward_mode}" ) assert isinstance(batch.attn_backend, TboAttnBackend) @@ -286,6 +426,9 @@ class TboForwardBatchPreparer: output_attn_backend: AttentionBackend, out_num_token_non_padded: torch.Tensor, ): + assert ( + end_token_index >= start_token_index + ), f"{end_token_index=}, {start_token_index=}, batch={batch}" num_tokens = batch.input_ids.shape[0] num_seqs = batch.batch_size @@ -317,11 +460,30 @@ class TboForwardBatchPreparer: old_value = getattr(batch, key) if old_value is None: continue + elif batch.forward_mode.is_target_verify() and ( + key == "extend_seq_lens" + or key == "extend_prefix_lens" + or key == "extend_start_loc" + or key == "extend_prefix_lens_cpu" + or key == "extend_seq_lens_cpu" + or key == "extend_logprob_start_lens_cpu" + ): + output_dict[key] = None + continue assert ( len(old_value) == num_seqs ), f"{key=} {old_value=} {num_seqs=} {batch=}" output_dict[key] = old_value[start_seq_index:end_seq_index] + spec_info = getattr(batch, "spec_info") + output_spec_info = split_spec_info( + spec_info=spec_info, + start_token_index=start_token_index, + end_token_index=end_token_index, + start_seq_index=start_seq_index, + end_seq_index=end_seq_index, + ) + output_dict["spec_info"] = output_spec_info for key in [ "forward_mode", "return_logprob", @@ -329,18 +491,17 @@ class TboForwardBatchPreparer: "token_to_kv_pool", "can_run_dp_cuda_graph", "global_forward_mode", - "spec_info", "spec_algorithm", "capture_hidden_mode", "padded_static_len", "mrope_positions", # only used by qwen2-vl, thus not care ]: output_dict[key] = getattr(batch, key) - - assert ( - _compute_extend_num_tokens(batch.input_ids, batch.forward_mode) - == batch.extend_num_tokens - ), f"{batch=}" + if not batch.forward_mode.is_target_verify(): + assert ( + _compute_extend_num_tokens(batch.input_ids, batch.forward_mode) + == batch.extend_num_tokens + ), f"{batch=}" extend_num_tokens = _compute_extend_num_tokens( output_dict["input_ids"], output_dict["forward_mode"] ) @@ -419,18 +580,26 @@ class TboForwardBatchPreparer: @classmethod def _compute_split_token_index(cls, batch: ForwardBatch): + token_num_per_seq = get_token_num_per_seq( + forward_mode=batch.forward_mode, spec_info=batch.spec_info + ) return compute_split_token_index( split_seq_index=batch.tbo_split_seq_index, forward_mode=batch.forward_mode, extend_seq_lens=batch.extend_seq_lens_cpu, + token_num_per_seq=token_num_per_seq, ) def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode): - if forward_mode.is_extend(): - return input_ids.shape[0] - elif forward_mode.is_decode() or forward_mode.is_idle(): + if ( + forward_mode.is_decode() + or forward_mode.is_idle() + or forward_mode.is_target_verify() + ): return None + elif forward_mode.is_extend(): + return input_ids.shape[0] raise NotImplementedError diff --git a/test/srt/test_dp_attention.py b/test/srt/test_dp_attention.py index af50dc780..085dc206b 100644 --- a/test/srt/test_dp_attention.py +++ b/test/srt/test_dp_attention.py @@ -137,5 +137,86 @@ class TestDPAttentionDP2TP2DeepseekV3MTP(CustomTestCase): self.assertGreater(avg_spec_accept_length, 2.5) +# TODO: enable this test later +# class TestDPAttentionDP2TP2DeepseekV3MTPTBO(CustomTestCase): +# @classmethod +# def setUpClass(cls): +# import os + +# # print debug log for tbo +# os.environ["SGLANG_TBO_DEBUG"] = "1" +# cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA +# cls.base_url = DEFAULT_URL_FOR_TEST +# other_args = [ +# "--trust-remote-code", +# "--disable-radix", +# "--speculative-algorithm", +# "EAGLE", +# "--speculative-num-steps", +# "2", +# "--speculative-eagle-topk", +# "4", +# "--speculative-num-draft-tokens", +# "4", +# "--speculative-draft", +# DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, +# "--tp-size", +# "2", +# "--enable-dp-attention", +# "--dp-size", +# "2", +# "--enable-two-batch-overlap", +# "--enable-deepep-moe", +# "--deepep-mode", +# "low_latency", +# "--chunked-prefill-size", +# "256", +# "--cuda-graph-max-bs", +# "32", +# "--max-running-requests", +# "32", +# ] +# if not is_in_amd_ci(): +# other_args += ["--mem-frac", "0.7"] +# cls.process = popen_launch_server( +# cls.model, +# cls.base_url, +# timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, +# other_args=other_args, +# ) + +# @classmethod +# def tearDownClass(cls): +# kill_process_tree(cls.process.pid) + +# def test_gsm8k(self): +# requests.get(self.base_url + "/flush_cache") + +# args = SimpleNamespace( +# num_shots=5, +# data_path=None, +# num_questions=200, +# max_new_tokens=512, +# parallel=128, +# host="http://127.0.0.1", +# port=int(self.base_url.split(":")[-1]), +# ) +# metrics = run_eval_few_shot_gsm8k(args) +# print(metrics) + +# self.assertGreater(metrics["accuracy"], 0.60) + +# server_info = requests.get(self.base_url + "/get_server_info") +# avg_spec_accept_length = server_info.json()["internal_states"][0][ +# "avg_spec_accept_length" +# ] +# print( +# f"###test_gsm8k (deepseek-v3 mtp + dp + tbo):\n" +# f"accuracy={metrics['accuracy']=:.3f}\n" +# f"{avg_spec_accept_length=:.3f}\n" +# ) +# self.assertGreater(avg_spec_accept_length, 2.3) + + if __name__ == "__main__": unittest.main()