[Feature] improve TBO: two chunk overlap (#8144)
This commit is contained in:
@@ -84,6 +84,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"disable_radix_cache",
|
||||
"enable_dp_attention",
|
||||
"enable_two_batch_overlap",
|
||||
"tbo_token_distribution_threshold",
|
||||
"enable_dp_lm_head",
|
||||
"moe_a2a_backend",
|
||||
"deepep_mode",
|
||||
|
||||
@@ -420,16 +420,12 @@ class ForwardBatch:
|
||||
batch.extend_prefix_lens, dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
ret.extend_num_tokens = batch.extend_num_tokens
|
||||
if support_triton(model_runner.server_args.attention_backend):
|
||||
positions, ret.extend_start_loc = compute_position_triton(
|
||||
ret.extend_prefix_lens,
|
||||
ret.extend_seq_lens,
|
||||
ret.extend_num_tokens,
|
||||
)
|
||||
else:
|
||||
positions, ret.extend_start_loc = compute_position_torch(
|
||||
ret.extend_prefix_lens, ret.extend_seq_lens
|
||||
)
|
||||
positions, ret.extend_start_loc = compute_position(
|
||||
model_runner.server_args.attention_backend,
|
||||
ret.extend_prefix_lens,
|
||||
ret.extend_seq_lens,
|
||||
ret.extend_num_tokens,
|
||||
)
|
||||
if ret.positions is None:
|
||||
ret.positions = positions
|
||||
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
|
||||
@@ -882,6 +878,25 @@ class PPProxyTensors:
|
||||
return f"PPProxyTensors(tensors={self.tensors})"
|
||||
|
||||
|
||||
def compute_position(
|
||||
attn_backend: str,
|
||||
extend_prefix_lens: torch.Tensor,
|
||||
extend_seq_lens: torch.Tensor,
|
||||
extend_seq_lens_sum: int,
|
||||
):
|
||||
if support_triton(attn_backend):
|
||||
positions, extend_start_loc = compute_position_triton(
|
||||
extend_prefix_lens,
|
||||
extend_seq_lens,
|
||||
extend_seq_lens_sum,
|
||||
)
|
||||
else:
|
||||
positions, extend_start_loc = compute_position_torch(
|
||||
extend_prefix_lens, extend_seq_lens
|
||||
)
|
||||
return positions, extend_start_loc
|
||||
|
||||
|
||||
def compute_position_triton(
|
||||
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
||||
):
|
||||
|
||||
@@ -229,6 +229,7 @@ class ServerArgs:
|
||||
enable_dp_attention: bool = False
|
||||
enable_dp_lm_head: bool = False
|
||||
enable_two_batch_overlap: bool = False
|
||||
tbo_token_distribution_threshold: float = 0.48
|
||||
enable_torch_compile: bool = False
|
||||
torch_compile_max_bs: int = 32
|
||||
torchao_config: str = ""
|
||||
@@ -1689,6 +1690,12 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enabling two micro batches to overlap.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tbo-token-distribution-threshold",
|
||||
type=float,
|
||||
default=ServerArgs.tbo_token_distribution_threshold,
|
||||
help="The threshold of token distribution between two batches in micro-batch-overlap, determines whether to two-batch-overlap or two-chunk-overlap. Set to 0 denote disable two-chunk-overlap.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-torch-compile",
|
||||
action="store_true",
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import logging
|
||||
from dataclasses import replace
|
||||
@@ -17,7 +18,11 @@ from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
compute_position,
|
||||
)
|
||||
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
||||
from sglang.srt.operations_strategy import OperationsStrategy
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
@@ -58,7 +63,7 @@ def compute_split_seq_index(
|
||||
) -> Optional[int]:
|
||||
if forward_mode == ForwardMode.EXTEND:
|
||||
assert extend_lens is not None
|
||||
return _split_array_by_half_sum(extend_lens)
|
||||
return _split_extend_seqs(extend_lens)
|
||||
elif forward_mode.is_target_verify() or forward_mode.is_decode():
|
||||
assert token_num_per_seq is not None
|
||||
return (num_tokens // token_num_per_seq) // 2
|
||||
@@ -69,7 +74,43 @@ def compute_split_seq_index(
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def _split_array_by_half_sum(arr: Sequence[int]) -> int:
|
||||
def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool:
|
||||
if extend_lens is None:
|
||||
return False
|
||||
|
||||
vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens)
|
||||
left_sum = sum(extend_lens[:vanilla_split_seq_index])
|
||||
overall_sum = sum(extend_lens)
|
||||
threshold = global_server_args_dict["tbo_token_distribution_threshold"]
|
||||
assert threshold <= 0.5, f"{threshold=}"
|
||||
return left_sum < overall_sum * threshold or left_sum > overall_sum * (
|
||||
1 - threshold
|
||||
)
|
||||
|
||||
|
||||
def _split_extend_seqs(arr: Sequence[int]) -> int:
|
||||
if _is_two_chunk_split_enabled(arr):
|
||||
return _split_array_by_cum_less_than_half(arr)
|
||||
|
||||
return _split_array_by_balanced_sum(arr)
|
||||
|
||||
|
||||
def _split_array_by_cum_less_than_half(arr: Sequence[int]) -> int:
|
||||
left_sum = 0
|
||||
overall_sum = sum(arr)
|
||||
half_sum = overall_sum // 2
|
||||
chosen_index = 0
|
||||
|
||||
for i in range(len(arr)):
|
||||
left_sum += arr[i]
|
||||
if left_sum > half_sum:
|
||||
chosen_index = i
|
||||
break
|
||||
|
||||
return chosen_index
|
||||
|
||||
|
||||
def _split_array_by_balanced_sum(arr: Sequence[int]) -> int:
|
||||
overall_sum = sum(arr)
|
||||
left_sum = 0
|
||||
min_diff = float("inf")
|
||||
@@ -88,6 +129,34 @@ def _split_array_by_half_sum(arr: Sequence[int]) -> int:
|
||||
return best_index
|
||||
|
||||
|
||||
def _update_device_and_sum_field_from_cpu_field(
|
||||
batch: ForwardBatch, cpu_field: str, device_field: str, sum_field: str = None
|
||||
):
|
||||
cpu_value = getattr(batch, cpu_field, None)
|
||||
old_device_value = getattr(batch, device_field, None)
|
||||
if (
|
||||
cpu_value is None
|
||||
or old_device_value is None
|
||||
or not (isinstance(cpu_value, torch.Tensor) or isinstance(cpu_value, list))
|
||||
):
|
||||
return
|
||||
|
||||
new_device_value = (
|
||||
cpu_value
|
||||
if isinstance(cpu_value, torch.Tensor)
|
||||
else torch.tensor(cpu_value, dtype=old_device_value.dtype)
|
||||
).to(device=global_server_args_dict["device"], non_blocking=True)
|
||||
setattr(batch, device_field, new_device_value)
|
||||
|
||||
if sum_field is not None:
|
||||
sum_value = (
|
||||
cpu_value.sum().item()
|
||||
if isinstance(cpu_value, torch.Tensor)
|
||||
else sum(cpu_value)
|
||||
)
|
||||
setattr(batch, sum_field, sum_value)
|
||||
|
||||
|
||||
def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int:
|
||||
if seq_index == 0:
|
||||
return 0
|
||||
@@ -181,6 +250,8 @@ def compute_split_token_index(
|
||||
) -> int:
|
||||
if forward_mode == ForwardMode.EXTEND:
|
||||
assert extend_seq_lens is not None
|
||||
if _is_two_chunk_split_enabled(extend_seq_lens):
|
||||
return sum(extend_seq_lens) // 2
|
||||
return sum(extend_seq_lens[:split_seq_index])
|
||||
elif forward_mode.is_target_verify() or forward_mode.is_decode():
|
||||
assert token_num_per_seq is not None
|
||||
@@ -388,9 +459,15 @@ class TboForwardBatchPreparer:
|
||||
|
||||
tbo_split_token_index = cls._compute_split_token_index(batch)
|
||||
|
||||
is_enable_two_chunk = (
|
||||
batch.forward_mode == ForwardMode.EXTEND
|
||||
and _is_two_chunk_split_enabled(batch.extend_seq_lens_cpu)
|
||||
)
|
||||
|
||||
if _tbo_debug:
|
||||
logger.info(
|
||||
f"TboForwardBatchPreparer.prepare "
|
||||
f"is_enable_two_chunk={is_enable_two_chunk} "
|
||||
f"tbo_split_seq_index={batch.tbo_split_seq_index} "
|
||||
f"tbo_split_token_index={tbo_split_token_index} "
|
||||
f"extend_seq_lens={batch.extend_seq_lens_cpu} "
|
||||
@@ -410,7 +487,11 @@ class TboForwardBatchPreparer:
|
||||
start_token_index=0,
|
||||
end_token_index=tbo_split_token_index,
|
||||
start_seq_index=0,
|
||||
end_seq_index=batch.tbo_split_seq_index,
|
||||
end_seq_index=(
|
||||
batch.tbo_split_seq_index + 1
|
||||
if is_enable_two_chunk
|
||||
else batch.tbo_split_seq_index
|
||||
),
|
||||
output_attn_backend=attn_backend_child_a,
|
||||
out_num_token_non_padded=out_num_token_non_padded_a,
|
||||
)
|
||||
@@ -424,9 +505,79 @@ class TboForwardBatchPreparer:
|
||||
out_num_token_non_padded=out_num_token_non_padded_b,
|
||||
)
|
||||
|
||||
if is_enable_two_chunk:
|
||||
cls.derive_fields_related_to_seq_len_for_two_chunk(
|
||||
batch,
|
||||
child_a=child_a,
|
||||
child_b=child_b,
|
||||
tbo_split_seq_index=batch.tbo_split_seq_index,
|
||||
)
|
||||
|
||||
assert batch.tbo_children is None
|
||||
batch.tbo_children = [child_a, child_b]
|
||||
|
||||
@classmethod
|
||||
def derive_fields_related_to_seq_len_for_two_chunk(
|
||||
cls,
|
||||
batch: ForwardBatch,
|
||||
*,
|
||||
child_a: ForwardBatch,
|
||||
child_b: ForwardBatch,
|
||||
tbo_split_seq_index: int,
|
||||
):
|
||||
extend_seq_lens_cpu = batch.extend_seq_lens_cpu
|
||||
overall_seq_lens_sum = sum(extend_seq_lens_cpu)
|
||||
half_seq_lens_sum = overall_seq_lens_sum // 2
|
||||
left_last_seq_token_num = half_seq_lens_sum - sum(
|
||||
extend_seq_lens_cpu[:tbo_split_seq_index]
|
||||
)
|
||||
right_first_seq_token_num = (
|
||||
extend_seq_lens_cpu[tbo_split_seq_index] - left_last_seq_token_num
|
||||
)
|
||||
|
||||
# making deepcopy to be extra safe
|
||||
child_a.extend_seq_lens_cpu = copy.deepcopy(child_a.extend_seq_lens_cpu)
|
||||
child_a.extend_seq_lens_cpu[-1] = left_last_seq_token_num
|
||||
child_b.extend_seq_lens_cpu = copy.deepcopy(child_b.extend_seq_lens_cpu)
|
||||
child_b.extend_seq_lens_cpu[0] = right_first_seq_token_num
|
||||
for child in [child_a, child_b]:
|
||||
_update_device_and_sum_field_from_cpu_field(
|
||||
batch=child,
|
||||
cpu_field="extend_seq_lens_cpu",
|
||||
device_field="extend_seq_lens",
|
||||
sum_field="extend_num_tokens",
|
||||
)
|
||||
|
||||
assert (
|
||||
child_a.extend_num_tokens == half_seq_lens_sum
|
||||
), f"{child_a.extend_num_tokens=}, {half_seq_lens_sum=}"
|
||||
|
||||
child_a.seq_lens_cpu = copy.deepcopy(child_a.seq_lens_cpu)
|
||||
child_a.seq_lens_cpu[-1] = (
|
||||
child_a.extend_seq_lens_cpu[-1] + child_a.extend_prefix_lens_cpu[-1]
|
||||
)
|
||||
_update_device_and_sum_field_from_cpu_field(
|
||||
batch=child_a,
|
||||
cpu_field="seq_lens_cpu",
|
||||
device_field="seq_lens",
|
||||
sum_field="seq_lens_sum",
|
||||
)
|
||||
|
||||
child_b.extend_prefix_lens_cpu = copy.deepcopy(child_b.extend_prefix_lens_cpu)
|
||||
child_b.extend_prefix_lens_cpu[0] += left_last_seq_token_num
|
||||
_update_device_and_sum_field_from_cpu_field(
|
||||
batch=child_b,
|
||||
cpu_field="extend_prefix_lens_cpu",
|
||||
device_field="extend_prefix_lens",
|
||||
sum_field=None,
|
||||
)
|
||||
_, child_b.extend_start_loc = compute_position(
|
||||
global_server_args_dict["attention_backend"],
|
||||
child_b.extend_prefix_lens,
|
||||
child_b.extend_seq_lens,
|
||||
child_b.extend_num_tokens,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def filter_batch(
|
||||
cls,
|
||||
|
||||
Reference in New Issue
Block a user