diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index a79911bc9..008953bc2 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -262,6 +262,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--enable-dp-attention` | Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently DeepSeek-V2 and Qwen 2/3 MoE models are supported. | False | | `--enable-dp-lm-head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | False | | `--enable-two-batch-overlap` | Enabling two micro batches to overlap. | False | +| `--tbo-token-distribution-threshold` | The threshold of token distribution between two batches in micro-batch-overlap, determines whether to two-batch-overlap or two-chunk-overlap. Set to 0 denote disable two-chunk-overlap. | 0.48 | | `--enable-torch-compile` | Optimize the model with torch.compile. Experimental feature. | False | | `--torch-compile-max-bs` | Set the maximum batch size when using torch compile. | 32 | | `--torchao-config` | Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-, fp8wo, fp8dq-per_tensor, fp8dq-per_row. | | diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 3452608b3..689ef94b3 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -84,6 +84,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ "disable_radix_cache", "enable_dp_attention", "enable_two_batch_overlap", + "tbo_token_distribution_threshold", "enable_dp_lm_head", "moe_a2a_backend", "deepep_mode", diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 984239cc3..6d09f1fdb 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -420,16 +420,12 @@ class ForwardBatch: batch.extend_prefix_lens, dtype=torch.int32 ).to(device, non_blocking=True) ret.extend_num_tokens = batch.extend_num_tokens - if support_triton(model_runner.server_args.attention_backend): - positions, ret.extend_start_loc = compute_position_triton( - ret.extend_prefix_lens, - ret.extend_seq_lens, - ret.extend_num_tokens, - ) - else: - positions, ret.extend_start_loc = compute_position_torch( - ret.extend_prefix_lens, ret.extend_seq_lens - ) + positions, ret.extend_start_loc = compute_position( + model_runner.server_args.attention_backend, + ret.extend_prefix_lens, + ret.extend_seq_lens, + ret.extend_num_tokens, + ) if ret.positions is None: ret.positions = positions ret.extend_prefix_lens_cpu = batch.extend_prefix_lens @@ -882,6 +878,25 @@ class PPProxyTensors: return f"PPProxyTensors(tensors={self.tensors})" +def compute_position( + attn_backend: str, + extend_prefix_lens: torch.Tensor, + extend_seq_lens: torch.Tensor, + extend_seq_lens_sum: int, +): + if support_triton(attn_backend): + positions, extend_start_loc = compute_position_triton( + extend_prefix_lens, + extend_seq_lens, + extend_seq_lens_sum, + ) + else: + positions, extend_start_loc = compute_position_torch( + extend_prefix_lens, extend_seq_lens + ) + return positions, extend_start_loc + + def compute_position_triton( extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum ): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 10e8278a6..225caaf60 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -229,6 +229,7 @@ class ServerArgs: enable_dp_attention: bool = False enable_dp_lm_head: bool = False enable_two_batch_overlap: bool = False + tbo_token_distribution_threshold: float = 0.48 enable_torch_compile: bool = False torch_compile_max_bs: int = 32 torchao_config: str = "" @@ -1689,6 +1690,12 @@ class ServerArgs: action="store_true", help="Enabling two micro batches to overlap.", ) + parser.add_argument( + "--tbo-token-distribution-threshold", + type=float, + default=ServerArgs.tbo_token_distribution_threshold, + help="The threshold of token distribution between two batches in micro-batch-overlap, determines whether to two-batch-overlap or two-chunk-overlap. Set to 0 denote disable two-chunk-overlap.", + ) parser.add_argument( "--enable-torch-compile", action="store_true", diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index 34afd043f..7e0602a20 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import dataclasses import logging from dataclasses import replace @@ -17,7 +18,11 @@ from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.utils import DeepEPMode from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_executor.forward_batch_info import ( + ForwardBatch, + ForwardMode, + compute_position, +) from sglang.srt.operations import execute_operations, execute_overlapped_operations from sglang.srt.operations_strategy import OperationsStrategy from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput @@ -58,7 +63,7 @@ def compute_split_seq_index( ) -> Optional[int]: if forward_mode == ForwardMode.EXTEND: assert extend_lens is not None - return _split_array_by_half_sum(extend_lens) + return _split_extend_seqs(extend_lens) elif forward_mode.is_target_verify() or forward_mode.is_decode(): assert token_num_per_seq is not None return (num_tokens // token_num_per_seq) // 2 @@ -69,7 +74,43 @@ def compute_split_seq_index( raise NotImplementedError() -def _split_array_by_half_sum(arr: Sequence[int]) -> int: +def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool: + if extend_lens is None: + return False + + vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens) + left_sum = sum(extend_lens[:vanilla_split_seq_index]) + overall_sum = sum(extend_lens) + threshold = global_server_args_dict["tbo_token_distribution_threshold"] + assert threshold <= 0.5, f"{threshold=}" + return left_sum < overall_sum * threshold or left_sum > overall_sum * ( + 1 - threshold + ) + + +def _split_extend_seqs(arr: Sequence[int]) -> int: + if _is_two_chunk_split_enabled(arr): + return _split_array_by_cum_less_than_half(arr) + + return _split_array_by_balanced_sum(arr) + + +def _split_array_by_cum_less_than_half(arr: Sequence[int]) -> int: + left_sum = 0 + overall_sum = sum(arr) + half_sum = overall_sum // 2 + chosen_index = 0 + + for i in range(len(arr)): + left_sum += arr[i] + if left_sum > half_sum: + chosen_index = i + break + + return chosen_index + + +def _split_array_by_balanced_sum(arr: Sequence[int]) -> int: overall_sum = sum(arr) left_sum = 0 min_diff = float("inf") @@ -88,6 +129,34 @@ def _split_array_by_half_sum(arr: Sequence[int]) -> int: return best_index +def _update_device_and_sum_field_from_cpu_field( + batch: ForwardBatch, cpu_field: str, device_field: str, sum_field: str = None +): + cpu_value = getattr(batch, cpu_field, None) + old_device_value = getattr(batch, device_field, None) + if ( + cpu_value is None + or old_device_value is None + or not (isinstance(cpu_value, torch.Tensor) or isinstance(cpu_value, list)) + ): + return + + new_device_value = ( + cpu_value + if isinstance(cpu_value, torch.Tensor) + else torch.tensor(cpu_value, dtype=old_device_value.dtype) + ).to(device=global_server_args_dict["device"], non_blocking=True) + setattr(batch, device_field, new_device_value) + + if sum_field is not None: + sum_value = ( + cpu_value.sum().item() + if isinstance(cpu_value, torch.Tensor) + else sum(cpu_value) + ) + setattr(batch, sum_field, sum_value) + + def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int: if seq_index == 0: return 0 @@ -181,6 +250,8 @@ def compute_split_token_index( ) -> int: if forward_mode == ForwardMode.EXTEND: assert extend_seq_lens is not None + if _is_two_chunk_split_enabled(extend_seq_lens): + return sum(extend_seq_lens) // 2 return sum(extend_seq_lens[:split_seq_index]) elif forward_mode.is_target_verify() or forward_mode.is_decode(): assert token_num_per_seq is not None @@ -388,9 +459,15 @@ class TboForwardBatchPreparer: tbo_split_token_index = cls._compute_split_token_index(batch) + is_enable_two_chunk = ( + batch.forward_mode == ForwardMode.EXTEND + and _is_two_chunk_split_enabled(batch.extend_seq_lens_cpu) + ) + if _tbo_debug: logger.info( f"TboForwardBatchPreparer.prepare " + f"is_enable_two_chunk={is_enable_two_chunk} " f"tbo_split_seq_index={batch.tbo_split_seq_index} " f"tbo_split_token_index={tbo_split_token_index} " f"extend_seq_lens={batch.extend_seq_lens_cpu} " @@ -410,7 +487,11 @@ class TboForwardBatchPreparer: start_token_index=0, end_token_index=tbo_split_token_index, start_seq_index=0, - end_seq_index=batch.tbo_split_seq_index, + end_seq_index=( + batch.tbo_split_seq_index + 1 + if is_enable_two_chunk + else batch.tbo_split_seq_index + ), output_attn_backend=attn_backend_child_a, out_num_token_non_padded=out_num_token_non_padded_a, ) @@ -424,9 +505,79 @@ class TboForwardBatchPreparer: out_num_token_non_padded=out_num_token_non_padded_b, ) + if is_enable_two_chunk: + cls.derive_fields_related_to_seq_len_for_two_chunk( + batch, + child_a=child_a, + child_b=child_b, + tbo_split_seq_index=batch.tbo_split_seq_index, + ) + assert batch.tbo_children is None batch.tbo_children = [child_a, child_b] + @classmethod + def derive_fields_related_to_seq_len_for_two_chunk( + cls, + batch: ForwardBatch, + *, + child_a: ForwardBatch, + child_b: ForwardBatch, + tbo_split_seq_index: int, + ): + extend_seq_lens_cpu = batch.extend_seq_lens_cpu + overall_seq_lens_sum = sum(extend_seq_lens_cpu) + half_seq_lens_sum = overall_seq_lens_sum // 2 + left_last_seq_token_num = half_seq_lens_sum - sum( + extend_seq_lens_cpu[:tbo_split_seq_index] + ) + right_first_seq_token_num = ( + extend_seq_lens_cpu[tbo_split_seq_index] - left_last_seq_token_num + ) + + # making deepcopy to be extra safe + child_a.extend_seq_lens_cpu = copy.deepcopy(child_a.extend_seq_lens_cpu) + child_a.extend_seq_lens_cpu[-1] = left_last_seq_token_num + child_b.extend_seq_lens_cpu = copy.deepcopy(child_b.extend_seq_lens_cpu) + child_b.extend_seq_lens_cpu[0] = right_first_seq_token_num + for child in [child_a, child_b]: + _update_device_and_sum_field_from_cpu_field( + batch=child, + cpu_field="extend_seq_lens_cpu", + device_field="extend_seq_lens", + sum_field="extend_num_tokens", + ) + + assert ( + child_a.extend_num_tokens == half_seq_lens_sum + ), f"{child_a.extend_num_tokens=}, {half_seq_lens_sum=}" + + child_a.seq_lens_cpu = copy.deepcopy(child_a.seq_lens_cpu) + child_a.seq_lens_cpu[-1] = ( + child_a.extend_seq_lens_cpu[-1] + child_a.extend_prefix_lens_cpu[-1] + ) + _update_device_and_sum_field_from_cpu_field( + batch=child_a, + cpu_field="seq_lens_cpu", + device_field="seq_lens", + sum_field="seq_lens_sum", + ) + + child_b.extend_prefix_lens_cpu = copy.deepcopy(child_b.extend_prefix_lens_cpu) + child_b.extend_prefix_lens_cpu[0] += left_last_seq_token_num + _update_device_and_sum_field_from_cpu_field( + batch=child_b, + cpu_field="extend_prefix_lens_cpu", + device_field="extend_prefix_lens", + sum_field=None, + ) + _, child_b.extend_start_loc = compute_position( + global_server_args_dict["attention_backend"], + child_b.extend_prefix_lens, + child_b.extend_seq_lens, + child_b.extend_num_tokens, + ) + @classmethod def filter_batch( cls, diff --git a/test/srt/test_two_batch_overlap.py b/test/srt/test_two_batch_overlap.py index 257d43ca8..6aa550c46 100644 --- a/test/srt/test_two_batch_overlap.py +++ b/test/srt/test_two_batch_overlap.py @@ -5,7 +5,10 @@ from types import SimpleNamespace import requests from sglang.srt.model_executor.forward_batch_info import ForwardMode -from sglang.srt.two_batch_overlap import compute_split_seq_index +from sglang.srt.two_batch_overlap import ( + compute_split_seq_index, + compute_split_token_index, +) from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( @@ -73,35 +76,46 @@ class TestTwoBatchOverlap(unittest.TestCase): class TestTwoBatchOverlapUnitTest(unittest.TestCase): - # TODO change tests when having 6328 - def test_compute_split_seq_index(self): + def test_compute_split_seq_and_token_index(self): for num_tokens, expect in [ (0, 0), (100, 50), (99, 49), ]: actual = compute_split_seq_index( - forward_mode=ForwardMode.DECODE, num_tokens=num_tokens, extend_lens=None + forward_mode=ForwardMode.DECODE, + num_tokens=num_tokens, + extend_lens=None, + token_num_per_seq=1, ) self.assertEqual(actual, expect) for extend_lens, expect in [ - ([], 0), - ([42], 0), - ([42, 999], 1), - ([999, 42], 1), - ([4096, 4096, 4096, 4096], 2), - ([4095, 4096, 4096, 4096, 1], 2), - ([1, 4095, 4096, 4096, 4096], 3), - ([4097, 4096, 4096, 4095, 1], 2), - ([1, 1, 1, 1, 99999], 4), - ([99999, 1, 1, 1, 1], 1), + ([], (0, 0)), + ([42], (0, 21)), + ([42, 999], (1, 520)), + ([999, 42], (0, 520)), + ([498, 502], (1, 498)), + ([4096, 4096, 4096, 4096], (2, 8192)), + ([4095, 4096, 4096, 4096, 1], (2, 8191)), + ([1, 4095, 4096, 4096, 4096], (3, 8192)), + ([4097, 4096, 4096, 4095, 1], (2, 8193)), + ([1, 1, 1, 1, 99999], (4, 50001)), + ([99999, 1, 1, 1, 1], (0, 50001)), ]: - actual = compute_split_seq_index( + actual_seq_idx = compute_split_seq_index( forward_mode=ForwardMode.EXTEND, num_tokens=None, extend_lens=extend_lens, + token_num_per_seq=None, ) + actual_token_idx = compute_split_token_index( + split_seq_index=actual_seq_idx, + forward_mode=ForwardMode.EXTEND, + extend_seq_lens=extend_lens, + token_num_per_seq=None, + ) + actual = (actual_seq_idx, actual_token_idx) print(f"{extend_lens=} {expect=} {actual=}") self.assertEqual(actual, expect)