989 lines
34 KiB
Python
989 lines
34 KiB
Python
from __future__ import annotations
|
|
|
|
import copy
|
|
import dataclasses
|
|
import logging
|
|
from dataclasses import replace
|
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union
|
|
|
|
import torch
|
|
|
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
from sglang.srt.layers.communicator import (
|
|
CommunicateContext,
|
|
CommunicateSummableTensorPairFn,
|
|
ScatterMode,
|
|
)
|
|
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
|
|
from sglang.srt.layers.moe.utils import DeepEPMode
|
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
|
from sglang.srt.model_executor.forward_batch_info import (
|
|
ForwardBatch,
|
|
ForwardMode,
|
|
compute_position,
|
|
)
|
|
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
|
from sglang.srt.operations_strategy import OperationsStrategy
|
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
|
from sglang.srt.utils import BumpAllocator, get_bool_env_var, is_hip
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
|
|
|
|
_is_hip = is_hip()
|
|
|
|
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# -------------------------------- Compute Basic Info ---------------------------------------
|
|
|
|
|
|
def get_token_num_per_seq(
|
|
forward_mode: ForwardMode,
|
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
|
|
):
|
|
if forward_mode.is_target_verify():
|
|
return spec_info.draft_token_num
|
|
elif forward_mode.is_decode():
|
|
return 1
|
|
elif forward_mode.is_idle():
|
|
return 0
|
|
else:
|
|
# For extend, we should not use `token_num_per_seq`.
|
|
return None
|
|
|
|
|
|
# TODO: may smartly disable TBO when batch size is too small b/c it will slow down
|
|
def compute_split_seq_index(
|
|
forward_mode: "ForwardMode",
|
|
num_tokens: int,
|
|
extend_lens: Optional[Sequence[int]],
|
|
token_num_per_seq: Optional[int],
|
|
) -> Optional[int]:
|
|
if forward_mode == ForwardMode.EXTEND:
|
|
assert extend_lens is not None
|
|
return _split_extend_seqs(extend_lens)
|
|
elif forward_mode.is_target_verify() or forward_mode.is_decode():
|
|
assert token_num_per_seq is not None
|
|
return (num_tokens // token_num_per_seq) // 2
|
|
elif forward_mode.is_idle():
|
|
assert num_tokens == 0
|
|
return 0
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
|
|
def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool:
|
|
if extend_lens is None:
|
|
return False
|
|
|
|
vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens)
|
|
left_sum = sum(extend_lens[:vanilla_split_seq_index])
|
|
overall_sum = sum(extend_lens)
|
|
threshold = global_server_args_dict["tbo_token_distribution_threshold"]
|
|
assert threshold <= 0.5, f"{threshold=}"
|
|
return left_sum < overall_sum * threshold or left_sum > overall_sum * (
|
|
1 - threshold
|
|
)
|
|
|
|
|
|
def _split_extend_seqs(arr: Sequence[int]) -> int:
|
|
if _is_two_chunk_split_enabled(arr):
|
|
return _split_array_by_cum_less_than_half(arr)
|
|
|
|
return _split_array_by_balanced_sum(arr)
|
|
|
|
|
|
def _split_array_by_cum_less_than_half(arr: Sequence[int]) -> int:
|
|
left_sum = 0
|
|
overall_sum = sum(arr)
|
|
half_sum = overall_sum // 2
|
|
chosen_index = 0
|
|
|
|
for i in range(len(arr)):
|
|
left_sum += arr[i]
|
|
if left_sum > half_sum:
|
|
chosen_index = i
|
|
break
|
|
|
|
return chosen_index
|
|
|
|
|
|
def _split_array_by_balanced_sum(arr: Sequence[int]) -> int:
|
|
overall_sum = sum(arr)
|
|
left_sum = 0
|
|
min_diff = float("inf")
|
|
best_index = 0
|
|
|
|
for i in range(1, len(arr)):
|
|
left_sum += arr[i - 1]
|
|
right_sum = overall_sum - left_sum
|
|
diff = abs(left_sum - right_sum)
|
|
if diff <= min_diff:
|
|
min_diff = diff
|
|
best_index = i
|
|
else:
|
|
break
|
|
|
|
return best_index
|
|
|
|
|
|
def _update_device_and_sum_field_from_cpu_field(
|
|
batch: ForwardBatch, cpu_field: str, device_field: str, sum_field: str = None
|
|
):
|
|
cpu_value = getattr(batch, cpu_field, None)
|
|
old_device_value = getattr(batch, device_field, None)
|
|
if (
|
|
cpu_value is None
|
|
or old_device_value is None
|
|
or not (isinstance(cpu_value, torch.Tensor) or isinstance(cpu_value, list))
|
|
):
|
|
return
|
|
|
|
new_device_value = (
|
|
cpu_value
|
|
if isinstance(cpu_value, torch.Tensor)
|
|
else torch.tensor(cpu_value, dtype=old_device_value.dtype)
|
|
).to(device=global_server_args_dict["device"], non_blocking=True)
|
|
setattr(batch, device_field, new_device_value)
|
|
|
|
if sum_field is not None:
|
|
sum_value = (
|
|
cpu_value.sum().item()
|
|
if isinstance(cpu_value, torch.Tensor)
|
|
else sum(cpu_value)
|
|
)
|
|
setattr(batch, sum_field, sum_value)
|
|
|
|
|
|
def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int:
|
|
if seq_index == 0:
|
|
return 0
|
|
|
|
offset = 0
|
|
max_seq_len = min(seq_index, spec_info.seq_lens_cpu.shape[0])
|
|
for i in range(max_seq_len):
|
|
offset += (
|
|
spec_info.seq_lens_cpu[i] + spec_info.draft_token_num
|
|
) * spec_info.draft_token_num
|
|
return offset
|
|
|
|
|
|
def split_spec_info(
|
|
spec_info: Optional[EagleVerifyInput],
|
|
start_seq_index: int,
|
|
end_seq_index: int,
|
|
start_token_index: int,
|
|
end_token_index: int,
|
|
):
|
|
if spec_info is None:
|
|
return None
|
|
if spec_info.draft_token is not None:
|
|
draft_token = spec_info.draft_token[start_token_index:end_token_index]
|
|
else:
|
|
draft_token = None
|
|
if spec_info.custom_mask is not None and spec_info.draft_token is not None:
|
|
custom_mask_start = _compute_mask_offset(start_seq_index, spec_info)
|
|
if end_seq_index == spec_info.seq_lens_cpu.shape[0]:
|
|
custom_mask_end = spec_info.custom_mask.shape[0]
|
|
else:
|
|
custom_mask_end = _compute_mask_offset(end_seq_index, spec_info)
|
|
|
|
if custom_mask_end > custom_mask_start:
|
|
custom_mask = spec_info.custom_mask[custom_mask_start:custom_mask_end]
|
|
else:
|
|
custom_mask = spec_info.custom_mask
|
|
else:
|
|
custom_mask = spec_info.custom_mask
|
|
if spec_info.positions is not None:
|
|
positions = spec_info.positions[start_token_index:end_token_index]
|
|
else:
|
|
positions = None
|
|
if spec_info.retrive_index is not None:
|
|
retrive_index = spec_info.retrive_index[start_seq_index:end_seq_index]
|
|
else:
|
|
retrive_index = None
|
|
if spec_info.retrive_next_token is not None:
|
|
retrive_next_token = spec_info.retrive_next_token[start_seq_index:end_seq_index]
|
|
else:
|
|
retrive_next_token = None
|
|
if spec_info.retrive_next_sibling is not None:
|
|
retrive_next_sibling = spec_info.retrive_next_sibling[
|
|
start_seq_index:end_seq_index
|
|
]
|
|
else:
|
|
retrive_next_sibling = None
|
|
if spec_info.retrive_cum_len is not None:
|
|
retrive_cum_len = spec_info.retrive_cum_len[start_seq_index:end_seq_index]
|
|
else:
|
|
retrive_cum_len = None
|
|
|
|
if spec_info.seq_lens_cpu is not None:
|
|
seq_lens_cpu = spec_info.seq_lens_cpu[start_seq_index:end_seq_index]
|
|
else:
|
|
seq_lens_cpu = None
|
|
if seq_lens_cpu is not None:
|
|
seq_lens_sum = seq_lens_cpu.sum()
|
|
else:
|
|
seq_lens_sum = None
|
|
output_spec_info = replace(
|
|
spec_info,
|
|
custom_mask=custom_mask,
|
|
draft_token=draft_token,
|
|
positions=positions,
|
|
retrive_index=retrive_index,
|
|
retrive_next_token=retrive_next_token,
|
|
retrive_next_sibling=retrive_next_sibling,
|
|
retrive_cum_len=retrive_cum_len,
|
|
seq_lens_cpu=seq_lens_cpu,
|
|
seq_lens_sum=seq_lens_sum,
|
|
)
|
|
return output_spec_info
|
|
|
|
|
|
def compute_split_token_index(
|
|
split_seq_index: int,
|
|
forward_mode: "ForwardMode",
|
|
extend_seq_lens: Optional[Sequence[int]],
|
|
token_num_per_seq: Optional[int],
|
|
) -> int:
|
|
if forward_mode == ForwardMode.EXTEND:
|
|
assert extend_seq_lens is not None
|
|
if _is_two_chunk_split_enabled(extend_seq_lens):
|
|
return sum(extend_seq_lens) // 2
|
|
return sum(extend_seq_lens[:split_seq_index])
|
|
elif forward_mode.is_target_verify() or forward_mode.is_decode():
|
|
assert token_num_per_seq is not None
|
|
return split_seq_index * token_num_per_seq
|
|
elif forward_mode.is_idle():
|
|
assert split_seq_index == 0
|
|
return 0
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
def compute_split_indices_for_cuda_graph_replay(
|
|
forward_mode: ForwardMode,
|
|
cuda_graph_num_tokens: int,
|
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
|
):
|
|
forward_mode_for_tbo_split = (
|
|
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
|
|
)
|
|
token_num_per_seq = get_token_num_per_seq(
|
|
forward_mode=forward_mode, spec_info=spec_info
|
|
)
|
|
tbo_split_seq_index = compute_split_seq_index(
|
|
forward_mode=forward_mode_for_tbo_split,
|
|
num_tokens=cuda_graph_num_tokens,
|
|
extend_lens=None,
|
|
token_num_per_seq=token_num_per_seq,
|
|
)
|
|
tbo_split_token_index = compute_split_token_index(
|
|
split_seq_index=tbo_split_seq_index,
|
|
forward_mode=forward_mode_for_tbo_split,
|
|
extend_seq_lens=None,
|
|
token_num_per_seq=token_num_per_seq,
|
|
)
|
|
return tbo_split_seq_index, tbo_split_token_index
|
|
|
|
|
|
# -------------------------------- Preparation ---------------------------------------
|
|
|
|
|
|
class TboCudaGraphRunnerPlugin:
|
|
def __init__(self):
|
|
self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32)
|
|
|
|
def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
|
|
if not global_server_args_dict["enable_two_batch_overlap"]:
|
|
return
|
|
token_num_per_seq = get_token_num_per_seq(
|
|
forward_mode=batch.forward_mode, spec_info=batch.spec_info
|
|
)
|
|
|
|
batch.tbo_split_seq_index = compute_split_seq_index(
|
|
forward_mode=batch.forward_mode,
|
|
num_tokens=num_tokens,
|
|
extend_lens=None,
|
|
token_num_per_seq=token_num_per_seq,
|
|
)
|
|
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
|
|
assert batch.tbo_split_seq_index is not None, f"{num_tokens=}"
|
|
|
|
self._tbo_children_num_token_non_padded[...] = (
|
|
TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded(batch)
|
|
)
|
|
|
|
TboForwardBatchPreparer.prepare_raw(
|
|
batch,
|
|
tbo_children_num_token_non_padded=self._tbo_children_num_token_non_padded,
|
|
)
|
|
|
|
def replay_prepare(
|
|
self,
|
|
forward_mode: ForwardMode,
|
|
bs: int,
|
|
num_token_non_padded: int,
|
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
|
):
|
|
token_num_per_seq = get_token_num_per_seq(
|
|
forward_mode=forward_mode, spec_info=spec_info
|
|
)
|
|
tbo_split_seq_index, tbo_split_token_index = (
|
|
compute_split_indices_for_cuda_graph_replay(
|
|
forward_mode=forward_mode,
|
|
cuda_graph_num_tokens=bs * token_num_per_seq,
|
|
spec_info=spec_info,
|
|
)
|
|
)
|
|
|
|
self._tbo_children_num_token_non_padded[...] = (
|
|
TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded_raw(
|
|
tbo_split_token_index=tbo_split_token_index,
|
|
num_token_non_padded=num_token_non_padded,
|
|
)
|
|
)
|
|
|
|
|
|
class TboDPAttentionPreparer:
|
|
def prepare_all_gather(
|
|
self,
|
|
local_batch: ScheduleBatch,
|
|
deepep_mode: DeepEPMode,
|
|
enable_deepep_moe: bool,
|
|
enable_two_batch_overlap: bool,
|
|
):
|
|
self.enable_two_batch_overlap = enable_two_batch_overlap
|
|
|
|
if local_batch is not None:
|
|
token_num_per_seq = get_token_num_per_seq(
|
|
forward_mode=local_batch.forward_mode, spec_info=local_batch.spec_info
|
|
)
|
|
|
|
if (
|
|
local_batch.forward_mode.is_target_verify()
|
|
or local_batch.forward_mode.is_decode()
|
|
):
|
|
num_tokens = local_batch.batch_size() * token_num_per_seq
|
|
else:
|
|
num_tokens = local_batch.extend_num_tokens
|
|
self.local_tbo_split_seq_index = compute_split_seq_index(
|
|
forward_mode=local_batch.forward_mode,
|
|
num_tokens=num_tokens,
|
|
extend_lens=local_batch.extend_lens,
|
|
token_num_per_seq=token_num_per_seq,
|
|
)
|
|
resolved_deepep_mode = deepep_mode.resolve(local_batch.is_extend_in_batch)
|
|
local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
|
|
(
|
|
local_batch.forward_mode.is_extend()
|
|
and not local_batch.forward_mode.is_target_verify()
|
|
)
|
|
and enable_deepep_moe
|
|
and (resolved_deepep_mode == DeepEPMode.LOW_LATENCY)
|
|
)
|
|
else:
|
|
self.local_tbo_split_seq_index = 0
|
|
local_can_run_tbo = True
|
|
|
|
local_forward_mode = self._compute_local_forward_mode(local_batch)
|
|
|
|
return local_can_run_tbo, local_forward_mode
|
|
|
|
def compute_output(self, partial_global_info):
|
|
local_can_run_tbo_aggregated = min(partial_global_info[:, 0, 0].tolist())
|
|
forward_modes = partial_global_info[:, 0, 1].tolist()
|
|
|
|
global_forward_mode, forward_mode_agree = self._compute_global_forward_mode(
|
|
forward_modes
|
|
)
|
|
|
|
can_run_tbo = (
|
|
self.enable_two_batch_overlap
|
|
and local_can_run_tbo_aggregated
|
|
and forward_mode_agree
|
|
)
|
|
|
|
tbo_split_seq_index = self.local_tbo_split_seq_index if can_run_tbo else None
|
|
global_forward_mode = global_forward_mode if can_run_tbo else None
|
|
return tbo_split_seq_index, global_forward_mode
|
|
|
|
@staticmethod
|
|
def _compute_local_forward_mode(local_batch):
|
|
return (
|
|
local_batch.forward_mode if local_batch is not None else ForwardMode.IDLE
|
|
).value
|
|
|
|
@staticmethod
|
|
def _compute_global_forward_mode(forward_modes):
|
|
forward_modes_excluding_idle = [
|
|
x for x in forward_modes if x != ForwardMode.IDLE.value
|
|
]
|
|
|
|
if not forward_modes_excluding_idle:
|
|
return ForwardMode.IDLE, False
|
|
|
|
forward_mode_agree = TboDPAttentionPreparer._is_all_same(
|
|
forward_modes_excluding_idle
|
|
)
|
|
global_forward_mode = (
|
|
ForwardMode(forward_modes_excluding_idle[0]) if forward_mode_agree else None
|
|
)
|
|
return global_forward_mode, forward_mode_agree
|
|
|
|
@staticmethod
|
|
def _is_all_same(x):
|
|
return all(value == x[0] for value in x)
|
|
|
|
|
|
class TboForwardBatchPreparer:
|
|
@classmethod
|
|
def prepare(cls, batch: ForwardBatch, is_draft_worker: bool = False):
|
|
if batch.tbo_split_seq_index is None or is_draft_worker:
|
|
return
|
|
|
|
tbo_children_num_token_non_padded = (
|
|
cls.compute_tbo_children_num_token_non_padded(batch)
|
|
)
|
|
cls.prepare_raw(
|
|
batch, tbo_children_num_token_non_padded=tbo_children_num_token_non_padded
|
|
)
|
|
|
|
@classmethod
|
|
def prepare_raw(
|
|
cls, batch: ForwardBatch, tbo_children_num_token_non_padded: torch.Tensor
|
|
):
|
|
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
|
|
|
|
tbo_split_token_index = cls._compute_split_token_index(batch)
|
|
|
|
is_enable_two_chunk = (
|
|
batch.forward_mode == ForwardMode.EXTEND
|
|
and _is_two_chunk_split_enabled(batch.extend_seq_lens_cpu)
|
|
)
|
|
|
|
if _tbo_debug:
|
|
logger.info(
|
|
f"TboForwardBatchPreparer.prepare "
|
|
f"is_enable_two_chunk={is_enable_two_chunk} "
|
|
f"tbo_split_seq_index={batch.tbo_split_seq_index} "
|
|
f"tbo_split_token_index={tbo_split_token_index} "
|
|
f"extend_seq_lens={batch.extend_seq_lens_cpu} "
|
|
f"bs={batch.batch_size} "
|
|
f"forward_mode={batch.forward_mode}"
|
|
)
|
|
|
|
assert isinstance(batch.attn_backend, TboAttnBackend)
|
|
attn_backend_child_a, attn_backend_child_b = batch.attn_backend.children
|
|
|
|
[out_num_token_non_padded_a, out_num_token_non_padded_b] = (
|
|
tbo_children_num_token_non_padded
|
|
)
|
|
|
|
child_a = cls.filter_batch(
|
|
batch,
|
|
start_token_index=0,
|
|
end_token_index=tbo_split_token_index,
|
|
start_seq_index=0,
|
|
end_seq_index=(
|
|
batch.tbo_split_seq_index + 1
|
|
if is_enable_two_chunk
|
|
else batch.tbo_split_seq_index
|
|
),
|
|
output_attn_backend=attn_backend_child_a,
|
|
out_num_token_non_padded=out_num_token_non_padded_a,
|
|
)
|
|
child_b = cls.filter_batch(
|
|
batch,
|
|
start_token_index=tbo_split_token_index,
|
|
end_token_index=batch.input_ids.shape[0],
|
|
start_seq_index=batch.tbo_split_seq_index,
|
|
end_seq_index=batch.batch_size,
|
|
output_attn_backend=attn_backend_child_b,
|
|
out_num_token_non_padded=out_num_token_non_padded_b,
|
|
)
|
|
|
|
if is_enable_two_chunk:
|
|
cls.derive_fields_related_to_seq_len_for_two_chunk(
|
|
batch,
|
|
child_a=child_a,
|
|
child_b=child_b,
|
|
tbo_split_seq_index=batch.tbo_split_seq_index,
|
|
)
|
|
|
|
assert batch.tbo_children is None
|
|
batch.tbo_children = [child_a, child_b]
|
|
|
|
@classmethod
|
|
def derive_fields_related_to_seq_len_for_two_chunk(
|
|
cls,
|
|
batch: ForwardBatch,
|
|
*,
|
|
child_a: ForwardBatch,
|
|
child_b: ForwardBatch,
|
|
tbo_split_seq_index: int,
|
|
):
|
|
extend_seq_lens_cpu = batch.extend_seq_lens_cpu
|
|
overall_seq_lens_sum = sum(extend_seq_lens_cpu)
|
|
half_seq_lens_sum = overall_seq_lens_sum // 2
|
|
left_last_seq_token_num = half_seq_lens_sum - sum(
|
|
extend_seq_lens_cpu[:tbo_split_seq_index]
|
|
)
|
|
right_first_seq_token_num = (
|
|
extend_seq_lens_cpu[tbo_split_seq_index] - left_last_seq_token_num
|
|
)
|
|
|
|
# making deepcopy to be extra safe
|
|
child_a.extend_seq_lens_cpu = copy.deepcopy(child_a.extend_seq_lens_cpu)
|
|
child_a.extend_seq_lens_cpu[-1] = left_last_seq_token_num
|
|
child_b.extend_seq_lens_cpu = copy.deepcopy(child_b.extend_seq_lens_cpu)
|
|
child_b.extend_seq_lens_cpu[0] = right_first_seq_token_num
|
|
for child in [child_a, child_b]:
|
|
_update_device_and_sum_field_from_cpu_field(
|
|
batch=child,
|
|
cpu_field="extend_seq_lens_cpu",
|
|
device_field="extend_seq_lens",
|
|
sum_field="extend_num_tokens",
|
|
)
|
|
|
|
assert (
|
|
child_a.extend_num_tokens == half_seq_lens_sum
|
|
), f"{child_a.extend_num_tokens=}, {half_seq_lens_sum=}"
|
|
|
|
child_a.seq_lens_cpu = copy.deepcopy(child_a.seq_lens_cpu)
|
|
child_a.seq_lens_cpu[-1] = (
|
|
child_a.extend_seq_lens_cpu[-1] + child_a.extend_prefix_lens_cpu[-1]
|
|
)
|
|
_update_device_and_sum_field_from_cpu_field(
|
|
batch=child_a,
|
|
cpu_field="seq_lens_cpu",
|
|
device_field="seq_lens",
|
|
sum_field="seq_lens_sum",
|
|
)
|
|
|
|
child_b.extend_prefix_lens_cpu = copy.deepcopy(child_b.extend_prefix_lens_cpu)
|
|
child_b.extend_prefix_lens_cpu[0] += left_last_seq_token_num
|
|
_update_device_and_sum_field_from_cpu_field(
|
|
batch=child_b,
|
|
cpu_field="extend_prefix_lens_cpu",
|
|
device_field="extend_prefix_lens",
|
|
sum_field=None,
|
|
)
|
|
_, child_b.extend_start_loc = compute_position(
|
|
global_server_args_dict["attention_backend"],
|
|
child_b.extend_prefix_lens,
|
|
child_b.extend_seq_lens,
|
|
child_b.extend_num_tokens,
|
|
)
|
|
|
|
@classmethod
|
|
def filter_batch(
|
|
cls,
|
|
batch: ForwardBatch,
|
|
*,
|
|
start_token_index: int,
|
|
end_token_index: int,
|
|
start_seq_index: int,
|
|
end_seq_index: int,
|
|
output_attn_backend: AttentionBackend,
|
|
out_num_token_non_padded: torch.Tensor,
|
|
):
|
|
assert (
|
|
end_token_index >= start_token_index
|
|
), f"{end_token_index=}, {start_token_index=}, batch={batch}"
|
|
num_tokens = batch.input_ids.shape[0]
|
|
num_seqs = batch.batch_size
|
|
|
|
output_dict = dict()
|
|
|
|
for key in [
|
|
"input_ids",
|
|
"positions",
|
|
"out_cache_loc",
|
|
]:
|
|
old_value = getattr(batch, key)
|
|
assert (
|
|
old_value.shape[0] == num_tokens
|
|
), f"{key=} {old_value=} {num_tokens=} {batch=}"
|
|
output_dict[key] = old_value[start_token_index:end_token_index]
|
|
|
|
for key in [
|
|
"req_pool_indices",
|
|
"seq_lens",
|
|
"seq_lens_cpu",
|
|
"extend_seq_lens",
|
|
"extend_prefix_lens",
|
|
"extend_start_loc",
|
|
"extend_prefix_lens_cpu",
|
|
"extend_seq_lens_cpu",
|
|
"extend_logprob_start_lens_cpu",
|
|
"lora_ids",
|
|
]:
|
|
old_value = getattr(batch, key)
|
|
if old_value is None:
|
|
continue
|
|
elif batch.forward_mode.is_target_verify() and (
|
|
key == "extend_seq_lens"
|
|
or key == "extend_prefix_lens"
|
|
or key == "extend_start_loc"
|
|
or key == "extend_prefix_lens_cpu"
|
|
or key == "extend_seq_lens_cpu"
|
|
or key == "extend_logprob_start_lens_cpu"
|
|
):
|
|
output_dict[key] = None
|
|
continue
|
|
assert (
|
|
len(old_value) == num_seqs
|
|
), f"{key=} {old_value=} {num_seqs=} {batch=}"
|
|
output_dict[key] = old_value[start_seq_index:end_seq_index]
|
|
|
|
spec_info = getattr(batch, "spec_info")
|
|
output_spec_info = split_spec_info(
|
|
spec_info=spec_info,
|
|
start_token_index=start_token_index,
|
|
end_token_index=end_token_index,
|
|
start_seq_index=start_seq_index,
|
|
end_seq_index=end_seq_index,
|
|
)
|
|
output_dict["spec_info"] = output_spec_info
|
|
for key in [
|
|
"forward_mode",
|
|
"is_extend_in_batch",
|
|
"return_logprob",
|
|
"req_to_token_pool",
|
|
"token_to_kv_pool",
|
|
"can_run_dp_cuda_graph",
|
|
"global_forward_mode",
|
|
"spec_algorithm",
|
|
"capture_hidden_mode",
|
|
"padded_static_len",
|
|
"mrope_positions", # only used by qwen2-vl, thus not care
|
|
"split_index", # for split prefill
|
|
"orig_seq_lens", # only used by qwen-1m, thus not care
|
|
]:
|
|
output_dict[key] = getattr(batch, key)
|
|
if not batch.forward_mode.is_target_verify():
|
|
assert (
|
|
_compute_extend_num_tokens(batch.input_ids, batch.forward_mode)
|
|
== batch.extend_num_tokens
|
|
), f"{batch=}"
|
|
extend_num_tokens = _compute_extend_num_tokens(
|
|
output_dict["input_ids"], output_dict["forward_mode"]
|
|
)
|
|
|
|
# TODO improve, e.g. unify w/ `init_raw`
|
|
if (
|
|
global_server_args_dict["moe_dense_tp_size"] == 1
|
|
and batch.gathered_buffer is not None
|
|
):
|
|
sum_len = end_token_index - start_token_index
|
|
gathered_buffer = torch.zeros(
|
|
(sum_len, batch.gathered_buffer.shape[1]),
|
|
dtype=batch.gathered_buffer.dtype,
|
|
device=batch.gathered_buffer.device,
|
|
)
|
|
else:
|
|
gathered_buffer = None
|
|
|
|
output_dict.update(
|
|
dict(
|
|
batch_size=end_seq_index - start_seq_index,
|
|
seq_lens_sum=(
|
|
output_dict["seq_lens_cpu"].sum()
|
|
if "seq_lens_cpu" in output_dict
|
|
else None
|
|
),
|
|
extend_num_tokens=extend_num_tokens,
|
|
attn_backend=output_attn_backend,
|
|
num_token_non_padded=out_num_token_non_padded,
|
|
tbo_split_seq_index=None,
|
|
tbo_parent_token_range=(start_token_index, end_token_index),
|
|
tbo_children=None,
|
|
global_num_tokens_gpu=None,
|
|
global_num_tokens_cpu=None,
|
|
dp_padding_mode=None,
|
|
gathered_buffer=gathered_buffer,
|
|
global_num_tokens_for_logprob_gpu=None,
|
|
global_num_tokens_for_logprob_cpu=None,
|
|
sampling_info=None,
|
|
# For logits and logprobs post processing, thus we do not care
|
|
temp_scaled_logprobs=False,
|
|
temperature=None,
|
|
top_p_normalized_logprobs=False,
|
|
top_p=None,
|
|
mm_inputs=None,
|
|
top_logprobs_nums=None,
|
|
token_ids_logprobs=None,
|
|
next_token_logits_buffer=None,
|
|
)
|
|
)
|
|
|
|
errors = []
|
|
for field in dataclasses.fields(ForwardBatch):
|
|
if getattr(batch, field.name) is not None and field.name not in output_dict:
|
|
errors.append(
|
|
f"Field {field.name} has value, but is not yet supported (value={getattr(batch, field.name)} batch={batch})"
|
|
)
|
|
if len(errors) > 0:
|
|
raise Exception(f"{len(errors)} errors happen:\n" + "\n\n".join(errors))
|
|
|
|
return ForwardBatch(**output_dict)
|
|
|
|
@classmethod
|
|
def compute_tbo_children_num_token_non_padded(cls, batch: ForwardBatch):
|
|
return cls.compute_tbo_children_num_token_non_padded_raw(
|
|
tbo_split_token_index=cls._compute_split_token_index(batch),
|
|
num_token_non_padded=len(batch.input_ids),
|
|
)
|
|
|
|
@classmethod
|
|
def compute_tbo_children_num_token_non_padded_raw(
|
|
cls, tbo_split_token_index: int, num_token_non_padded: int
|
|
):
|
|
# TODO we may make padding on both sub-batches to make it slightly more balanced
|
|
value_a = min(tbo_split_token_index, num_token_non_padded)
|
|
value_b = max(0, num_token_non_padded - tbo_split_token_index)
|
|
return torch.tensor([value_a, value_b], dtype=torch.int32).to(
|
|
device=global_server_args_dict["device"], non_blocking=True
|
|
)
|
|
|
|
@classmethod
|
|
def _compute_split_token_index(cls, batch: ForwardBatch):
|
|
token_num_per_seq = get_token_num_per_seq(
|
|
forward_mode=batch.forward_mode, spec_info=batch.spec_info
|
|
)
|
|
return compute_split_token_index(
|
|
split_seq_index=batch.tbo_split_seq_index,
|
|
forward_mode=batch.forward_mode,
|
|
extend_seq_lens=batch.extend_seq_lens_cpu,
|
|
token_num_per_seq=token_num_per_seq,
|
|
)
|
|
|
|
|
|
def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):
|
|
if (
|
|
forward_mode.is_decode()
|
|
or forward_mode.is_idle()
|
|
or forward_mode.is_target_verify()
|
|
):
|
|
return None
|
|
elif forward_mode.is_extend():
|
|
return input_ids.shape[0]
|
|
raise NotImplementedError
|
|
|
|
|
|
# -------------------------------- Execution ---------------------------------------
|
|
|
|
|
|
def model_forward_maybe_tbo(
|
|
layers,
|
|
enable_tbo: bool,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
hidden_states: torch.Tensor,
|
|
input_data_scatter_mode: ScatterMode,
|
|
residual: Optional[torch.Tensor],
|
|
zero_allocator: Optional[BumpAllocator] = None,
|
|
):
|
|
inputs = dict(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
forward_batch=forward_batch,
|
|
residual=residual,
|
|
zero_allocator=zero_allocator,
|
|
)
|
|
layer_input_scatter_mode = layers[0].layer_scatter_modes.layer_input_mode
|
|
operations_strategy = OperationsStrategy.init_new_tbo(
|
|
layers, forward_batch.global_forward_mode
|
|
)
|
|
if enable_tbo:
|
|
return _model_forward_tbo(
|
|
inputs=inputs,
|
|
operations_strategy=operations_strategy,
|
|
input_data_scatter_mode=input_data_scatter_mode,
|
|
layer_input_scatter_mode=layer_input_scatter_mode,
|
|
)
|
|
else:
|
|
return _model_forward_non_tbo(inputs, operations_strategy)
|
|
|
|
|
|
def _model_forward_tbo(
|
|
inputs,
|
|
operations_strategy: OperationsStrategy,
|
|
input_data_scatter_mode: ScatterMode,
|
|
layer_input_scatter_mode: ScatterMode,
|
|
):
|
|
inputs_arr = _model_forward_tbo_split_inputs(
|
|
**inputs,
|
|
input_data_scatter_mode=input_data_scatter_mode,
|
|
layer_input_scatter_mode=layer_input_scatter_mode,
|
|
)
|
|
del inputs
|
|
|
|
context = (
|
|
empty_context()
|
|
if _is_hip
|
|
else deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
|
operations_strategy.deep_gemm_num_sms
|
|
)
|
|
)
|
|
|
|
with context:
|
|
outputs_arr = execute_overlapped_operations(
|
|
inputs_arr=inputs_arr,
|
|
operations_arr=[operations_strategy.operations] * 2,
|
|
delta_stages=[0, operations_strategy.tbo_delta_stages],
|
|
)
|
|
|
|
return _model_forward_tbo_merge_outputs(*outputs_arr)
|
|
|
|
|
|
def _model_forward_non_tbo(inputs, operations_strategy: OperationsStrategy):
|
|
outputs = execute_operations(inputs, operations_strategy.operations)
|
|
return outputs["hidden_states"], outputs["residual"]
|
|
|
|
|
|
def _model_forward_tbo_split_inputs(
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
zero_allocator: Optional[BumpAllocator],
|
|
input_data_scatter_mode: ScatterMode,
|
|
layer_input_scatter_mode: ScatterMode,
|
|
) -> List[Dict]:
|
|
tbo_splitter_scatter_mode = ScatterMode.TP_ATTN_FULL
|
|
context = CommunicateContext.init_new()
|
|
|
|
hidden_states, residual = CommunicateSummableTensorPairFn.execute(
|
|
hidden_states_input_mode=input_data_scatter_mode,
|
|
residual_input_mode=input_data_scatter_mode,
|
|
output_mode=tbo_splitter_scatter_mode,
|
|
hidden_states=hidden_states,
|
|
residual=residual,
|
|
forward_batch=forward_batch,
|
|
context=context,
|
|
)
|
|
|
|
inputs_arr = _model_forward_tbo_split_inputs_raw(
|
|
hidden_states=hidden_states,
|
|
residual=residual,
|
|
positions=positions,
|
|
forward_batch=forward_batch,
|
|
zero_allocator=zero_allocator,
|
|
)
|
|
|
|
def _post_transform(hidden_states, residual, forward_batch, **kwargs):
|
|
hidden_states, residual = CommunicateSummableTensorPairFn.execute(
|
|
hidden_states_input_mode=tbo_splitter_scatter_mode,
|
|
residual_input_mode=tbo_splitter_scatter_mode,
|
|
output_mode=layer_input_scatter_mode,
|
|
hidden_states=hidden_states,
|
|
residual=residual,
|
|
forward_batch=forward_batch,
|
|
context=context,
|
|
)
|
|
return dict(
|
|
hidden_states=hidden_states,
|
|
residual=residual,
|
|
forward_batch=forward_batch,
|
|
**kwargs,
|
|
)
|
|
|
|
return [_post_transform(**inputs) for inputs in inputs_arr]
|
|
|
|
|
|
def _model_forward_tbo_split_inputs_raw(
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
zero_allocator: Optional[BumpAllocator],
|
|
) -> List[Dict]:
|
|
return [
|
|
dict(
|
|
**_model_forward_filter_inputs(
|
|
hidden_states=hidden_states,
|
|
residual=residual,
|
|
positions=positions,
|
|
output_forward_batch=output_forward_batch,
|
|
tbo_subbatch_index=tbo_subbatch_index,
|
|
),
|
|
**(
|
|
dict(zero_allocator=zero_allocator)
|
|
if zero_allocator is not None
|
|
else {}
|
|
),
|
|
)
|
|
for tbo_subbatch_index, output_forward_batch in enumerate(
|
|
forward_batch.tbo_children
|
|
)
|
|
]
|
|
|
|
|
|
def _model_forward_filter_inputs(
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
output_forward_batch: ForwardBatch,
|
|
tbo_subbatch_index: int,
|
|
) -> Dict:
|
|
token_slice = slice(*output_forward_batch.tbo_parent_token_range)
|
|
return dict(
|
|
hidden_states=hidden_states[token_slice],
|
|
residual=None if residual is None else residual[token_slice],
|
|
positions=positions[token_slice],
|
|
forward_batch=output_forward_batch,
|
|
tbo_subbatch_index=tbo_subbatch_index,
|
|
)
|
|
|
|
|
|
def _model_forward_tbo_merge_outputs(output_a, output_b):
|
|
def _handle_key(name):
|
|
value_a = output_a[name]
|
|
value_b = output_b[name]
|
|
assert (value_a is None) == (value_b is None)
|
|
if value_a is None:
|
|
return None
|
|
return torch.concat([value_a, value_b], dim=0)
|
|
|
|
return _handle_key("hidden_states"), _handle_key("residual")
|
|
|
|
|
|
# -------------------------------- Utilities and wrappers ---------------------------------------
|
|
|
|
|
|
class MaybeTboDeepEPDispatcher:
|
|
def __init__(self, **kwargs):
|
|
num_inner_dispatchers = (
|
|
2 if global_server_args_dict["enable_two_batch_overlap"] else 1
|
|
)
|
|
self._inners = [
|
|
DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
|
|
]
|
|
|
|
def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs):
|
|
return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs)
|
|
|
|
def dispatch(self, **kwargs) -> DispatchOutput:
|
|
return self._execute("dispatch", **kwargs)
|
|
|
|
def dispatch_a(self, **kwargs):
|
|
return self._execute("dispatch_a", **kwargs)
|
|
|
|
def dispatch_b(self, **kwargs):
|
|
return self._execute("dispatch_b", **kwargs)
|
|
|
|
def combine(self, **kwargs) -> torch.Tensor:
|
|
return self._execute("combine", **kwargs)
|
|
|
|
def combine_a(self, **kwargs):
|
|
return self._execute("combine_a", **kwargs)
|
|
|
|
def combine_b(self, **kwargs):
|
|
return self._execute("combine_b", **kwargs)
|